Skip to content

Conversation

@sourcery-ai
Copy link

@sourcery-ai sourcery-ai bot commented Jul 17, 2023

Branch main refactored by Sourcery.

If you're happy with these changes, merge this Pull Request using the Squash and merge strategy.

See our documentation here.

Run Sourcery locally

Reduce the feedback loop during development by using the Sourcery editor plugin:

Review changes via command line

To manually merge these changes, make sure you're on the main branch, then run:

git fetch origin sourcery/main
git merge --ff-only FETCH_HEAD
git reset HEAD^

Help us improve this pull request!

@sourcery-ai sourcery-ai bot requested a review from admariner July 17, 2023 10:11
Copy link
Author

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to GitHub API limits, only the first 60 comments can be shown.

return (sample,)

return UNet2DConditionOutput(sample=sample)
return (sample, ) if not return_dict else UNet2DConditionOutput(sample=sample)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function unet_forward_XTI refactored with the following changes:

Comment on lines -118 to +121
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
save_stable_diffusion_format = args.save_model_as.lower() in [
"ckpt",
"safetensors",
]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function train refactored with the following changes:

Comment on lines -28 to +32
# self.lib = ct.cdll.LoadLibrary(binary_path)
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
else:
print(f"CUDA SETUP: Loading binary {binary_path}...")
# self.lib = ct.cdll.LoadLibrary(binary_path)
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$

# self.lib = ct.cdll.LoadLibrary(binary_path)
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function CUDALibrary_Singleton.initialize refactored with the following changes:

# TODO: handle different compute capabilities; for now, take the max
return ccs[-1]
return None
return ccs[-1] if ccs is not None else None
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_compute_capability refactored with the following changes:

This removes the following comments ( why? ):

# TODO: handle different compute capabilities; for now, take the max

Comment on lines -119 to -166

binary_name = "libbitsandbytes_cpu.so"
#if not torch.cuda.is_available():
#print('No GPU detected. Loading CPU library...')
#return binary_name

cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
print(
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
)
return binary_name

print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
cuda_version_string = get_cuda_version(cuda, cudart_path)


if cc == '':
print(
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
)
return binary_name

# 7.5 is the minimum CC vor cublaslt
has_cublaslt = cc in ["7.5", "8.0", "8.6"]

# TODO:
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed

# we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler
print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')

def get_binary_name():
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
bin_base_name = "libbitsandbytes_cuda"
if has_cublaslt:
return f"{bin_base_name}{cuda_version_string}.so"
else:
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"

binary_name = get_binary_name()

return binary_name
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function evaluate_cuda_setup refactored with the following changes:

This removes the following comments ( why? ):

# TODO:
#if not torch.cuda.is_available():
# we use ls -l instead of nvcc to determine the cuda version
# (2) Multiple CUDA versions installed
# 7.5 is the minimum CC vor cublaslt
#return binary_name
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
#print('No GPU detected. Loading CPU library...')
# since most installations will have the libcudart.so installed, but not the compiler

Comment on lines -109 to +123
image_embeds = self.visual_encoder(image)
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)

text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)

text.input_ids[:,0] = self.tokenizer.bos_token_id
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)

decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:,:self.prompt_length] = -100

decoder_output = self.text_decoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
labels = decoder_targets,
return_dict = True,
)
loss_lm = decoder_output.loss

return loss_lm
)
return decoder_output.loss
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BLIP_Decoder.forward refactored with the following changes:

Comment on lines -225 to +233

state_dict = checkpoint['model']
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)

state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
model.visual_encoder_m)
model.visual_encoder_m)
for key in model.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape!=model.state_dict()[key].shape:
del state_dict[key]

msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
print(f'load checkpoint from {url_or_filename}')
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function load_checkpoint refactored with the following changes:

"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BertSelfAttention.__init__ refactored with the following changes:

attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
if self.position_embedding_type in ["relative_key", "relative_key_query"]:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BertSelfAttention.forward refactored with the following changes:

attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
return (attention_output,) + self_outputs[1:]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BertAttention.forward refactored with the following changes:

This removes the following comments ( why? ):

# add attentions if we output them

Comment on lines -382 to +381
layer_output = self.output(intermediate_output, attention_output)
return layer_output
return self.output(intermediate_output, attention_output)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BertLayer.feed_forward_chunk refactored with the following changes:

Comment on lines -544 to +542
prediction_scores = self.predictions(sequence_output)
return prediction_scores
return self.predictions(sequence_output)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BertOnlyMLMHead.forward refactored with the following changes:

Comment on lines -640 to +637

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BertModel.get_extended_attention_mask refactored with the following changes:

for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
mha_prefix = f'{block_prefix}MultiHeadDotProductAttention_1/'
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _load_weights refactored with the following changes:

Comment on lines -282 to -283
# interpolate position embedding
embedding_size = pos_embed_checkpoint.shape[-1]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function interpolate_pos_embed refactored with the following changes:

Comment on lines -520 to +528
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
return next(
(
torch.device(module._hf_hook.execution_device)
for module in self.unet.modules()
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
)
),
self.device,
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function StableDiffusionLongPromptWeightingPipeline._execution_device refactored with the following changes:

  • Use the built-in function next instead of a for-loop (use-next)

Comment on lines -599 to +603
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
if (
callback_steps is None
or not isinstance(callback_steps, int)
or callback_steps <= 0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function StableDiffusionLongPromptWeightingPipeline.check_inputs refactored with the following changes:

Comment on lines -609 to +619
else:
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(device)
return timesteps, num_inference_steps - t_start
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(device)
return timesteps, num_inference_steps - t_start
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function StableDiffusionLongPromptWeightingPipeline.get_timesteps refactored with the following changes:

Comment on lines -667 to +673
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
elif latents.shape == shape:
latents = latents.to(device)

else:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function StableDiffusionLongPromptWeightingPipeline.prepare_latents refactored with the following changes:


mapping.append({"old": old_item, "new": new_item})

mapping.append({"old": new_item, "new": new_item})
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function renew_attention_paths refactored with the following changes:

This removes the following comments ( why? ):

#         new_item = new_item.replace('norm.bias', 'group_norm.bias')
#         new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
#         new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
#         new_item = new_item.replace('norm.weight', 'group_norm.weight')
#         new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')

Comment on lines -226 to +239
# extract state_dict for UNet
unet_state_dict = {}
unet_key = "model.diffusion_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)

new_checkpoint = {}

new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]

new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]

new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
unet_state_dict = {
key.replace(unet_key, ""): checkpoint.pop(key)
for key in keys
if key.startswith(unet_key)
}
new_checkpoint = {
"time_embedding.linear_1.weight": unet_state_dict[
"time_embed.0.weight"
],
"time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"],
"time_embedding.linear_2.weight": unet_state_dict[
"time_embed.2.weight"
],
"time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"],
"conv_in.weight": unet_state_dict["input_blocks.0.0.weight"],
"conv_in.bias": unet_state_dict["input_blocks.0.0.bias"],
"conv_norm_out.weight": unet_state_dict["out.0.weight"],
"conv_norm_out.bias": unet_state_dict["out.0.bias"],
"conv_out.weight": unet_state_dict["out.2.weight"],
"conv_out.bias": unet_state_dict["out.2.bias"],
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function convert_ldm_unet_checkpoint refactored with the following changes:

This removes the following comments ( why? ):

# extract state_dict for UNet

Comment on lines -371 to +395
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
vae_state_dict = {
key.replace(vae_key, ""): checkpoint.get(key)
for key in keys
if key.startswith(vae_key)
}
# if len(vae_state_dict) == 0:
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
# vae_state_dict = checkpoint

new_checkpoint = {}

new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]

new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]

new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
new_checkpoint = {
"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
"encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
"encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"],
"encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"],
"encoder.conv_norm_out.weight": vae_state_dict[
"encoder.norm_out.weight"
],
"encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"],
"decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"],
"decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"],
"decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"],
"decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"],
"decoder.conv_norm_out.weight": vae_state_dict[
"decoder.norm_out.weight"
],
"decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"],
"quant_conv.weight": vae_state_dict["quant_conv.weight"],
"quant_conv.bias": vae_state_dict["quant_conv.bias"],
"post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"],
"post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"],
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function convert_ldm_vae_checkpoint refactored with the following changes:

This removes the following comments ( why? ):

# extract state dict for VAE

Comment on lines -491 to +485
for i in range(len(block_out_channels)):
for _ in block_out_channels:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function create_unet_diffusers_config refactored with the following changes:

Comment on lines -524 to +518
config = dict(
return dict(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function create_vae_diffusers_config refactored with the following changes:

Comment on lines -539 to +536
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
return text_model_dict
return {
key[len("cond_stage_model.transformer.") :]: checkpoint[key]
for key in keys
if key.startswith("cond_stage_model.transformer")
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function convert_ldm_clip_checkpoint_v1 refactored with the following changes:

if reso in self.predefined_resos_set:
pass
else:
if reso not in self.predefined_resos_set:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BucketManager.select_bucket refactored with the following changes:

Comment on lines -352 to +350
self.caption_extension = "." + self.caption_extension
self.caption_extension = f".{self.caption_extension}"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function DreamBoothSubset.__init__ refactored with the following changes:

Comment on lines -457 to +455
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
if self.current_epoch != epoch: # epochが切り替わったらバケツをシャッフルする
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BaseDataset.set_current_epoch refactored with the following changes:

  • Simplify logical expression using De Morgan identities (de-morgan)

Comment on lines -472 to +470
tag = tag.strip()
if tag:
if tag := tag.strip():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BaseDataset.set_tag_frequency refactored with the following changes:

Comment on lines -537 to +534
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
caption = random.choice(str_to) if type(str_to) == list else str_to
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BaseDataset.process_caption refactored with the following changes:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants