Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[warnings] fix E721 warnings #32223

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None
for processor in self.logits_processor:
if type(processor) == MinLengthLogitsProcessor:
if isinstance(processor, MinLengthLogitsProcessor):
raise ValueError(
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Please pass in `min_length` into `.generate()` instead"
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_flax_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,7 +1599,7 @@ def __call__(
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)

# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer):
if len(jnp.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def test_chunk_size(chunk_size: int) -> bool:
def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
consistent = True
for a1, a2 in zip(ac1, ac2):
assert type(ac1) == type(ac2)
assert type(ac1) is type(ac2)
if isinstance(ac1, (list, tuple)):
consistent &= self._compare_arg_caches(a1, a2)
elif isinstance(ac1, dict):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mbart/modeling_flax_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ def __call__(
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)

# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer):
if len(jnp.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
"""
if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)):
assert (
type(tensors) == type(new_tensors)
type(tensors) is type(new_tensors)
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _parse_type_hint(hint: str) -> Dict:

elif origin is Union:
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
subtypes = [_parse_type_hint(t) for t in args if t != type(None)]
subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
if len(subtypes) == 1:
# A single non-null type can be expressed directly
return_dict = subtypes[0]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _is_tf_symbolic_tensor(x):
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14
if hasattr(tf, "is_symbolic_tensor"):
return tf.is_symbolic_tensor(x)
return type(x) == tf.Tensor
return isinstance(x, tf.Tensor)


def is_tf_symbolic_tensor(x):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/ibert/test_modeling_ibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,10 +684,10 @@ def quantize(self, model):
# Recursively convert all the `quant_mode` attributes as `True`
if hasattr(model, "quant_mode"):
model.quant_mode = True
elif type(model) == nn.Sequential:
elif isinstance(model, nn.Sequential):
for n, m in model.named_children():
self.quantize(m)
elif type(model) == nn.ModuleList:
elif isinstance(model, nn.ModuleList):
for n in model:
self.quantize(n)
else:
Expand Down
Loading