Skip to content

Commit

Permalink
fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Aug 12, 2024
1 parent 7121f3c commit df4be6f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 34 deletions.
20 changes: 9 additions & 11 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,16 @@ def from_pretrained(cls,
ggml_tensor_qtype, FP4Params

if isinstance(model.lm_head, torch.nn.Linear):
new_linear = LowBitLinear(
model.lm_head.in_features,
model.lm_head.out_features,
ggml_tensor_qtype["sym_int4"],
False
)
new_linear = LowBitLinear(model.lm_head.in_features,
model.lm_head.out_features,
ggml_tensor_qtype["sym_int4"],
False)
paramsLowBit = FP4Params(data=model.lm_head.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
qtype=ggml_tensor_qtype["sym_int4"],
in_features=model.lm_head.in_features).to("cpu")
requires_grad=False,
quantized=False,
_shape=None,
qtype=ggml_tensor_qtype["sym_int4"],
in_features=model.lm_head.in_features).to("cpu")
new_linear._parameters['weight'] = paramsLowBit
model.lm_head = new_linear

Expand Down
14 changes: 8 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_models/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def init_fused_kv_cache(batch_size, num_heads, head_dim, current_length, max_len
max_length, head_dim,
dtype=dtype, device=device)
value_cache_storage = torch.zeros(batch_size, num_heads,
max_length, head_dim,
dtype=dtype, device=device)
max_length, head_dim,
dtype=dtype, device=device)

key_cache = key_cache_storage.as_strided((batch_size, num_heads,
current_length, head_dim),
Expand Down Expand Up @@ -57,9 +57,9 @@ class DynamicFusedNormalCache(DynamicCache):
KV_ALLOC_BLOCK_LENGTH = 256

def __init__(self) -> None:
self.key_cache: Dict[int, torch.Tensor] = {}
self.key_cache: Dict[int, torch.Tensor] = {}
self.value_cache: Dict[int, torch.Tensor] = {}
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep how many tokens the cache has seen

def update(
self,
Expand All @@ -85,7 +85,8 @@ def update(
# Update the cache
# if len(self.key_cache) <= layer_idx:
if layer_idx not in self.key_cache:
max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + self.KV_ALLOC_BLOCK_LENGTH
max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + \
self.KV_ALLOC_BLOCK_LENGTH
k_cache, v_cache = init_fused_kv_cache(
batch_size, num_heads, head_dim,
0, max_len,
Expand All @@ -107,7 +108,8 @@ def update(
return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
"""Returns the sequence length of the cached states.
A layer index can be optionally passed."""

for idx, layer in self.key_cache.items():
return layer.shape[-2]
18 changes: 7 additions & 11 deletions python/llm/src/ipex_llm/transformers/npu_models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def llama_fused_model_forward(

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
cache_position, past_seen_tokens)

Expand All @@ -247,21 +247,17 @@ def llama_fused_model_forward(
seq_len = hidden_states.size(1)

if seq_len == 1:
# assert hasattr(self, "multi_decoder")
# multi_decoder = self.layers[(self.layer_end + 1) % num_layers]
layer_outputs = self.multi_decoder(hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,)
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,)
hidden_states = layer_outputs[0]

assert use_cache
next_decoder_cache = layer_outputs[1]

assert not output_attentions
else:
for decoder_layer in self.layers:
if output_hidden_states:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def pipeline_parallel_generate(self,
bs = inputs_tensor.shape[0]
if model_kwargs.get("attention_mask", None) is None:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id)
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id)
if self.config.is_encoder_decoder:
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=bs,
Expand All @@ -289,7 +289,7 @@ def pipeline_parallel_generate(self,
else:
input_ids = inputs_tensor if model_input_name == "input_ids" \
else model_kwargs.pop("input_ids")

local_rank = dist.get_rank()
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
next_rank = (local_rank + 1) % self.pipeline_parallel_stages
Expand Down Expand Up @@ -325,7 +325,7 @@ def pipeline_parallel_generate(self,

if _input_ids is None:
_input_ids = input_ids

model_inputs = self.prepare_inputs_for_generation(output_ids, **model_kwargs)

tic = time.time()
Expand Down Expand Up @@ -360,8 +360,8 @@ def pipeline_parallel_generate(self,
output_ids = torch.cat([output_ids, next_ids], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)

# finished sentences should have their next token be a padding token
next_ids = next_ids.squeeze()
Expand Down Expand Up @@ -602,7 +602,7 @@ def glm4_conditional_generation_forward_lowmem(
hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[:, -1:]

device = hidden_states.device
# ipex-llm change starts
if device.type == "xpu":
Expand Down

0 comments on commit df4be6f

Please sign in to comment.