-
Notifications
You must be signed in to change notification settings - Fork 456
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Runtime tensor parallelism (#158)
* works on interlm and vicuna * support GQA * remove comment * update readme, add logger, default tp=1 * remove log
- Loading branch information
Showing
12 changed files
with
492 additions
and
191 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import logging | ||
from typing import Optional | ||
|
||
logger_initialized = {} | ||
|
||
|
||
def get_logger(name: str, | ||
log_file: Optional[str] = None, | ||
log_level: int = logging.INFO, | ||
file_mode: str = 'w'): | ||
"""Initialize and get a logger by name. | ||
If the logger has not been initialized, this method will initialize the | ||
logger by adding one or two handlers, otherwise the initialized logger will | ||
be directly returned. During initialization, a StreamHandler will always be | ||
added. If `log_file` is specified, a FileHandler will also be added. | ||
Args: | ||
name (str): Logger name. | ||
log_file (str | None): The log filename. If specified, a FileHandler | ||
will be added to the logger. | ||
log_level (int): The logger level. | ||
file_mode (str): The file mode used in opening log file. | ||
Defaults to 'w'. | ||
Returns: | ||
logging.Logger: The expected logger. | ||
""" | ||
# use logger in mmengine if exists. | ||
try: | ||
from mmengine.logging import MMLogger | ||
if MMLogger.check_instance_created(name): | ||
logger = MMLogger.get_instance(name) | ||
else: | ||
logger = MMLogger.get_instance(name, | ||
logger_name=name, | ||
log_file=log_file, | ||
log_level=log_level, | ||
file_mode=file_mode) | ||
return logger | ||
|
||
except Exception: | ||
pass | ||
|
||
logger = logging.getLogger(name) | ||
if name in logger_initialized: | ||
return logger | ||
# handle hierarchical names | ||
# e.g., logger "a" is initialized, then logger "a.b" will skip the | ||
# initialization since it is a child of "a". | ||
for logger_name in logger_initialized: | ||
if name.startswith(logger_name): | ||
return logger | ||
|
||
# handle duplicate logs to the console | ||
for handler in logger.root.handlers: | ||
if type(handler) is logging.StreamHandler: | ||
handler.setLevel(logging.ERROR) | ||
|
||
stream_handler = logging.StreamHandler() | ||
handlers = [stream_handler] | ||
|
||
if log_file is not None: | ||
# Here, the default behaviour of the official logger is 'a'. Thus, we | ||
# provide an interface to change the file mode to the default | ||
# behaviour. | ||
file_handler = logging.FileHandler(log_file, file_mode) | ||
handlers.append(file_handler) | ||
|
||
formatter = logging.Formatter( | ||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
for handler in handlers: | ||
handler.setFormatter(formatter) | ||
handler.setLevel(log_level) | ||
logger.addHandler(handler) | ||
|
||
logger.setLevel(log_level) | ||
logger_initialized[name] = True | ||
|
||
return logger |
Oops, something went wrong.