Skip to content

Commit 6eec980

Browse files
committed
Ruff check and format
Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>
1 parent 0f29e27 commit 6eec980

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

QEfficient/transformers/embeddings/embedding_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def forward(
5757
output = self.base_model(input_ids, attention_mask, **kwargs)
5858
return self.pooling_fn(output[0], attention_mask)
5959

60+
6061
def validate_user_pooling_function(user_function):
6162
if not callable(user_function):
6263
raise TypeError("Provided pooling function is not callable.")
@@ -65,4 +66,4 @@ def validate_user_pooling_function(user_function):
6566
required_args = {"last_hidden_states", "attention_mask"}
6667
if not required_args.issubset(sig.parameters.keys()):
6768
raise ValueError(f"Pooling function must accept arguments: {required_args}")
68-
return user_function
69+
return user_function

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,9 @@ def apply(cls, model: nn.Module, **kwargs) -> Tuple[nn.Module, bool]:
502502
transformed = False
503503
if kwargs.get("pooling") is not None:
504504
pooling = kwargs["pooling"]
505-
pooling_method = POOLING_MAP[pooling] if isinstance(pooling,str) else validate_user_pooling_function(pooling)
505+
pooling_method = (
506+
POOLING_MAP[pooling] if isinstance(pooling, str) else validate_user_pooling_function(pooling)
507+
)
506508
model = PooledModel(model, pooling_method)
507509
warnings.warn(f"Pooling method {pooling.__name__} is applied to the model.")
508510
return model, transformed

examples/embedding_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313

1414
from QEfficient import QEFFAutoModel as AutoModel
1515

16+
1617
def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
1718
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
1819
last_hidden_states[input_mask_expanded == 0] = -1e9
1920
return torch.max(last_hidden_states, 1)[0]
2021

22+
2123
# Sentences we want sentence embeddings for
2224
sentences = "This is an example sentence"
2325

@@ -28,7 +30,7 @@ def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor)
2830
qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling=max_pooling)
2931

3032
# Here seq_len can be list seq_len or single int
31-
qeff_model.compile(num_cores=16, seq_len=[32,64])
33+
qeff_model.compile(num_cores=16, seq_len=[32, 64])
3234

3335
# Tokenize sentences
3436
encoded_input = tokenizer(sentences, return_tensors="pt")

0 commit comments

Comments
 (0)