Skip to content

Commit

Permalink
fix: dynamic import of transformers in hybrid model
Browse files Browse the repository at this point in the history
We currently have an import to `transformers` in the hybrid model source
code but we don't have the library as a requirement, only a development
dependency.

Importing the library dynamically should fix this.
  • Loading branch information
fd0r committed Jun 27, 2024
1 parent 7a2eeea commit 829b68b
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from brevitas.quant_tensor import QuantTensor
from concrete.fhe import Configuration
from torch import nn
from transformers import Conv1D

from ..common.utils import MAX_BITWIDTH_BACKWARD_COMPATIBLE
from ..deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
Expand Down Expand Up @@ -76,6 +75,11 @@ def convert_conv1d_to_linear(layer_or_module):
nn.Module or nn.Linear: The updated module with Conv1D layers converted to Linear layers,
or the Conv1D layer converted to a Linear layer.
"""
try:
from transformers import Conv1D # pylint: disable=import-outside-toplevel
except ImportError: # pragma: no cover
return layer_or_module

if isinstance(layer_or_module, Conv1D):
# Get the weight size
weight_size = layer_or_module.weight.size()
Expand Down

0 comments on commit 829b68b

Please sign in to comment.