-
Notifications
You must be signed in to change notification settings - Fork 0
Sourcery refactored main branch #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to GitHub API limits, only the first 60 comments can be shown.
| return (sample,) | ||
|
|
||
| return UNet2DConditionOutput(sample=sample) | ||
| return (sample, ) if not return_dict else UNet2DConditionOutput(sample=sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function unet_forward_XTI refactored with the following changes:
- Lift code into else after jump in control flow (
reintroduce-else) - Replace if statement with if expression (
assign-if-exp)
| save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" | ||
| save_stable_diffusion_format = args.save_model_as.lower() in [ | ||
| "ckpt", | ||
| "safetensors", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function train refactored with the following changes:
- Replace multiple comparisons of same variable with
inoperator (merge-comparisons)
| # self.lib = ct.cdll.LoadLibrary(binary_path) | ||
| self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ | ||
| else: | ||
| print(f"CUDA SETUP: Loading binary {binary_path}...") | ||
| # self.lib = ct.cdll.LoadLibrary(binary_path) | ||
| self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ | ||
|
|
||
| # self.lib = ct.cdll.LoadLibrary(binary_path) | ||
| self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function CUDALibrary_Singleton.initialize refactored with the following changes:
- Hoist repeated code outside conditional statement (
hoist-statement-from-if)
| # TODO: handle different compute capabilities; for now, take the max | ||
| return ccs[-1] | ||
| return None | ||
| return ccs[-1] if ccs is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_compute_capability refactored with the following changes:
- Lift code into else after jump in control flow (
reintroduce-else) - Replace if statement with if expression (
assign-if-exp)
This removes the following comments ( why? ):
# TODO: handle different compute capabilities; for now, take the max
|
|
||
| binary_name = "libbitsandbytes_cpu.so" | ||
| #if not torch.cuda.is_available(): | ||
| #print('No GPU detected. Loading CPU library...') | ||
| #return binary_name | ||
|
|
||
| cudart_path = determine_cuda_runtime_lib_path() | ||
| if cudart_path is None: | ||
| print( | ||
| "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!" | ||
| ) | ||
| return binary_name | ||
|
|
||
| print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") | ||
| cuda = get_cuda_lib_handle() | ||
| cc = get_compute_capability(cuda) | ||
| print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") | ||
| cuda_version_string = get_cuda_version(cuda, cudart_path) | ||
|
|
||
|
|
||
| if cc == '': | ||
| print( | ||
| "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." | ||
| ) | ||
| return binary_name | ||
|
|
||
| # 7.5 is the minimum CC vor cublaslt | ||
| has_cublaslt = cc in ["7.5", "8.0", "8.6"] | ||
|
|
||
| # TODO: | ||
| # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) | ||
| # (2) Multiple CUDA versions installed | ||
|
|
||
| # we use ls -l instead of nvcc to determine the cuda version | ||
| # since most installations will have the libcudart.so installed, but not the compiler | ||
| print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}') | ||
|
|
||
| def get_binary_name(): | ||
| "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" | ||
| bin_base_name = "libbitsandbytes_cuda" | ||
| if has_cublaslt: | ||
| return f"{bin_base_name}{cuda_version_string}.so" | ||
| else: | ||
| return f"{bin_base_name}{cuda_version_string}_nocublaslt.so" | ||
|
|
||
| binary_name = get_binary_name() | ||
|
|
||
| return binary_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function evaluate_cuda_setup refactored with the following changes:
- Remove unreachable code (
remove-unreachable-code)
This removes the following comments ( why? ):
# TODO:
#if not torch.cuda.is_available():
# we use ls -l instead of nvcc to determine the cuda version
# (2) Multiple CUDA versions installed
# 7.5 is the minimum CC vor cublaslt
#return binary_name
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
#print('No GPU detected. Loading CPU library...')
# since most installations will have the libcudart.so installed, but not the compiler
| image_embeds = self.visual_encoder(image) | ||
| image_embeds = self.visual_encoder(image) | ||
| image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) | ||
|
|
||
| text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) | ||
|
|
||
| text.input_ids[:,0] = self.tokenizer.bos_token_id | ||
| decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) | ||
|
|
||
| decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) | ||
| decoder_targets[:,:self.prompt_length] = -100 | ||
|
|
||
| decoder_output = self.text_decoder(text.input_ids, | ||
| attention_mask = text.attention_mask, | ||
| encoder_hidden_states = image_embeds, | ||
| encoder_attention_mask = image_atts, | ||
| labels = decoder_targets, | ||
| return_dict = True, | ||
| ) | ||
| loss_lm = decoder_output.loss | ||
|
|
||
| return loss_lm | ||
| ) | ||
| return decoder_output.loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BLIP_Decoder.forward refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable)
|
|
||
| state_dict = checkpoint['model'] | ||
| state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) | ||
|
|
||
| state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) | ||
| if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): | ||
| state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], | ||
| model.visual_encoder_m) | ||
| model.visual_encoder_m) | ||
| for key in model.state_dict().keys(): | ||
| if key in state_dict.keys(): | ||
| if state_dict[key].shape!=model.state_dict()[key].shape: | ||
| del state_dict[key] | ||
|
|
||
| msg = model.load_state_dict(state_dict,strict=False) | ||
| print('load checkpoint from %s'%url_or_filename) | ||
| print(f'load checkpoint from {url_or_filename}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function load_checkpoint refactored with the following changes:
- Replace interpolated string formatting with f-string (
replace-interpolation-with-fstring)
| "heads (%d)" % (config.hidden_size, config.num_attention_heads) | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BertSelfAttention.__init__ refactored with the following changes:
- Replace multiple comparisons of same variable with
inoperator (merge-comparisons)
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | ||
|
|
||
| if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": | ||
| if self.position_embedding_type in ["relative_key", "relative_key_query"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BertSelfAttention.forward refactored with the following changes:
- Split conditional into multiple branches (
split-or-ifs) - Merge duplicate blocks in conditional (
merge-duplicate-blocks) - Remove redundant conditional (
remove-redundant-if) - Replace multiple comparisons of same variable with
inoperator (merge-comparisons)
| attention_output = self.output(self_outputs[0], hidden_states) | ||
| outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them | ||
| return outputs | ||
| return (attention_output,) + self_outputs[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BertAttention.forward refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable)
This removes the following comments ( why? ):
# add attentions if we output them
| layer_output = self.output(intermediate_output, attention_output) | ||
| return layer_output | ||
| return self.output(intermediate_output, attention_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BertLayer.feed_forward_chunk refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable)
| prediction_scores = self.predictions(sequence_output) | ||
| return prediction_scores | ||
| return self.predictions(sequence_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BertOnlyMLMHead.forward refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable)
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BertModel.get_extended_attention_mask refactored with the following changes:
- Replace call to format with f-string (
use-fstring-for-formatting)
| for i, block in enumerate(model.blocks.children()): | ||
| block_prefix = f'{prefix}Transformer/encoderblock_{i}/' | ||
| mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' | ||
| mha_prefix = f'{block_prefix}MultiHeadDotProductAttention_1/' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function _load_weights refactored with the following changes:
- Use f-string instead of string concatenation (
use-fstring-for-concatenation)
| # interpolate position embedding | ||
| embedding_size = pos_embed_checkpoint.shape[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function interpolate_pos_embed refactored with the following changes:
- Move assignments closer to their usage (
move-assign) - Swap if/else branches (
swap-if-else-branches) - Remove unnecessary else after guard condition (
remove-unnecessary-else)
| for module in self.unet.modules(): | ||
| if ( | ||
| hasattr(module, "_hf_hook") | ||
| and hasattr(module._hf_hook, "execution_device") | ||
| and module._hf_hook.execution_device is not None | ||
| ): | ||
| return torch.device(module._hf_hook.execution_device) | ||
| return self.device | ||
| return next( | ||
| ( | ||
| torch.device(module._hf_hook.execution_device) | ||
| for module in self.unet.modules() | ||
| if ( | ||
| hasattr(module, "_hf_hook") | ||
| and hasattr(module._hf_hook, "execution_device") | ||
| and module._hf_hook.execution_device is not None | ||
| ) | ||
| ), | ||
| self.device, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function StableDiffusionLongPromptWeightingPipeline._execution_device refactored with the following changes:
- Use the built-in function
nextinstead of a for-loop (use-next)
| if (callback_steps is None) or ( | ||
| callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) | ||
| if ( | ||
| callback_steps is None | ||
| or not isinstance(callback_steps, int) | ||
| or callback_steps <= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function StableDiffusionLongPromptWeightingPipeline.check_inputs refactored with the following changes:
- Remove redundant conditional (
remove-redundant-if)
| else: | ||
| # get the original timestep using init_timestep | ||
| offset = self.scheduler.config.get("steps_offset", 0) | ||
| init_timestep = int(num_inference_steps * strength) + offset | ||
| init_timestep = min(init_timestep, num_inference_steps) | ||
| # get the original timestep using init_timestep | ||
| offset = self.scheduler.config.get("steps_offset", 0) | ||
| init_timestep = int(num_inference_steps * strength) + offset | ||
| init_timestep = min(init_timestep, num_inference_steps) | ||
|
|
||
| t_start = max(num_inference_steps - init_timestep + offset, 0) | ||
| timesteps = self.scheduler.timesteps[t_start:].to(device) | ||
| return timesteps, num_inference_steps - t_start | ||
| t_start = max(num_inference_steps - init_timestep + offset, 0) | ||
| timesteps = self.scheduler.timesteps[t_start:].to(device) | ||
| return timesteps, num_inference_steps - t_start |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function StableDiffusionLongPromptWeightingPipeline.get_timesteps refactored with the following changes:
- Remove unnecessary else after guard condition (
remove-unnecessary-else)
| else: | ||
| if latents.shape != shape: | ||
| raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | ||
| elif latents.shape == shape: | ||
| latents = latents.to(device) | ||
|
|
||
| else: | ||
| raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function StableDiffusionLongPromptWeightingPipeline.prepare_latents refactored with the following changes:
- Merge else clause's nested if statement into elif (
merge-else-if-into-elif) - Lift code into else after jump in control flow (
reintroduce-else) - Swap if/else branches (
swap-if-else-branches)
|
|
||
| mapping.append({"old": old_item, "new": new_item}) | ||
|
|
||
| mapping.append({"old": new_item, "new": new_item}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function renew_attention_paths refactored with the following changes:
- Use previously assigned local variable (
use-assigned-variable)
This removes the following comments ( why? ):
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
| # extract state_dict for UNet | ||
| unet_state_dict = {} | ||
| unet_key = "model.diffusion_model." | ||
| keys = list(checkpoint.keys()) | ||
| for key in keys: | ||
| if key.startswith(unet_key): | ||
| unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) | ||
|
|
||
| new_checkpoint = {} | ||
|
|
||
| new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | ||
| new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | ||
| new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | ||
| new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | ||
|
|
||
| new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | ||
| new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | ||
|
|
||
| new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | ||
| new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | ||
| new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | ||
| new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | ||
| unet_state_dict = { | ||
| key.replace(unet_key, ""): checkpoint.pop(key) | ||
| for key in keys | ||
| if key.startswith(unet_key) | ||
| } | ||
| new_checkpoint = { | ||
| "time_embedding.linear_1.weight": unet_state_dict[ | ||
| "time_embed.0.weight" | ||
| ], | ||
| "time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"], | ||
| "time_embedding.linear_2.weight": unet_state_dict[ | ||
| "time_embed.2.weight" | ||
| ], | ||
| "time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"], | ||
| "conv_in.weight": unet_state_dict["input_blocks.0.0.weight"], | ||
| "conv_in.bias": unet_state_dict["input_blocks.0.0.bias"], | ||
| "conv_norm_out.weight": unet_state_dict["out.0.weight"], | ||
| "conv_norm_out.bias": unet_state_dict["out.0.bias"], | ||
| "conv_out.weight": unet_state_dict["out.2.weight"], | ||
| "conv_out.bias": unet_state_dict["out.2.bias"], | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function convert_ldm_unet_checkpoint refactored with the following changes:
- Move assignment closer to its usage within a block (
move-assign-in-block) - Merge dictionary assignment with declaration [×10] (
merge-dict-assign) - Convert for loop into dictionary comprehension (
dict-comprehension)
This removes the following comments ( why? ):
# extract state_dict for UNet
| # extract state dict for VAE | ||
| vae_state_dict = {} | ||
| vae_key = "first_stage_model." | ||
| keys = list(checkpoint.keys()) | ||
| for key in keys: | ||
| if key.startswith(vae_key): | ||
| vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | ||
| vae_state_dict = { | ||
| key.replace(vae_key, ""): checkpoint.get(key) | ||
| for key in keys | ||
| if key.startswith(vae_key) | ||
| } | ||
| # if len(vae_state_dict) == 0: | ||
| # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict | ||
| # vae_state_dict = checkpoint | ||
|
|
||
| new_checkpoint = {} | ||
|
|
||
| new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] | ||
| new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] | ||
| new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] | ||
| new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] | ||
| new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] | ||
| new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] | ||
|
|
||
| new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] | ||
| new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] | ||
| new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] | ||
| new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] | ||
| new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] | ||
| new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] | ||
|
|
||
| new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] | ||
| new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] | ||
| new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] | ||
| new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] | ||
| new_checkpoint = { | ||
| "encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"], | ||
| "encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"], | ||
| "encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"], | ||
| "encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"], | ||
| "encoder.conv_norm_out.weight": vae_state_dict[ | ||
| "encoder.norm_out.weight" | ||
| ], | ||
| "encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"], | ||
| "decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"], | ||
| "decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"], | ||
| "decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"], | ||
| "decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"], | ||
| "decoder.conv_norm_out.weight": vae_state_dict[ | ||
| "decoder.norm_out.weight" | ||
| ], | ||
| "decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"], | ||
| "quant_conv.weight": vae_state_dict["quant_conv.weight"], | ||
| "quant_conv.bias": vae_state_dict["quant_conv.bias"], | ||
| "post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"], | ||
| "post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"], | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function convert_ldm_vae_checkpoint refactored with the following changes:
- Move assignment closer to its usage within a block (
move-assign-in-block) - Merge dictionary assignment with declaration [×16] (
merge-dict-assign) - Convert for loop into dictionary comprehension (
dict-comprehension)
This removes the following comments ( why? ):
# extract state dict for VAE
| for i in range(len(block_out_channels)): | ||
| for _ in block_out_channels: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function create_unet_diffusers_config refactored with the following changes:
- Replace unused for index with underscore [×2] (
for-index-underscore) - Replace index in for loop with direct reference (
for-index-replacement)
| config = dict( | ||
| return dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function create_vae_diffusers_config refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable)
| text_model_dict = {} | ||
| for key in keys: | ||
| if key.startswith("cond_stage_model.transformer"): | ||
| text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] | ||
| return text_model_dict | ||
| return { | ||
| key[len("cond_stage_model.transformer.") :]: checkpoint[key] | ||
| for key in keys | ||
| if key.startswith("cond_stage_model.transformer") | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function convert_ldm_clip_checkpoint_v1 refactored with the following changes:
- Convert for loop into dictionary comprehension (
dict-comprehension) - Inline variable that is immediately returned (
inline-immediately-returned-variable)
| if reso in self.predefined_resos_set: | ||
| pass | ||
| else: | ||
| if reso not in self.predefined_resos_set: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BucketManager.select_bucket refactored with the following changes:
- Swap if/else to remove empty if body (
remove-pass-body)
| self.caption_extension = "." + self.caption_extension | ||
| self.caption_extension = f".{self.caption_extension}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function DreamBoothSubset.__init__ refactored with the following changes:
- Use f-string instead of string concatenation (
use-fstring-for-concatenation)
| if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする | ||
| if self.current_epoch != epoch: # epochが切り替わったらバケツをシャッフルする |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BaseDataset.set_current_epoch refactored with the following changes:
- Simplify logical expression using De Morgan identities (
de-morgan)
| tag = tag.strip() | ||
| if tag: | ||
| if tag := tag.strip(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BaseDataset.set_tag_frequency refactored with the following changes:
- Use named expression to simplify assignment and conditional (
use-named-expression)
| if type(str_to) == list: | ||
| caption = random.choice(str_to) | ||
| else: | ||
| caption = str_to | ||
| caption = random.choice(str_to) if type(str_to) == list else str_to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function BaseDataset.process_caption refactored with the following changes:
- Replace if statement with if expression (
assign-if-exp)
Branch
mainrefactored by Sourcery.If you're happy with these changes, merge this Pull Request using the Squash and merge strategy.
See our documentation here.
Run Sourcery locally
Reduce the feedback loop during development by using the Sourcery editor plugin:
Review changes via command line
To manually merge these changes, make sure you're on the
mainbranch, then run:Help us improve this pull request!