Skip to content

Commit

Permalink
Add BART model tests (#191)
Browse files Browse the repository at this point in the history
Added tests for HF BART base and large model variant with a language
modeling head on top.
  • Loading branch information
mrakitaTT authored Jan 27, 2025
1 parent ffce22a commit 35b2f4d
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 0 deletions.
Empty file.
Empty file.
40 changes: 40 additions & 0 deletions tests/jax/models/bart/base/test_bart_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import FlaxBartForCausalLMTester

MODEL_PATH = "facebook/bart-base"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxBartForCausalLMTester:
return FlaxBartForCausalLMTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxBartForCausalLMTester:
return FlaxBartForCausalLMTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
def test_flax_bart_base_inference(
inference_tester: FlaxBartForCausalLMTester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bart_base_training(
training_tester: FlaxBartForCausalLMTester,
):
training_tester.test()
Empty file.
40 changes: 40 additions & 0 deletions tests/jax/models/bart/large/test_bart_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import FlaxBartForCausalLMTester

MODEL_PATH = "facebook/bart-large"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxBartForCausalLMTester:
return FlaxBartForCausalLMTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxBartForCausalLMTester:
return FlaxBartForCausalLMTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'")
def test_flax_bart_large_inference(
inference_tester: FlaxBartForCausalLMTester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bart_large_training(
training_tester: FlaxBartForCausalLMTester,
):
training_tester.test()
43 changes: 43 additions & 0 deletions tests/jax/models/bart/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, Sequence

import jax
from flax import linen as nn
from infra import ComparisonConfig, ModelTester, RunMode
from transformers import AutoTokenizer, FlaxBartForCausalLM


class FlaxBartForCausalLMTester(ModelTester):
"""Tester for BART model variants with a language modeling head on top."""

# TODO(mrakita): Add tests for other variants.

def __init__(
self,
model_name: str,
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
self._model_name = model_name
super().__init__(comparison_config, run_mode)

# @override
def _get_model(self) -> nn.Module:
return FlaxBartForCausalLM.from_pretrained(self._model_name, from_pt=True)

# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
tokenizer = AutoTokenizer.from_pretrained(self._model_name)
inputs = tokenizer("Hello", return_tensors="np")
return inputs["input_ids"]

# @override
def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]:
assert hasattr(self._model, "params")
return {
"params": self._model.params,
"input_ids": self._get_input_activations(),
}

0 comments on commit 35b2f4d

Please sign in to comment.