Skip to content

Commit

Permalink
add inference global
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra committed Apr 2, 2022
1 parent 525e9e7 commit 7e96581
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch.distributed as dist
import deepspeed.utils.groups as groups

DS_INFERENCE_ENABLED = False


class InferenceEngine(Module):
inference_mp_group = None
Expand Down Expand Up @@ -59,6 +61,8 @@ def __init__(self,
replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise,
the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection)
"""
global DS_INFERENCE_ENABLED
DS_INFERENCE_ENABLED = True

super().__init__()

Expand Down

0 comments on commit 7e96581

Please sign in to comment.