Skip to content

Commit

Permalink
initial support for training textual inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Oct 2, 2022
1 parent 84e97a9 commit 820f1dc
Show file tree
Hide file tree
Showing 19 changed files with 828 additions and 315 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ __pycache__
/.idea
notification.mp3
/SwinIR
/textual_inversion
1 change: 1 addition & 0 deletions javascript/progressbar.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
})

function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
Expand Down
8 changes: 8 additions & 0 deletions javascript/textualInversion.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@


function start_training_textual_inversion(){
requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML=''

return args_to_array(arguments)
}
3 changes: 1 addition & 2 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ def enable_tf32():

errors.run(enable_tf32, "Enabling TF32")


device = get_optimal_device()
device_codeformer = cpu if has_mps else device

dtype = torch.float16

def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
Expand Down
13 changes: 8 additions & 5 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
self.styles: str = styles
self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
Expand Down Expand Up @@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
}

generation_params.update(p.extra_generation_params)
Expand All @@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:

fix_seed(p)

os.makedirs(p.outpath_samples, exist_ok=True)
os.makedirs(p.outpath_grids, exist_ok=True)
if p.outpath_samples is not None:
os.makedirs(p.outpath_samples, exist_ok=True)

if p.outpath_grids is not None:
os.makedirs(p.outpath_grids, exist_ok=True)

modules.sd_hijack.model_hijack.apply_circular(p.tiling)

Expand All @@ -323,7 +326,7 @@ def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)

if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
model_hijack.embedding_db.load_textual_inversion_embeddings()

infotexts = []
output_images = []
Expand Down
Loading

0 comments on commit 820f1dc

Please sign in to comment.