Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ pyvenv.cfg
htmlcov/
.coverage
*supply_chain_*/
examples/supply_chain/docker-compose.yml
examples/supply_chain/docker-compose.yml
examples/rl/config.yml
7 changes: 2 additions & 5 deletions docker_files/dev.df
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.7.12-buster
FROM python:3.7-buster
WORKDIR /maro

# Install Apt packages
Expand All @@ -13,7 +13,7 @@ RUN apt-get install -y python3-dev libpython3.7-dev python-numpy
RUN rm -rf /var/lib/apt/lists/*

# Install Python packages
RUN pip3 install --upgrade pip
RUN pip install --upgrade pip
RUN pip install --no-cache-dir Cython==0.29.14
RUN pip install --no-cache-dir pyaml==20.4.0
RUN pip install --no-cache-dir pyzmq==19.0.2
Expand All @@ -31,9 +31,6 @@ COPY setup.py /maro/
RUN bash /maro/scripts/install_maro.sh
RUN pip cache purge

RUN rm -r /maro/maro/rl
RUN rm -r /maro/maro/simulator/scenarios/supply_chain

ENV PYTHONPATH=/maro

CMD ["/bin/bash"]
29 changes: 0 additions & 29 deletions examples/rl/config.yml

This file was deleted.

64 changes: 64 additions & 0 deletions examples/rl/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os

from maro.rl.training import TrainerManager
from maro.rl.workflows.scenario import Scenario
from maro.utils import Logger


SCENARIO_PATH = "cim"
NUM_EPISODES = 50
NUM_STEPS = -1
CHECKPOINT_PATH = os.path.join(os.getcwd(), "checkpoints")
CHECKPOINT_INTERVAL = 5
EVAL_SCHEDULE = [10, 20, 30, 40, 50]
LOG_PATH = os.path.join(os.getcwd(), "logs", "cim")


if __name__ == "__main__":
scenario = Scenario(SCENARIO_PATH)
logger = Logger("MAIN", dump_path=LOG_PATH)

agent2policy = scenario.agent2policy
policy_creator = scenario.policy_creator
trainer_creator = scenario.trainer_creator
policy_dict = {name: get_policy_func(name) for name, get_policy_func in policy_creator.items()}
policy_creator = {name: lambda name: policy_dict[name] for name in policy_dict}

# evaluation schedule
logger.info(f"Policy will be evaluated at the end of episodes {EVAL_SCHEDULE}")
eval_point_index = 0

env_sampler = scenario.get_env_sampler(policy_creator)
trainer_manager = TrainerManager(policy_creator, trainer_creator, agent2policy, logger=logger)

# main loop
for ep in range(1, NUM_EPISODES + 1):
collect_time = training_time = 0
segment, end_of_episode = 1, False
while not end_of_episode:
# experience collection
result = env_sampler.sample(num_steps=NUM_STEPS)
experiences = result["experiences"]
end_of_episode = result["end_of_episode"]

if scenario.post_collect:
scenario.post_collect(result["info"], ep, segment)

logger.info(f"Roll-out completed for episode {ep}. Training started...")
trainer_manager.record_experiences(experiences)
trainer_manager.train()
if CHECKPOINT_PATH and ep % CHECKPOINT_INTERVAL == 0:
pth = os.path.join(CHECKPOINT_PATH, str(ep))
trainer_manager.save(pth)
logger.info(f"All trainer states saved under {pth}")
segment += 1

# performance details
if ep == EVAL_SCHEDULE[eval_point_index]:
eval_point_index += 1
result = env_sampler.eval()
if scenario.post_evaluate:
scenario.post_evaluate(result["info"], ep)
14 changes: 6 additions & 8 deletions maro/cli/k8s/aks_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from maro.cli.utils.azure.resource_group import create_resource_group, delete_resource_group_under_subscription
# from maro.cli.utils.azure.vm import list_vm_sizes
from maro.cli.utils.common import show_log
from maro.cli.utils.config_parser import get_rl_component_env_vars
# from maro.cli.utils.deployment_validator import DeploymentValidator
# from maro.cli.utils.details_reader import DetailsReader
# from maro.cli.utils.details_writer import DetailsWriter
# from maro.cli.utils.name_creator import NameCreator
# from maro.cli.utils.path_convertor import PathConvertor
# from maro.cli.utils.subprocess import Subprocess
# from maro.utils.exception.cli_exception import BadRequestError, FileOperationError
from maro.rl.workflows.config import ConfigParser
from maro.utils.logger import CliLogger
from maro.utils.utils import LOCAL_MARO_ROOT

Expand Down Expand Up @@ -221,8 +221,8 @@ def add_job(conf_path: dict, **kwargs):
logger.error_red(NO_DEPLOYMENT_MSG)
return

with open(conf_path, "r") as fp:
job_conf = yaml.safe_load(fp)
parser = ConfigParser(conf_path)
job_conf = parser.config

job_name = job_conf["job"]
local_job_path = get_local_job_path(job_name)
Expand Down Expand Up @@ -251,8 +251,8 @@ def add_job(conf_path: dict, **kwargs):
for name in ["scenario", "logs", "checkpoints"]
]

if "load_path" in job_conf:
load_dir = job_conf['load_path']
if "load_path" in job_conf["training"]:
load_dir = job_conf["training"]["load_path"]
logger.info(f"Uploading local directory {load_dir}...")
azure_storage_utils.upload_to_fileshare(job_dir, load_dir, name="loadpoint")
volumes.append(
Expand All @@ -269,12 +269,10 @@ def add_job(conf_path: dict, **kwargs):
ADDRESS_REGISTRY_NAME, REDIS_HOST, deployment_conf["name"], ADDRESS_REGISTRY_PORT
), job_name
)
for component_name, env in get_rl_component_env_vars(job_conf, containerized=True).items():
for component_name, env in parser.as_env(containerize=True).items():
container_spec = k8s_manifest_generator.get_container_spec(
get_docker_image_name_in_acr(resource_name["acrName"], DOCKER_IMAGE_NAME),
component_name,
ADDRESS_REGISTRY_NAME,
ADDRESS_REGISTRY_PORT,
env,
volumes
)
Expand Down
6 changes: 3 additions & 3 deletions maro/cli/k8s/lib/modes/aks/create_aks_cluster/template.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"resources": [
{
"type": "Microsoft.Storage/storageAccounts/fileServices/shares",
"apiVersion": "2020-08-01-preview",
"apiVersion": "2021-04-01",
"name": "[concat(parameters('storageAccountName'), '/default/', parameters('fileShareName'))]",
"dependsOn": [
"[variables('stvmId')]"
Expand All @@ -84,7 +84,7 @@
{
"name": "[parameters('acrName')]",
"type": "Microsoft.ContainerRegistry/registries",
"apiVersion": "2020-11-01-preview",
"apiVersion": "2021-09-01",
"location": "[parameters('location')]",
"sku": {
"name": "[parameters('acrSku')]"
Expand Down Expand Up @@ -126,7 +126,7 @@
},
{
"type": "Microsoft.Storage/storageAccounts",
"apiVersion": "2020-08-01-preview",
"apiVersion": "2021-08-01",
"name": "[parameters('storageAccountName')]",
"location": "[parameters('location')]",
"kind": "StorageV2",
Expand Down
6 changes: 3 additions & 3 deletions maro/cli/k8s/test_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
{
"name": "[parameters('acrName')]",
"type": "Microsoft.ContainerRegistry/registries",
"apiVersion": "2020-11-01-preview",
"apiVersion": "2021-09-01",
"location": "[parameters('location')]",
"sku": {
"name": "[parameters('acrSku')]"
Expand Down Expand Up @@ -133,7 +133,7 @@
},
{
"type": "Microsoft.Storage/storageAccounts",
"apiVersion": "2020-08-01-preview",
"apiVersion": "2021-08-01",
"name": "[parameters('storageAccountName')]",
"location": "[parameters('location')]",
"kind": "StorageV2",
Expand All @@ -147,7 +147,7 @@
},
{
"type": "Microsoft.Storage/storageAccounts/fileServices/shares",
"apiVersion": "2020-08-01-preview",
"apiVersion": "2021-04-01",
"name": "[concat(parameters('storageAccountName'), '/default/', parameters('fileShareName'))]",
"dependsOn": [
"[resourceId('Microsoft.Storage/storageAccounts', parameters('storageAccountName'))]"
Expand Down
10 changes: 5 additions & 5 deletions maro/cli/k8s/utils/k8s_manifest_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import List

from maro.cli.utils.config_parser import format_env_vars, get_mnt_path_in_container, get_script_path
from maro.cli.utils.common import format_env_vars


def get_job_manifest(agent_pool_name: str, component_name: str, container_spec: dict, volumes: List[dict]):
Expand Down Expand Up @@ -33,18 +33,18 @@ def get_azurefile_volume_spec(name: str, share_name: str, secret_name: str):
}


def get_container_spec(image_name: str, component_name: str, redis_host: str, redis_port: int, env: dict, volumes):
def get_container_spec(image_name: str, component_name: str, env: dict, volumes):
common_container_spec = {
"image": image_name,
"imagePullPolicy": "Always",
"volumeMounts": [{"name": vol["name"], "mountPath": get_mnt_path_in_container(vol["name"])} for vol in volumes]
"volumeMounts": [{"name": vol["name"], "mountPath": f"/{vol['name']}"} for vol in volumes]
}
return {
**common_container_spec,
**{
"name": component_name,
"command": ["python3", get_script_path(component_name, containerized=True)],
"env": format_env_vars({**env, "REDIS_HOST": redis_host, "REDIS_PORT": str(redis_port)}, mode="k8s")
"command": ["python3", f"/maro/maro/rl/workflows/{component_name.split('-')[0]}.py"],
"env": format_env_vars(env, mode="k8s")
}
}

Expand Down
41 changes: 16 additions & 25 deletions maro/cli/local/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
import yaml

from maro.cli.utils.common import close_by_pid, show_log
from maro.rl.utils.common import get_log_path
from maro.rl.workflows.config import ConfigParser
from maro.utils.logger import CliLogger
from maro.utils.utils import LOCAL_MARO_ROOT

from .utils import (
JobStatus, RedisHashKey, start_redis, start_redis_container, start_rl_job_in_foreground,
start_rl_job_with_docker_compose, stop_redis, stop_redis_container, stop_rl_job_with_docker_compose
JobStatus, RedisHashKey, start_redis, start_rl_job, start_rl_job_with_docker_compose, stop_redis,
stop_rl_job_with_docker_compose
)

# metadata
Expand All @@ -28,7 +28,6 @@
DOCKERFILE_PATH = join(LOCAL_MARO_ROOT, "docker_files", "dev.df")
DOCKER_IMAGE_NAME = "maro-local"
DOCKER_NETWORK = "MAROLOCAL"
REDIS_CONTAINER_NAME = "maro-local-redis"

# display
NO_JOB_MANAGER_MSG = """No job manager found. Run "maro local init" to start the job manager first."""
Expand Down Expand Up @@ -57,19 +56,21 @@ def get_redis_conn(port=None):


# Functions executed on CLI commands
def run(conf_path: str, containerize: bool = False, port: int = 20000, **kwargs):
def run(conf_path: str, containerize: bool = False, **kwargs):
# Load job configuration file
with open(conf_path, "r") as fr:
conf = yaml.safe_load(fr)

parser = ConfigParser(conf_path)
env_by_component = parser.as_env(containerize=containerize)
if containerize:
path_mapping = parser.get_path_mapping(containerize=True)
try:
start_rl_job_with_docker_compose(conf, LOCAL_MARO_ROOT, DOCKERFILE_PATH, DOCKER_IMAGE_NAME)
start_rl_job_with_docker_compose(
parser.config, LOCAL_MARO_ROOT, DOCKERFILE_PATH, DOCKER_IMAGE_NAME, env_by_component, path_mapping
)
except KeyboardInterrupt:
stop_rl_job_with_docker_compose(conf)
stop_rl_job_with_docker_compose(parser.config["job"])
else:
try:
start_rl_job_in_foreground(conf, LOCAL_MARO_ROOT, port=port)
start_rl_job(parser.as_env(), LOCAL_MARO_ROOT)
except KeyboardInterrupt:
sys.exit(1)

Expand All @@ -91,10 +92,7 @@ def init(
)
return

if containerize:
start_redis_container(port, REDIS_CONTAINER_NAME, DOCKER_NETWORK)
else:
start_redis(port)
start_redis(port)

# Start job manager
command = ["python", join(dirname(abspath(__file__)), 'job_manager.py')]
Expand All @@ -109,9 +107,7 @@ def init(
"REDIS_PORT": str(port),
"LOCAL_MARO_ROOT": LOCAL_MARO_ROOT,
"DOCKER_IMAGE_NAME": DOCKER_IMAGE_NAME,
"DOCKERFILE_PATH": DOCKERFILE_PATH,
"DOCKER_NETWORK": DOCKER_NETWORK,
"REDIS_CONTAINER_NAME": REDIS_CONTAINER_NAME
"DOCKERFILE_PATH": DOCKERFILE_PATH
}
)

Expand Down Expand Up @@ -147,10 +143,7 @@ def exit(**kwargs):
close_by_pid(int(session_state["job_manager_pid"]))

# Stop Redis
if session_state["containerized"]:
stop_redis_container(REDIS_CONTAINER_NAME, DOCKER_NETWORK)
else:
stop_redis(session_state["port"])
stop_redis(session_state["port"])

# Remove dump folder.
shutil.rmtree(LOCAL_ROOT, True)
Expand Down Expand Up @@ -226,9 +219,7 @@ def get_job_logs(job_name: str, tail: int = -1, **kwargs):
return

conf = json.loads(redis_conn.hget(RedisHashKey.JOB_CONF, job_name))
if "log_dir" in conf:
log_path = get_log_path(conf['log_dir'], conf["job"])
show_log(log_path, tail=tail)
show_log(conf["log_path"], tail=tail)


def list_jobs(**kwargs):
Expand Down
Loading