Skip to content

Commit

Permalink
Replace all deprecated jax.ops operations with jnp's at (huggingf…
Browse files Browse the repository at this point in the history
…ace#16078)

* Replace all deprecated `jax.ops` operations with jnp's `at`

* np to jnp scores

* suggested changes
  • Loading branch information
sanchit-gandhi authored Mar 16, 2022
1 parent c2dc89b commit ee27b3d
Show file tree
Hide file tree
Showing 13 changed files with 28 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")

emb = jnp.zeros((50264, model.config.hidden_size))
# update the first 50257 weights using pre-trained weights
emb = jax.ops.index_update(emb, jax.ops.index[:50257, :], model.params["transformer"]["wte"]["embedding"])
emb = emb.at[jax.ops.index[:50257, :]].set(model.params["transformer"]["wte"]["embedding"])
params = model.params
params["transformer"]["wte"]["embedding"] = emb

Expand Down
19 changes: 7 additions & 12 deletions src/transformers/generation_flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,11 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) ->
score_mask = cumulative_probs < self.top_p

# include the token that is higher than top_p as well
score_mask |= jax.ops.index_update(jnp.roll(score_mask, 1), jax.ops.index[:, 0], True)
score_mask = jnp.roll(score_mask, 1)
score_mask |= score_mask.at[jax.ops.index[:, 0]].set(True)

# min tokens to keep
score_mask = jax.ops.index_update(score_mask, jax.ops.index[:, : self.min_tokens_to_keep], True)
score_mask = score_mask.at[jax.ops.index[:, : self.min_tokens_to_keep]].set(True)

topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
Expand Down Expand Up @@ -184,7 +185,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) ->
topk_scores_flat = topk_scores.flatten()
topk_indices_flat = topk_indices.flatten() + shift

next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
next_scores_flat = next_scores_flat.at[topk_indices_flat].set(topk_scores_flat)
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
return next_scores

Expand All @@ -206,9 +207,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) ->

apply_penalty = 1 - jnp.bool_(cur_len - 1)

scores = jnp.where(
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores
)
scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.bos_token_id]].set(0), scores)

return scores

Expand All @@ -233,9 +232,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) ->

apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)

scores = jnp.where(
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.eos_token_id], 0), scores
)
scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.eos_token_id]].set(0), scores)

return scores

Expand Down Expand Up @@ -266,8 +263,6 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) ->
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)

scores = jnp.where(
apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
)
scores = jnp.where(apply_penalty, scores.at[jax.ops.index[:, self.eos_token_id]].set(-float("inf")), scores)

return scores
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_flax_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/big_bird/modeling_flax_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,7 +2124,7 @@ def __call__(
if token_type_ids is None:
token_type_ids = (~logits_mask).astype("i4")
logits_mask = jnp.expand_dims(logits_mask, axis=2)
logits_mask = jax.ops.index_update(logits_mask, jax.ops.index[:, 0], False)
logits_mask = logits_mask.at[jax.ops.index[:, 0]].set(False)

# init input tensors if not passed
if token_type_ids is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/marian/modeling_flax_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxMarianForSequenceClassificationModule
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
Expand Down Expand Up @@ -1422,7 +1422,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_

def _adapt_logits_for_beam_search(self, logits):
"""This function enforces the padding token never to be generated."""
logits = jax.ops.index_update(logits, jax.ops.index[:, :, self.config.pad_token_id], float("-inf"))
logits = logits.at[jax.ops.index[:, :, self.config.pad_token_id]].set(float("-inf"))
return logits

def prepare_inputs_for_generation(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mbart/modeling_flax_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxMBartForSequenceClassificationModule
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,8 @@ def __call__(

# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask = jax.ops.index_update(
attention_mask, jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1], 1
)
idx = jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1]
attention_mask = attention_mask.at[idx].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")

hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
Expand Down Expand Up @@ -1038,7 +1037,8 @@ def _get_feature_vector_attention_mask(

attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask = attention_mask.at[(jnp.arange(attention_mask.shape[0]), output_lengths - 1)].set(1)
idx = (jnp.arange(attention_mask.shape[0]), output_lengths - 1)
attention_mask = attention_mask.at[idx].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)

attention_mask = jnp.array(attention_mask, dtype=bool)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/xglm/modeling_flax_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)

Expand Down Expand Up @@ -2040,7 +2040,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
Expand Down
6 changes: 3 additions & 3 deletions tests/generation/test_generation_flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
@require_flax
class LogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = np.ones((batch_size, length)) / length
scores = jnp.ones((batch_size, length)) / length
return scores

def test_temperature_dist_warper(self):
Expand All @@ -51,8 +51,8 @@ def test_temperature_dist_warper(self):
scores = self._get_uniform_logits(batch_size=2, length=length)

# tweak scores to not be uniform anymore
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
scores = scores.at[1, 5].set((1 / length) + 0.1) # peak, 1st batch
scores = scores.at[1, 10].set((1 / length) - 0.4) # valley, 1st batch

# compute softmax
probs = jax.nn.softmax(scores, axis=-1)
Expand Down
7 changes: 3 additions & 4 deletions tests/generation/test_generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
if is_flax_available():
import os

import jax
import jax.numpy as jnp
from jax import jit
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
Expand Down Expand Up @@ -219,7 +218,7 @@ def test_greedy_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()

# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
attention_mask = attention_mask.at[(0, 0)].set(0)

config.do_sample = False
config.max_length = max_length
Expand All @@ -239,7 +238,7 @@ def test_sample_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()

# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
attention_mask = attention_mask.at[(0, 0)].set(0)

config.do_sample = True
config.max_length = max_length
Expand All @@ -259,7 +258,7 @@ def test_beam_search_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()

# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
attention_mask = attention_mask.at[(0, 0)].set(0)

config.num_beams = 2
config.max_length = max_length
Expand Down

0 comments on commit ee27b3d

Please sign in to comment.