Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/transformers/utils/kernel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import PushToHubMixin, is_kernels_available, is_torch_available
from ..utils import PushToHubMixin, is_torch_available


if is_kernels_available():
from kernels import LayerRepository, Mode

if is_torch_available():
import torch

Expand Down Expand Up @@ -58,6 +55,8 @@ def infer_device(model):


def add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping):
from kernels import LayerRepository

if device not in ["cuda", "rocm", "xpu"]:
raise ValueError(f"Only cuda, rocm, and xpu devices supported, got: {device}")
repo_layer_name = repo_name.split(":")[1]
Expand All @@ -82,6 +81,8 @@ def __init__(self, kernel_mapping={}):
self.registered_layer_names = {}

def update_kernel(self, repo_id, registered_name, layer_name, device, mode, revision=None):
from kernels import LayerRepository

self.kernel_mapping[registered_name] = {
device: {
mode: LayerRepository(
Expand Down Expand Up @@ -204,6 +205,8 @@ def create_compatible_mapping(self, model, compile=False):
The device is inferred from the model's parameters if not provided.
The Mode is inferred from the model's training state.
"""
from kernels import Mode

compatible_mapping = {}
for layer_name, kernel in self.kernel_mapping.items():
# Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE
Expand Down