Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Qwen2 models #746

Merged
merged 11 commits into from
Dec 5, 2024
Prev Previous commit
Next Next commit
feat(decoder): allow export from local class
  • Loading branch information
dacorvo committed Dec 2, 2024
commit b7adfe085b049d1bc17b2bd13692596560a08bd1
15 changes: 10 additions & 5 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,16 @@ def __init__(self, task: str):
raise ModuleNotFoundError(
"The mandatory transformers-neuronx package is missing. Please install optimum[neuronx]."
)
module_name, class_name = self.NEURONX_CLASS.rsplit(".", maxsplit=1)
module = importlib.import_module(f"transformers_neuronx.{module_name}")
self._neuronx_class = getattr(module, class_name, None)
if self._neuronx_class is None:
raise ImportError(f"{class_name} not found in {module_name}. Please check transformers-neuronx version.")
if isinstance(self.NEURONX_CLASS, type):
self._neuronx_class = self.NEURONX_CLASS
else:
module_name, class_name = self.NEURONX_CLASS.rsplit(".", maxsplit=1)
module = importlib.import_module(f"transformers_neuronx.{module_name}")
self._neuronx_class = getattr(module, class_name, None)
if self._neuronx_class is None:
raise ImportError(
f"{class_name} not found in {module_name}. Please check transformers-neuronx version."
)

@property
def neuronx_class(self):
Expand Down