diff --git a/requirements.txt b/requirements.txt index c027ec57e07..32e3f725316 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ adal==0.4.3 applicationinsights==0.10.0 argcomplete==1.8.0 colorama==0.3.7 +humanfriendly==2.4 jmespath mock==1.3.0 paramiko==2.0.2 diff --git a/src/azure-cli-core/azure/cli/core/application.py b/src/azure-cli-core/azure/cli/core/application.py index 52dacd3b6b6..577449feba7 100644 --- a/src/azure-cli-core/azure/cli/core/application.py +++ b/src/azure-cli-core/azure/cli/core/application.py @@ -15,6 +15,7 @@ import azure.cli.core.azlogging as azlogging from azure.cli.core.util import todict, truncate_text, CLIError, read_file_content from azure.cli.core._config import az_config +import azure.cli.core.commands.progress as progress import azure.cli.core.telemetry as telemetry @@ -127,6 +128,11 @@ def __init__(self, configuration=None): self.parser = AzCliCommandParser(prog='az', parents=[self.global_parser]) self.configuration = configuration + self.progress_controller = progress.ProgressHook() + + def get_progress_controller(self): + self.progress_controller.init_progress(progress.get_progress_view()) + return self.progress_controller def initialize(self, configuration): self.configuration = configuration diff --git a/src/azure-cli-core/azure/cli/core/commands/__init__.py b/src/azure-cli-core/azure/cli/core/commands/__init__.py index daaf117fd4b..ebcbb3ca385 100644 --- a/src/azure-cli-core/azure/cli/core/commands/__init__.py +++ b/src/azure-cli-core/azure/cli/core/commands/__init__.py @@ -21,7 +21,6 @@ import azure.cli.core.azlogging as azlogging import azure.cli.core.telemetry as telemetry from azure.cli.core.util import CLIError -from azure.cli.core.application import APPLICATION from azure.cli.core.prompting import prompt_y_n, NoTTYException from azure.cli.core._config import az_config, DEFAULTS_SECTION from azure.cli.core.profiles import ResourceType, supported_api_version @@ -126,10 +125,13 @@ def __setattr__(self, name, value): class LongRunningOperation(object): # pylint: disable=too-few-public-methods - def __init__(self, start_msg='', finish_msg='', poller_done_interval_ms=1000.0): + def __init__(self, start_msg='', finish_msg='', + poller_done_interval_ms=1000.0, progress_controller=None): self.start_msg = start_msg self.finish_msg = finish_msg self.poller_done_interval_ms = poller_done_interval_ms + from azure.cli.core.application import APPLICATION + self.progress_controller = progress_controller or APPLICATION.get_progress_controller() def _delay(self): time.sleep(self.poller_done_interval_ms / 1000.0) @@ -138,7 +140,9 @@ def __call__(self, poller): from msrest.exceptions import ClientException logger.info("Starting long running operation '%s'", self.start_msg) correlation_message = '' + self.progress_controller.begin() while not poller.done(): + self.progress_controller.add(message='Running') try: # pylint: disable=protected-access correlation_id = json.loads( @@ -151,8 +155,10 @@ def __call__(self, poller): try: self._delay() except KeyboardInterrupt: + self.progress_controller.stop() logger.error('Long running operation wait cancelled. %s', correlation_message) raise + try: result = poller.result() except ClientException as client_exception: @@ -161,6 +167,7 @@ def __call__(self, poller): fault_type='failed-long-running-operation', summary='Unexpected client exception in {}.'.format(LongRunningOperation.__name__)) message = getattr(client_exception, 'message', client_exception) + self.progress_controller.stop() try: message = '{} {}'.format( @@ -176,6 +183,7 @@ def __call__(self, poller): logger.info("Long running operation '%s' completed with result %s", self.start_msg, result) + self.progress_controller.end() return result @@ -232,6 +240,8 @@ def __init__(self, name, handler, description=None, table_transformer=None, @staticmethod def _should_load_description(): + from azure.cli.core.application import APPLICATION + return not APPLICATION.session['completer_active'] def load_arguments(self): diff --git a/src/azure-cli-core/azure/cli/core/commands/progress.py b/src/azure-cli-core/azure/cli/core/commands/progress.py new file mode 100644 index 00000000000..b2e642c348d --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/commands/progress.py @@ -0,0 +1,142 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from __future__ import division +import sys + +import humanfriendly + +BAR_LEN = 70 + + +class ProgressViewBase(object): + """ a view base for progress reporting """ + def __init__(self, out): + self.out = out + + def write(self, args): + """ writes the progress """ + raise NotImplementedError + + def flush(self): + """ flushes the message out the pipeline""" + self.out.flush() + + +class ProgressReporter(object): + """ generic progress reporter """ + def __init__(self, message='', value=None, total_value=None): + self.message = message + self.value = value + self.total_val = total_value + self.closed = False + + def add(self, **kwargs): + """ + adds a progress report + :param kwargs: dictionary containing 'message', 'total_val', 'value' + """ + message = kwargs.get('message', self.message) + total_val = kwargs.get('total_val', self.total_val) + value = kwargs.get('value', self.value) + if value and total_val: + assert value >= 0 and value <= total_val and total_val >= 0 + self.closed = value == total_val + self.total_val = total_val + self.value = value + self.message = message + + def report(self): + """ report the progress """ + percent = self.value / self.total_val if self.value is not None and self.total_val else None + return {'message': self.message, 'percent': percent} + + +class ProgressHook(object): + """ sends the progress to the view """ + def __init__(self): + self.reporter = ProgressReporter() + self.active_progress = None + + def init_progress(self, progress_view): + """ activate a view """ + self.active_progress = progress_view + + def add(self, **kwargs): + """ adds a progress report """ + self.reporter.add(**kwargs) + self.update() + + def update(self): + """ updates the view with the progress """ + self.active_progress.write(self.reporter.report()) + self.active_progress.flush() + + def stop(self): + """ if there is an abupt stop before ending """ + self.add(message='Interrupted') + + def begin(self, **kwargs): + """ start reporting progress """ + kwargs['message'] = kwargs.get('message', 'Starting') + self.add(**kwargs) + + def end(self, **kwargs): + """ ending reporting of progress """ + kwargs['message'] = kwargs.get('message', 'Finished') + self.add(**kwargs) + + +class IndeterminateStandardOut(ProgressViewBase): + """ custom output for progress reporting """ + def __init__(self, out=None): + super(IndeterminateStandardOut, self).__init__( + out if out else sys.stderr) + self.spinner = humanfriendly.Spinner(label='In Progress', stream=self.out) + self.spinner.hide_cursor = False + + def write(self, args): + """ + writes the progress + :param args: dictionary containing key 'message' + """ + msg = args.get('message', 'In Progress') + self.spinner.step(label=msg) + + +def _format_value(msg, percent): + bar_len = BAR_LEN - len(msg) - 1 + completed = int(bar_len * percent) + + message = '\r{}['.format(msg) + message += ('#' * completed).ljust(bar_len) + message += '] {:.4%}'.format(percent) + return message + + +class DeterminateStandardOut(ProgressViewBase): + """ custom output for progress reporting """ + def __init__(self, out=None): + super(DeterminateStandardOut, self).__init__(out if out else sys.stderr) + + def write(self, args): + """ + writes the progress + :param args: args is a dictionary containing 'percent', 'message' + """ + percent = args.get('percent', 0) + message = args.get('message', '') + + if percent: + percent = percent + progress = _format_value(message, percent) + self.out.write(progress) + + +def get_progress_view(determinant=False, outstream=sys.stderr): + """ gets your view """ + if determinant: + return DeterminateStandardOut(out=outstream) + else: + return IndeterminateStandardOut(out=outstream) diff --git a/src/azure-cli-core/azure/cli/core/test_utils/vcr_test_base.py b/src/azure-cli-core/azure/cli/core/test_utils/vcr_test_base.py index f54373e1674..2e9303f0336 100644 --- a/src/azure-cli-core/azure/cli/core/test_utils/vcr_test_base.py +++ b/src/azure-cli-core/azure/cli/core/test_utils/vcr_test_base.py @@ -117,6 +117,22 @@ def _mock_operation_delay(_): return +class _MockOutstream(object): + """ mock outstream for testing """ + def __init__(self): + self.string = '' + + def write(self, message): + self.string = message + + def flush(self): + pass + + +def _mock_get_progress_view(determinant=False, out=None): # pylint: disable=unused-argument + return _MockOutstream() + + # TEST CHECKS @@ -373,6 +389,7 @@ def _execute_live_or_recording(self): if callable(tear_down) and not self.skip_teardown: self.tear_down() + @mock.patch('azure.cli.core.commands.progress.get_progress_view', _mock_get_progress_view) @mock.patch('azure.cli.core._profile.Profile.load_cached_subscriptions', _mock_subscriptions) @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', _mock_user_access_token) # pylint: disable=line-too-long diff --git a/src/azure-cli-core/setup.py b/src/azure-cli-core/setup.py index 3d83cbff86b..5161101267d 100644 --- a/src/azure-cli-core/setup.py +++ b/src/azure-cli-core/setup.py @@ -55,6 +55,7 @@ 'applicationinsights', 'argcomplete>=1.8.0', 'colorama', + 'humanfriendly', 'jmespath', 'msrest>=0.4.4', 'msrestazure>=0.4.7', diff --git a/src/azure-cli-core/tests/test_progress.py b/src/azure-cli-core/tests/test_progress.py new file mode 100644 index 00000000000..8a3b3c5356b --- /dev/null +++ b/src/azure-cli-core/tests/test_progress.py @@ -0,0 +1,99 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import unittest + +import azure.cli.core.commands.progress as progress + + +class MockOutstream(progress.ProgressViewBase): + """ mock outstream for testing """ + def __init__(self): + self.string = '' + + def write(self, message): + self.string = message + + def flush(self): + pass + + +class TestProgress(unittest.TestCase): # pylint: disable=too-many-public-methods + """ test the progress reporting """ + + def test_progress_indicator_det_model(self): + """ test the progress reporter """ + reporter = progress.ProgressReporter() + args = reporter.report() + self.assertEqual(args['message'], '') + self.assertEqual(args['percent'], None) + + reporter.add(message='Progress', total_val=10, value=0) + self.assertEqual(reporter.message, 'Progress') + self.assertEqual(reporter.value, 0) + self.assertEqual(reporter.total_val, 10) + args = reporter.report() + self.assertEqual(args['message'], 'Progress') + self.assertEqual(args['percent'], 0) + + with self.assertRaises(AssertionError): + reporter.add(message='In words', total_val=-1, value=10) + with self.assertRaises(AssertionError): + reporter.add(message='In words', total_val=1, value=-10) + with self.assertRaises(AssertionError): + reporter.add(message='In words', total_val=30, value=12340) + + reporter = progress.ProgressReporter() + message = reporter.report() + self.assertEqual(message['message'], '') + + reporter.add(message='Progress') + self.assertEqual(reporter.message, 'Progress') + + message = reporter.report() + self.assertEqual(message['message'], 'Progress') + + def test_progress_indicator_indet_stdview(self): + """ tests the indeterminate progress standardout view """ + outstream = MockOutstream() + view = progress.IndeterminateStandardOut(out=outstream) + before = view.spinner.total + self.assertEqual(view.spinner.label, 'In Progress') + view.write({}) + after = view.spinner.total + self.assertTrue(after >= before) + view.write({'message': 'TESTING'}) + + def test_progress_indicator_det_stdview(self): + """ test the determinate progress standardout view """ + outstream = MockOutstream() + view = progress.DeterminateStandardOut(out=outstream) + view.write({'message': 'hihi', 'percent': .5}) + # 95 length, 48 complete, 4 dec percent + bar_str = ('#' * int(.5 * 65)).ljust(65) + self.assertEqual(outstream.string, '\rhihi[{}] {:.4%}'.format(bar_str, .5)) + + view.write({'message': '', 'percent': .9}) + # 99 length, 90 complete, 4 dec percent + bar_str = ('#' * int(.9 * 69)).ljust(69) + self.assertEqual(outstream.string, '\r[{}] {:.4%}'.format(bar_str, .9)) + + def test_progress_indicator_controller(self): + """ tests the controller for progress reporting """ + controller = progress.ProgressHook() + view = MockOutstream() + + controller.init_progress(view) + self.assertTrue(view == controller.active_progress) + + controller.begin() + + self.assertEqual(controller.active_progress.string['message'], 'Starting') + + controller.end() + self.assertEqual(controller.active_progress.string['message'], 'Finished') + + +if __name__ == '__main__': + unittest.main() diff --git a/src/azure-cli-testsdk/azure/cli/testsdk/base.py b/src/azure-cli-testsdk/azure/cli/testsdk/base.py index dfc843d4a6d..add4c8fa98e 100644 --- a/src/azure-cli-testsdk/azure/cli/testsdk/base.py +++ b/src/azure-cli-testsdk/azure/cli/testsdk/base.py @@ -19,7 +19,7 @@ from .patches import (patch_load_cached_subscriptions, patch_main_exception_handler, patch_retrieve_token_for_user, patch_long_run_operation_delay, - patch_time_sleep_api) + patch_time_sleep_api, patch_progress_controller) from .exceptions import CliExecutionError from .const import (ENV_LIVE_TEST, ENV_SKIP_ASSERT, ENV_TEST_DIAGNOSE, MOCKED_SUBSCRIPTION_ID) from .recording_processors import (SubscriptionRecordingProcessor, OAuthRequestResponsesFilter, @@ -160,6 +160,7 @@ def setUp(self): patch_long_run_operation_delay(self) patch_load_cached_subscriptions(self) patch_retrieve_token_for_user(self) + patch_progress_controller(self) def tearDown(self): os.environ = self.original_env diff --git a/src/azure-cli-testsdk/azure/cli/testsdk/patches.py b/src/azure-cli-testsdk/azure/cli/testsdk/patches.py index 54a351aa3df..a30d653f281 100644 --- a/src/azure-cli-testsdk/azure/cli/testsdk/patches.py +++ b/src/azure-cli-testsdk/azure/cli/testsdk/patches.py @@ -7,6 +7,19 @@ from .const import MOCKED_SUBSCRIPTION_ID, MOCKED_TENANT_ID +def patch_progress_controller(unit_test): + def _handle_progress_update(*args): # pylint: disable=unused-argument + pass + + def _handle_progress_add(*args, **kwargs): # pylint: disable=unused-argument + pass + + _mock_in_unit_test( + unit_test, 'azure.cli.core.commands.progress.ProgressHook.update', _handle_progress_update) + _mock_in_unit_test( + unit_test, 'azure.cli.core.commands.progress.ProgressHook.add', _handle_progress_add) + + def patch_main_exception_handler(unit_test): from vcr.errors import CannotOverwriteExistingCassetteException diff --git a/src/command_modules/azure-cli-shell/HISTORY.rst b/src/command_modules/azure-cli-shell/HISTORY.rst index 827c5800067..683ac7945d2 100644 --- a/src/command_modules/azure-cli-shell/HISTORY.rst +++ b/src/command_modules/azure-cli-shell/HISTORY.rst @@ -7,6 +7,8 @@ Release History ++++++++++++++++++ * Integrate shell into az +* Color options +* --progress flag 0.2.1 diff --git a/src/command_modules/azure-cli-shell/azclishell/app.py b/src/command_modules/azure-cli-shell/azclishell/app.py index 60b19f8c425..756cc07f47f 100644 --- a/src/command_modules/azure-cli-shell/azclishell/app.py +++ b/src/command_modules/azure-cli-shell/azclishell/app.py @@ -10,6 +10,7 @@ import os import subprocess import sys +import threading import jmespath from six.moves import configparser @@ -29,11 +30,13 @@ from azclishell.gather_commands import add_random_new_lines from azclishell.key_bindings import registry, get_section, sub_section from azclishell.layout import create_layout, create_tutorial_layout, set_scope +from azclishell.progress import get_progress_message, DONE_STR, progress_view from azclishell.telemetry import TC as telemetry from azclishell.util import get_window_dim, parse_quotes, get_os_clear_screen_word import azure.cli.core.azlogging as azlogging from azure.cli.core.application import Configuration +from azure.cli.core.commands import LongRunningOperation, get_op_handler from azure.cli.core.cloud import get_active_cloud_name from azure.cli.core._config import az_config, DEFAULTS_SECTION from azure.cli.core._environment import get_config_dir @@ -95,26 +98,6 @@ def space_examples(list_examples, rows): return example + page_number -def _toolbar_info(): - sub_name = "" - try: - sub_name = PROFILE.get_subscription()[_SUBSCRIPTION_NAME] - except CLIError: - pass - - curr_cloud = "Cloud: {}".format(get_active_cloud_name()) - tool_val = '{}'.format('Subscription: {}'.format(sub_name) if sub_name else curr_cloud) - - settings_items = [ - " [F1]Layout", - "[F2]Defaults", - "[F3]Keys", - "[Ctrl+D]Quit", - tool_val - ] - return settings_items - - def space_toolbar(settings_items, cols, empty_space): """ formats the toolbar """ counter = 0 @@ -157,6 +140,9 @@ def __init__(self, completer=None, styles=None, self.output = output_custom self.config_default = "" self.default_command = "" + self.threads = [] + self.curr_thread = None + self.spin_val = -1 @property def cli(self): @@ -170,13 +156,8 @@ def on_input_timeout(self, cli): """ brings up the metadata for the command if there is a valid command already typed """ - _, cols = get_window_dim() - cols = int(cols) document = cli.current_buffer.document text = document.text - empty_space = "" - for i in range(cols): # pylint: disable=unused-variable - empty_space += " " text = text.replace('az', '') if self.default_command: @@ -188,22 +169,52 @@ def on_input_timeout(self, cli): self.example_docs = u'{}'.format(example) self._update_default_info() - - settings, empty_space = space_toolbar(_toolbar_info(), cols, empty_space) - cli.buffers['description'].reset( initial_document=Document(self.description_docs, cursor_position=0)) cli.buffers['parameter'].reset( initial_document=Document(self.param_docs)) cli.buffers['examples'].reset( initial_document=Document(self.example_docs)) - cli.buffers['bottom_toolbar'].reset( - initial_document=Document(u'{}{}{}'.format(NOTIFICATIONS, settings, empty_space))) cli.buffers['default_values'].reset( initial_document=Document( u'{}'.format(self.config_default if self.config_default else 'No Default Values'))) + self._update_toolbar() cli.request_redraw() + def _update_toolbar(self): + cli = self.cli + _, cols = get_window_dim() + cols = int(cols) + + empty_space = "" + for _ in range(cols): + empty_space += " " + + settings = self._toolbar_info() + settings, empty_space = space_toolbar(settings, cols, empty_space) + cli.buffers['bottom_toolbar'].reset( + initial_document=Document(u'{}{}{}'.format(NOTIFICATIONS, settings, empty_space))) + + def _toolbar_info(self): + sub_name = "" + try: + sub_name = PROFILE.get_subscription()[_SUBSCRIPTION_NAME] + except CLIError: + pass + + curr_cloud = "Cloud: {}".format(get_active_cloud_name()) + tool_val = '{}'.format('Subscription: {}'.format(sub_name) if sub_name else curr_cloud) + + settings_items = [ + " [F1]Layout", + "[F2]Defaults", + "[F3]Keys", + "[Ctrl+D]Quit", + tool_val + # tool_val2 + ] + return settings_items + def generate_help_text(self, text): """ generates the help text based on commands typed """ command = param_descrip = example = "" @@ -267,7 +278,8 @@ def create_application(self, full_layout=True): 'bottom_toolbar': Buffer(is_multiline=True), 'example_line': Buffer(is_multiline=True), 'default_values': Buffer(), - 'symbols': Buffer() + 'symbols': Buffer(), + 'progress': Buffer(is_multiline=False) } writing_buffer = Buffer( @@ -394,7 +406,7 @@ def _special_cases(self, text, cmd, outside): break_flag = False continue_flag = False - if text and text.split()[0].lower() == 'az': + if text and len(text.split()) > 0 and text.split()[0].lower() == 'az': telemetry.track_ssg('az', text) cmd = ' '.join(text.split()[1:]) if self.default_command: @@ -521,6 +533,7 @@ def handle_scoping_input(self, continue_flag, cmd, text): return continue_flag, cmd def cli_execute(self, cmd): + """ sends the command to the CLI to be executed """ try: args = parse_quotes(cmd) azlogging.configure_logging(args) @@ -534,7 +547,23 @@ def cli_execute(self, cmd): config = Configuration() self.app.initialize(config) - result = self.app.execute(args) + + if '--progress' in args: + args.remove('--progress') + thread = ExecuteThread(self.app.execute, args) + thread.daemon = True + thread.start() + self.threads.append(thread) + self.curr_thread = thread + + thread = ProgressViewThread(progress_view, self) + thread.daemon = True + thread.start() + self.threads.append(thread) + result = None + + else: + result = self.app.execute(args) self.last_exit = 0 if result and result.result is not None: from azure.cli.core._output import OutputProducer @@ -552,9 +581,10 @@ def cli_execute(self, cmd): self.last_exit = int(ex.code) def run(self): - + """ starts the REPL """ telemetry.start() - + from azclishell.progress import ShellProgressView + self.app.progress_controller.init_progress(ShellProgressView()) from azclishell.configuration import SHELL_HELP self.cli.buffers['symbols'].reset( initial_document=Document(u'{}'.format(SHELL_HELP))) @@ -590,9 +620,40 @@ def run(self): subprocess.Popen(cmd, shell=True).communicate() else: self.cli_execute(cmd) + except KeyboardInterrupt: # CTRL C self.set_prompt() continue print('Have a lovely day!!') telemetry.conclude() + + +class ExecuteThread(threading.Thread): + """ thread for executing commands """ + def __init__(self, func, args): + super(ExecuteThread, self).__init__() + self.args = args + self.func = func + + def run(self): + self.func(self.args) + + +class ProgressViewThread(threading.Thread): + """ thread to keep the toolbar spinner spinning """ + def __init__(self, func, arg): + super(ProgressViewThread, self).__init__() + self.func = func + self.arg = arg + + def run(self): + import time + try: + while True: + if self.func(self.arg): + time.sleep(4) + break + time.sleep(.25) + except KeyboardInterrupt: + pass diff --git a/src/command_modules/azure-cli-shell/azclishell/argfinder.py b/src/command_modules/azure-cli-shell/azclishell/argfinder.py index 7008fcb8405..d04a269570e 100644 --- a/src/command_modules/azure-cli-shell/azclishell/argfinder.py +++ b/src/command_modules/azure-cli-shell/azclishell/argfinder.py @@ -5,8 +5,6 @@ import argparse import os -import io -import sys from argcomplete import CompletionFinder from argcomplete.compat import USING_PYTHON2, ensure_bytes @@ -14,6 +12,9 @@ class ArgsFinder(CompletionFinder): """ gets the parsed args """ + def __init__(self, parser, outstream=None): + super(ArgsFinder, self).__init__(parser) + self.outstream = outstream def get_parsed_args(self, comp_words): """ gets the parsed args from a patched parser """ @@ -27,13 +28,13 @@ def get_parsed_args(self, comp_words): comp_words = [ensure_bytes(word) for word in comp_words] try: - stderr = sys.stderr - sys.stderr = os.open(os.devnull, "w") + temp = self.outstream + self.outstream = os.fdopen(os.devnull, "w") active_parsers[0].parse_known_args(comp_words, namespace=parsed_args) - sys.stderr.close() - sys.stderr = stderr + self.outstream.close() + self.outstream = temp except BaseException: pass diff --git a/src/command_modules/azure-cli-shell/azclishell/az_completer.py b/src/command_modules/azure-cli-shell/azclishell/az_completer.py index 158b1514e51..739f074d87d 100644 --- a/src/command_modules/azure-cli-shell/azclishell/az_completer.py +++ b/src/command_modules/azure-cli-shell/azclishell/az_completer.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------------------------- from __future__ import absolute_import, division, print_function, unicode_literals +import sys from prompt_toolkit.completion import Completer, Completion @@ -82,7 +83,7 @@ def _get_weight(val): class AzCompleter(Completer): """ Completes Azure CLI commands """ - def __init__(self, commands, global_params=True): + def __init__(self, commands, global_params=True, outstream=sys.stderr): # dictionary of command to descriptions self.command_description = commands.descrip # from a command to a list of parameters @@ -116,7 +117,7 @@ def __init__(self, commands, global_params=True): from azclishell._dump_commands import CMD_TABLE self.cmdtab = CMD_TABLE self.parser.load_command_table(CMD_TABLE) - self.argsfinder = ArgsFinder(self.parser) + self.argsfinder = ArgsFinder(self.parser, outstream) def validate_completion(self, param, words, text_before_cursor, double=True): """ validates that a param should be completed """ @@ -199,7 +200,6 @@ def gen_dynamic_completions(self, text): for comp in gen_dyn_completion( comp, started_param, prefix, text): yield comp - except TypeError: try: for comp in self.cmdtab[self.curr_command].\ diff --git a/src/command_modules/azure-cli-shell/azclishell/gather_commands.py b/src/command_modules/azure-cli-shell/azclishell/gather_commands.py index 22433035cca..de821401056 100644 --- a/src/command_modules/azure-cli-shell/azclishell/gather_commands.py +++ b/src/command_modules/azure-cli-shell/azclishell/gather_commands.py @@ -40,8 +40,8 @@ def add_random_new_lines(long_phrase, line_min=LINE_MINIMUM, tolerance=TOLERANCE skip = False index = 0 if len(long_phrase) > line_min: - # pylint: disable=unused-variable - for num in range(int(math.floor(len(long_phrase) / line_min))): + + for _ in range(int(math.floor(len(long_phrase) / line_min))): previous = index index += line_min if skip: diff --git a/src/command_modules/azure-cli-shell/azclishell/layout.py b/src/command_modules/azure-cli-shell/azclishell/layout.py index ce425bbe4aa..0938ddbed8a 100644 --- a/src/command_modules/azure-cli-shell/azclishell/layout.py +++ b/src/command_modules/azure-cli-shell/azclishell/layout.py @@ -22,6 +22,7 @@ import azclishell.configuration from azclishell.key_bindings import get_show_default, get_symbols +from azclishell.progress import get_progress_message, DONE_STR MAX_COMPLETION = 16 DEFAULT_COMMAND = "" @@ -64,6 +65,14 @@ def __call__(self, *a, **kw): return get_symbols() +# pylint: disable=too-few-public-methods +class ShowProgress(Filter): + """ toggle showing the progress """ + def __call__(self, *a, **kw): + progress = get_progress_message() + return progress != '' and progress != DONE_STR + + def get_scope(): """" returns the default command """ return DEFAULT_COMMAND @@ -194,6 +203,15 @@ def create_layout(lex, exam_lex, toolbar_lex): ), filter=ShowSymbol() ), + ConditionalContainer( + Window( + content=BufferControl( + buffer_name='progress', + lexer=lexer + ) + ), + filter=ShowProgress() + ), Window( content=BufferControl( buffer_name='bottom_toolbar', diff --git a/src/command_modules/azure-cli-shell/azclishell/progress.py b/src/command_modules/azure-cli-shell/azclishell/progress.py new file mode 100644 index 00000000000..c396736f85a --- /dev/null +++ b/src/command_modules/azure-cli-shell/azclishell/progress.py @@ -0,0 +1,95 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from random import randint + +from prompt_toolkit.document import Document + +from azure.cli.core.commands.progress import ProgressViewBase +from azclishell.util import get_window_dim + + +PROGRESS = '' +PROGRESS_BAR = '' +DONE_STR = 'Finished' +# have 2 down beats to make the odds work out better +HEART_BEAT_VALUES = {0: "__", 1: "/\\", 2: '/^\\', 3: "__"} +HEART_BEAT = '' + + +class ShellProgressView(ProgressViewBase): + """ custom output for progress reporting """ + def __init__(self): + super(ShellProgressView, self).__init__(None, None) + + def write(self, args): + """ writes the progres """ + global PROGRESS, PROGRESS_BAR + message = args.get('message', '') + percent = args.get('percent', None) + if percent: + PROGRESS_BAR = self._format_value(message, percent) + PROGRESS = message + + def _format_value(self, msg, percent=0.0): + _, col = get_window_dim() + bar_len = int(col) - len(msg) - 10 + + completed = int(bar_len * percent) + message = '{}['.format(msg) + for i in range(bar_len): + if i < completed: + message += '#' + else: + message += ' ' + message += '] {:.1%}'.format(percent) + return message + + def flush(self): + """ flushes the message""" + pass + + +def get_progress_message(): + """ gets the progress message """ + return PROGRESS + + +def progress_view(shell): + """ updates the view """ + global HEART_BEAT + _, col = get_window_dim() + col = int(col) + progress = get_progress_message() + buffer_size = col - len(progress) - 4 + + if PROGRESS_BAR: + doc = u'{}:{}'.format(progress, PROGRESS_BAR) + else: + if progress and progress != DONE_STR: + if shell.spin_val >= 0: + beat = HEART_BEAT_VALUES[_get_heart_frequency()] + HEART_BEAT += beat + HEART_BEAT = HEART_BEAT[len(beat):] + len_beat = len(HEART_BEAT) + if len_beat > buffer_size: + HEART_BEAT = HEART_BEAT[len_beat - buffer_size:] + + else: + shell.spin_val = 0 + counter = 0 + while counter < buffer_size: + beat = HEART_BEAT_VALUES[_get_heart_frequency()] + HEART_BEAT += beat + counter += len(beat) + doc = u'{}:{}'.format(progress, HEART_BEAT) + shell.cli.buffers['progress'].reset( + initial_document=Document(doc)) + shell.cli.request_redraw() + if PROGRESS == 'Finished' or PROGRESS == 'Interrupted': + return True + + +def _get_heart_frequency(): + return int(round(randint(0, 3))) diff --git a/src/command_modules/azure-cli-shell/tests/test_completion.py b/src/command_modules/azure-cli-shell/tests/test_completion.py index 5bf3382028b..d8caf459b18 100644 --- a/src/command_modules/azure-cli-shell/tests/test_completion.py +++ b/src/command_modules/azure-cli-shell/tests/test_completion.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +import unittest +from io import StringIO import six import azclishell.command_tree as tree @@ -10,8 +12,6 @@ from prompt_toolkit.document import Document from prompt_toolkit.completion import Completion -import unittest - class _Commands(): """ mock model for testing completer """ @@ -49,7 +49,7 @@ def init1(self): command_tree=com_tree3, descrip=command_description ) - self.completer = AzCompleter(commands, global_params=False) + self.completer = AzCompleter(commands, global_params=False, outstream=StringIO()) def init2(self): """ a variation of initializing """ @@ -79,7 +79,7 @@ def init2(self): param_descript=param_descript, descrip=command_description ) - self.completer = AzCompleter(commands, global_params=False) + self.completer = AzCompleter(commands, global_params=False, outstream=StringIO()) def init3(self): """ a variation of initializing """ @@ -116,7 +116,7 @@ def init3(self): same_param_doubles=same_param_doubles, descrip=command_description ) - self.completer = AzCompleter(commands, global_params=False) + self.completer = AzCompleter(commands, global_params=False, outstream=StringIO()) def init4(self): """ a variation of initializing """ @@ -153,7 +153,7 @@ def init4(self): same_param_doubles=same_param_doubles, descrip=command_description ) - self.completer = AzCompleter(commands, global_params=False) + self.completer = AzCompleter(commands, global_params=False, outstream=StringIO()) def test_command_completion(self): """ tests general command completion """ diff --git a/src/command_modules/azure-cli-storage/azure/cli/command_modules/storage/custom.py b/src/command_modules/azure-cli-storage/azure/cli/command_modules/storage/custom.py index 266f15f070f..5bf5a54a25f 100644 --- a/src/command_modules/azure-cli-storage/azure/cli/command_modules/storage/custom.py +++ b/src/command_modules/azure-cli-storage/azure/cli/command_modules/storage/custom.py @@ -6,7 +6,6 @@ # pylint: disable=no-self-use,too-many-arguments,line-too-long from __future__ import print_function -from sys import stderr from azure.cli.core.decorators import transfer_doc from azure.cli.core.util import CLIError @@ -14,7 +13,9 @@ from azure.cli.command_modules.storage._factory import \ (storage_client_factory, generic_data_service_factory) +from azure.cli.core.application import APPLICATION +from azure.cli.core.commands.progress import get_progress_view Logging, Metrics, CorsRule, \ AccessPolicy, RetentionPolicy = get_sdk(ResourceType.DATA_STORAGE, @@ -39,15 +40,13 @@ 'queue#QueueService') +HOOK = APPLICATION.progress_controller +HOOK.init_progress(get_progress_view(determinant=True)) + + def _update_progress(current, total): if total: - message = 'Percent complete: %' - percent_done = current * 100 / total - message += '{: >5.1f}'.format(percent_done) - print('\b' * len(message) + message, end='', file=stderr) - stderr.flush() - if current == total: - print('', file=stderr) + HOOK.add(message='Alive', value=current, total_val=total) # CUSTOM METHODS