-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Makefile and code quality #62
Changes from 4 commits
25d09bb
888d3ae
1e91d69
1d9f494
78cd9a9
fac2baa
7c0a821
b5679af
ae074da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,27 @@ on: | |
branches: [ master ] | ||
|
||
jobs: | ||
|
||
check_code_quality: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.7" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install .[dev] | ||
- name: Check quality | ||
run: | | ||
black --check --line-length 119 --target-version py36 tests trl | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm not mistaken, you can just run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
isort --check-only tests trl | ||
flake8 tests trl | ||
|
||
tests: | ||
needs: check_code_quality | ||
strategy: | ||
matrix: | ||
python-version: [3.7, 3.8, 3.9] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,13 @@ | ||
SRC = $(wildcard ./nbs//*.ipynb) | ||
|
||
all: trl docs | ||
|
||
trl: $(SRC) | ||
nbdev_build_lib | ||
touch trl | ||
|
||
docs_serve: docs | ||
cd docs && bundle exec jekyll serve | ||
|
||
docs: $(SRC) | ||
nbdev_build_docs | ||
touch docs | ||
.PHONY: quality style test | ||
|
||
test: | ||
pytest tests | ||
|
||
format: | ||
black --line-length 119 --target-version py36 tests trl examples | ||
isort tests trl examples | ||
|
||
release: pypi | ||
nbdev_bump_version | ||
|
||
pypi: dist | ||
twine upload --repository pypi dist/* | ||
python -m pytest -n auto --dist=loadfile -s -v ./tests/ | ||
|
||
dist: clean | ||
python setup.py sdist bdist_wheel | ||
quality: | ||
black --check --line-length 119 --target-version py36 tests trl | ||
lvwerra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
isort --check-only tests trl | ||
flake8 tests trl | ||
|
||
clean: | ||
rm -rf dist | ||
style: | ||
black --line-length 119 --target-version py36 tests trl | ||
lvwerra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
isort tests trl |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[metadata] | ||
license_file = LICENSE | ||
|
||
[isort] | ||
ensure_newline_before_comments = True | ||
force_grid_wrap = 0 | ||
include_trailing_comma = True | ||
line_length = 119 | ||
lines_after_imports = 2 | ||
multi_line_output = 3 | ||
use_parentheses = True | ||
|
||
[flake8] | ||
ignore = E203, E501, W503 | ||
max-line-length = 119 |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,10 +11,10 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
import unittest | ||||||
import tempfile | ||||||
import torch | ||||||
import unittest | ||||||
|
||||||
import torch | ||||||
from transformers import AutoModelForCausalLM | ||||||
|
||||||
from trl import AutoModelForCausalLMWithValueHead | ||||||
|
@@ -30,6 +30,7 @@ | |||||
"trl-internal-testing/tiny-random-GPT2LMHeadModel", | ||||||
] | ||||||
|
||||||
|
||||||
class BaseModelTester: | ||||||
all_model_names = None | ||||||
trl_model_class = None | ||||||
|
@@ -41,17 +42,17 @@ def test_from_save(self): | |||||
for model_name in self.all_model_names: | ||||||
torch.manual_seed(0) | ||||||
model = self.trl_model_class.from_pretrained(model_name) | ||||||
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir: | ||||||
model.save_pretrained(tmp_dir) | ||||||
|
||||||
torch.manual_seed(0) | ||||||
model_from_save = self.trl_model_class.from_pretrained(tmp_dir) | ||||||
# Check if the weights are the same | ||||||
|
||||||
# Check if the weights are the same | ||||||
for key in model_from_save.state_dict(): | ||||||
self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) | ||||||
|
||||||
def test_from_save_transformers(self): | ||||||
""" | ||||||
Test if the model can be saved and loaded using transformers and get the same weights | ||||||
|
@@ -60,19 +61,25 @@ def test_from_save_transformers(self): | |||||
transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name) | ||||||
|
||||||
trl_model = self.trl_model_class.from_pretrained(model_name) | ||||||
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir: | ||||||
trl_model.save_pretrained(tmp_dir) | ||||||
transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained(tmp_dir) | ||||||
# Check if the weights are the same | ||||||
|
||||||
# Check if the weights are the same | ||||||
for key in transformers_model.state_dict(): | ||||||
self.assertTrue(torch.allclose(transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key])) | ||||||
self.assertTrue( | ||||||
torch.allclose( | ||||||
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] | ||||||
) | ||||||
) | ||||||
|
||||||
|
||||||
class VHeadModelTester(BaseModelTester, unittest.TestCase): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not for this pr, but a small nit that being explicit with e.g. the same nit applies to docstrings using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in #65 ! |
||||||
""" | ||||||
Testing suite for v-head models. | ||||||
Testing suite for v-head models. | ||||||
""" | ||||||
|
||||||
all_model_names = ALL_CAUSAL_LM_MODELS | ||||||
trl_model_class = AutoModelForCausalLMWithValueHead | ||||||
|
||||||
|
@@ -83,25 +90,25 @@ def test_vhead(self): | |||||
for model_name in self.all_model_names: | ||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) | ||||||
self.assertTrue(hasattr(model, "v_head")) | ||||||
|
||||||
def test_vhead_nb_classes(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would this be better?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd address this too in a separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in #65 ! |
||||||
r""" | ||||||
Test if the v-head has the correct shape | ||||||
""" | ||||||
for model_name in self.all_model_names: | ||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) | ||||||
self.assertTrue(model.v_head.summary.weight.shape[0] == 1) | ||||||
|
||||||
def test_vhead_init_random(self): | ||||||
r""" | ||||||
Test if the v-head has been randomly initialized. | ||||||
We can check that by making sure the bias is different | ||||||
Test if the v-head has been randomly initialized. | ||||||
We can check that by making sure the bias is different | ||||||
than zeros by default. | ||||||
""" | ||||||
for model_name in self.all_model_names: | ||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) | ||||||
self.assertFalse(torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))) | ||||||
self.assertFalse(torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))) | ||||||
|
||||||
def test_vhead_not_str(self): | ||||||
r""" | ||||||
Test if the v-head is added to the model succesfully | ||||||
|
@@ -113,8 +120,8 @@ def test_vhead_not_str(self): | |||||
|
||||||
def test_inference(self): | ||||||
r""" | ||||||
Test if the model can be used for inference and outputs 3 values | ||||||
- logits, loss, and value states | ||||||
Test if the model can be used for inference and outputs 3 values | ||||||
- logits, loss, and value states | ||||||
""" | ||||||
EXPECTED_OUTPUT_SIZE = 3 | ||||||
|
||||||
|
@@ -123,10 +130,10 @@ def test_inference(self): | |||||
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) | ||||||
outputs = model(input_ids) | ||||||
|
||||||
# Check if the outputs are of the right size - here | ||||||
# Check if the outputs are of the right size - here | ||||||
# we always output 3 values - logits, loss, and value states | ||||||
self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) | ||||||
|
||||||
def test_dropout_config(self): | ||||||
r""" | ||||||
Test if we instantiate a model by adding `summary_drop_prob` to the config | ||||||
|
@@ -139,7 +146,7 @@ def test_dropout_config(self): | |||||
|
||||||
# Check if v head of the model has the same dropout as the config | ||||||
self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) | ||||||
|
||||||
def test_dropout_kwargs(self): | ||||||
r""" | ||||||
Test if we instantiate a model by adding `summary_drop_prob` to the config | ||||||
|
@@ -157,10 +164,10 @@ def test_dropout_kwargs(self): | |||||
|
||||||
# Check if v head of the model has the same dropout as the config | ||||||
self.assertEqual(model.v_head.dropout.p, 0.5) | ||||||
|
||||||
def test_generate(self): | ||||||
r""" | ||||||
Test if `generate` works for every model | ||||||
Test if `generate` works for every model | ||||||
""" | ||||||
for model_name in self.all_model_names: | ||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) | ||||||
|
@@ -175,4 +182,4 @@ def test_raise_error_not_causallm(self): | |||||
# This should raise a ValueError | ||||||
with self.assertRaises(ValueError): | ||||||
pretrained_model = AutoModelForCausalLM.from_pretrained(model_id) | ||||||
_ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer) | ||||||
_ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# flake8: noqa | ||
|
||
__version__ = "0.1.1" | ||
|
||
from .models import AutoModelForCausalLMWithValueHead | ||
from .models import AutoModelForCausalLMWithValueHead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see my comment below about considering v3.8 as the minimum version since this one will be eol in 6 months
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good