Skip to content

Commit 73cb4b8

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 8bd1e5f + be7b47f commit 73cb4b8

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,13 @@ pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_contro
208208
if [ "${CU_VERSION:-}" != cpu ] ; then
209209
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
210210
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
211+
--ignore test/llm \
211212
--timeout=120 --mp_fork_if_no_cuda
212213
else
213214
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
214215
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
215216
--ignore test/test_distributed.py \
217+
--ignore test/llm \
216218
--timeout=120 --mp_fork_if_no_cuda
217219
fi
218220

.github/unittest/linux_optdeps/scripts/run_all.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ export BATCHED_PIPE_TIMEOUT=60
159159
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
160160
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
161161
--ignore test/test_distributed.py \
162+
--ignore test/llm \
162163
--timeout=120 --mp_fork_if_no_cuda
163164

164165
coverage combine

torchrl/_utils.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from copy import copy
2121
from functools import wraps
2222
from importlib import import_module
23+
from textwrap import indent
2324
from typing import Any, Callable, cast, TypeVar
2425

2526
import numpy as np
@@ -52,25 +53,37 @@ def strtobool(val: Any) -> bool:
5253
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
5354
logger = logging.getLogger("torchrl")
5455
logger.setLevel(getattr(logging, LOGGING_LEVEL))
55-
# Disable propagation to the root logger
5656
logger.propagate = False
57-
# Remove all attached handlers
57+
# Clear existing handlers
5858
while logger.hasHandlers():
5959
logger.removeHandler(logger.handlers[0])
6060
stream_handlers = {
6161
"stdout": sys.stdout,
6262
"stderr": sys.stderr,
6363
}
6464
TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM")
65-
if TORCHRL_CONSOLE_STREAM:
66-
stream_handler = stream_handlers[TORCHRL_CONSOLE_STREAM]
67-
else:
68-
stream_handler = None
69-
console_handler = logging.StreamHandler(stream=stream_handler)
70-
71-
console_handler.setLevel(logging.INFO)
72-
formatter = logging.Formatter("%(asctime)s [%(name)s][%(levelname)s] %(message)s")
73-
console_handler.setFormatter(formatter)
65+
stream_handler = stream_handlers.get(TORCHRL_CONSOLE_STREAM, sys.stdout)
66+
67+
68+
# Create colored handler
69+
class _CustomFormatter(logging.Formatter):
70+
def format(self, record):
71+
# Format the initial part in green
72+
green_format = "\033[92m%(asctime)s [%(name)s][%(levelname)s]\033[0m"
73+
# Format the message part
74+
message_format = "%(message)s"
75+
# End marker in green
76+
end_marker = "\033[92m [END]\033[0m"
77+
# Combine all parts
78+
formatted_message = logging.Formatter(
79+
green_format + indent(message_format, " " * 4) + end_marker
80+
).format(record)
81+
82+
return formatted_message
83+
84+
85+
console_handler = logging.StreamHandler(stream_handler)
86+
console_handler.setFormatter(_CustomFormatter())
7487
logger.addHandler(console_handler)
7588

7689
VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG))))

0 commit comments

Comments
 (0)