Skip to content

Commit f5e6868

Browse files
committed
fix doc
1 parent c6bee00 commit f5e6868

File tree

5 files changed

+139
-15
lines changed

5 files changed

+139
-15
lines changed

keras_hub/src/layers/preprocessing/multi_segment_packer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,10 @@ def call(
281281
sequence_length=None,
282282
add_start_value=True,
283283
add_end_value=True,
284+
padding_side=None,
284285
):
285286
inputs, unbatched = self._sanitize_inputs(inputs)
286-
287+
padding_side = padding_side or self.padding_side
287288
segments = self._trim_inputs(inputs)
288289
token_ids, segment_ids = self._combine_inputs(
289290
segments,
@@ -296,13 +297,13 @@ def call(
296297
token_ids = pad(
297298
token_ids,
298299
shape=shape,
299-
padding_side=self.padding_side,
300+
padding_side=padding_side,
300301
pad_value=self.pad_value,
301302
)
302303
segment_ids = pad(
303304
segment_ids,
304305
shape=shape,
305-
padding_side=self.padding_side,
306+
padding_side=padding_side,
306307
pad_value=0,
307308
)
308309
# Remove the batch dim if added.

keras_hub/src/layers/preprocessing/start_end_packer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,12 @@ def call(
152152
sequence_length=None,
153153
add_start_value=True,
154154
add_end_value=True,
155+
padding_side=None,
155156
):
156157
inputs, unbatched, rectangular = convert_to_ragged_batch(inputs)
157158
x = inputs # Intermediate result.
158-
159159
batch_size = tf.shape(x)[0]
160+
padding_side = padding_side or self.padding_side
160161
sequence_length = sequence_length or self.sequence_length
161162
dtype = inputs.dtype
162163
# Truncate.
@@ -185,7 +186,7 @@ def call(
185186
outputs = pad(
186187
x,
187188
pad_value=self.pad_value,
188-
padding_side=self.padding_side,
189+
padding_side=padding_side,
189190
shape=(batch_size, sequence_length),
190191
)
191192
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
@@ -196,7 +197,7 @@ def call(
196197
mask = pad(
197198
mask,
198199
pad_value=False,
199-
padding_side=self.padding_side,
200+
padding_side=padding_side,
200201
shape=(batch_size, sequence_length),
201202
)
202203
mask = tf.squeeze(mask, axis=0) if unbatched else mask

keras_hub/src/models/causal_lm_preprocessor.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
from keras_hub.src.utils.tensor_utils import preprocessing_function
77
from keras_hub.src.utils.tensor_utils import strip_to_ragged
88

9+
try:
10+
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
11+
except ImportError:
12+
dynamic_update_slice = None
13+
914

1015
@keras_hub_export("keras_hub.models.CausalLMPreprocessor")
1116
class CausalLMPreprocessor(Preprocessor):
@@ -64,6 +69,7 @@ def __init__(
6469
sequence_length=1024,
6570
add_start_token=True,
6671
add_end_token=True,
72+
padding_side="right",
6773
**kwargs,
6874
):
6975
super().__init__(**kwargs)
@@ -72,6 +78,8 @@ def __init__(
7278
self.sequence_length = sequence_length
7379
self.add_start_token = add_start_token
7480
self.add_end_token = add_end_token
81+
assert padding_side in ["right", "left"]
82+
self.padding_side = padding_side
7583

7684
def build(self, input_shape):
7785
# Defer packer creation to `build()` so that we can be sure tokenizer
@@ -82,6 +90,7 @@ def build(self, input_shape):
8290
pad_value=self.tokenizer.pad_token_id,
8391
sequence_length=self.sequence_length,
8492
return_padding_mask=True,
93+
padding_side=self.padding_side,
8594
)
8695
self.built = True
8796

@@ -92,16 +101,38 @@ def call(
92101
y=None,
93102
sample_weight=None,
94103
sequence_length=None,
104+
padding_side=None,
95105
):
96-
sequence_length = sequence_length or self.sequence_length
97106
x = self.tokenizer(x)
98-
# Pad with one extra token to account for the truncation below.
99-
token_ids, padding_mask = self.packer(
100-
x,
101-
sequence_length=sequence_length + 1,
102-
add_start_value=self.add_start_token,
103-
add_end_value=self.add_end_token,
104-
)
107+
padding_side = padding_side or self.padding_side
108+
sequence_length = sequence_length or self.sequence_length
109+
if padding_side == "left":
110+
token_ids, padding_mask = self.packer(
111+
x,
112+
sequence_length=x.to_tensor().shape[-1] + 2,
113+
add_start_value=self.add_start_token,
114+
add_end_value=self.add_end_token,
115+
padding_side=padding_side,
116+
)
117+
token_ids, all_padding_mask = self.packer(
118+
token_ids,
119+
sequence_length=sequence_length + 1,
120+
add_start_value=False,
121+
add_end_value=False,
122+
padding_side="right",
123+
)
124+
padding_mask = dynamic_update_slice(
125+
all_padding_mask, padding_mask, [0] * len(padding_mask.shape)
126+
)
127+
else:
128+
# Pad with one extra token to account for the truncation below.
129+
token_ids, padding_mask = self.packer(
130+
x,
131+
sequence_length=sequence_length + 1,
132+
add_start_value=self.add_start_token,
133+
add_end_value=self.add_end_token,
134+
padding_side=padding_side,
135+
)
105136
# The last token does not have a next token, so we truncate it out.
106137
x = {
107138
"token_ids": token_ids[..., :-1],
@@ -166,6 +197,7 @@ def get_config(self):
166197
"sequence_length": self.sequence_length,
167198
"add_start_token": self.add_start_token,
168199
"add_end_token": self.add_end_token,
200+
"padding_side": self.padding_side,
169201
}
170202
)
171203
return config

keras_hub/src/models/causal_lm_preprocessor_test.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,97 @@ def test_preset_accessors(self):
1717
self.assertTrue(bert_presets.isdisjoint(all_presets))
1818
self.assertTrue(gpt2_presets.issubset(all_presets))
1919

20+
def test_padding_side_call(self):
21+
preprocessor = CausalLMPreprocessor.from_preset(
22+
"gpt2_base_en", sequence_length=7
23+
)
24+
# left pad
25+
outputs = preprocessor(
26+
["i love you", "this is keras hub"], padding_side="left"
27+
)
28+
self.assertAllClose(
29+
outputs,
30+
(
31+
{
32+
"token_ids": [
33+
[0, 0, 50256, 72, 1842, 345, 50256],
34+
[50256, 5661, 318, 41927, 292, 12575, 50256],
35+
],
36+
"padding_mask": [
37+
[
38+
False,
39+
False,
40+
True,
41+
True,
42+
True,
43+
True,
44+
True,
45+
],
46+
[
47+
True,
48+
True,
49+
True,
50+
True,
51+
True,
52+
True,
53+
True,
54+
],
55+
],
56+
},
57+
[
58+
[0, 50256, 72, 1842, 345, 50256, 0],
59+
[5661, 318, 41927, 292, 12575, 50256, 0],
60+
],
61+
[
62+
[False, True, True, True, True, True, False],
63+
[True, True, True, True, True, True, False],
64+
],
65+
),
66+
)
67+
# right pad
68+
outputs = preprocessor(
69+
["i love you", "this is keras hub"], padding_side="right"
70+
)
71+
self.assertAllClose(
72+
outputs,
73+
(
74+
{
75+
"token_ids": [
76+
[50256, 72, 1842, 345, 50256, 0, 0],
77+
[50256, 5661, 318, 41927, 292, 12575, 50256],
78+
],
79+
"padding_mask": [
80+
[
81+
True,
82+
True,
83+
True,
84+
True,
85+
True,
86+
False,
87+
False,
88+
],
89+
[
90+
True,
91+
True,
92+
True,
93+
True,
94+
True,
95+
True,
96+
True,
97+
],
98+
],
99+
},
100+
[
101+
[72, 1842, 345, 50256, 0, 0, 0],
102+
[5661, 318, 41927, 292, 12575, 50256, 0],
103+
],
104+
[
105+
[True, True, True, True, True, False, False, False],
106+
[True, True, True, True, True, True, False],
107+
],
108+
),
109+
)
110+
20111
@pytest.mark.large
21112
def test_from_preset(self):
22113
self.assertIsInstance(

keras_hub/src/models/esm/esm_backbone.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class ESMBackbone(Backbone):
7676
num_heads=4,
7777
hidden_dim=256,
7878
intermediate_dim=512,
79-
num_heads = 4,
8079
)
8180
model(input_data)
8281
```

0 commit comments

Comments
 (0)