Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#14276 from AUTOMATIC1111/fix-styles
Browse files Browse the repository at this point in the history
Fix styles
  • Loading branch information
AUTOMATIC1111 committed Dec 14, 2023
1 parent b55f09c commit 888b928
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions modules/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,8 @@ def __init__(self, path: str):
self.path = path

folder, file = os.path.split(self.path)
self.default_file = file.split("*")[0] + ".csv"
if self.default_file == ".csv":
self.default_file = "styles.csv"
self.default_path = os.path.join(folder, self.default_file)
filename, _, ext = file.partition('*')
self.default_path = os.path.join(folder, filename + ext)

self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]

Expand Down Expand Up @@ -155,10 +153,8 @@ def load_from_csv(self, path: str):
row["name"], prompt, negative_prompt, path
)

def get_style_paths(self) -> list():
"""
Returns a list of all distinct paths, including the default path, of
files that styles are loaded from."""
def get_style_paths(self) -> set:
"""Returns a set of all distinct paths of files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
Expand All @@ -172,9 +168,9 @@ def get_style_paths(self) -> list():
style_paths.add(style.path)

# Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save")
style_paths.discard("do_not_save")

return list(style_paths)
return style_paths

def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
Expand All @@ -196,20 +192,7 @@ def save_styles(self, path: str = None) -> None:
# The path argument is deprecated, but kept for backwards compatibility
_ = path

# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
self.styles[style.name] = style._replace(path=self.default_path)

# Create a list of all distinct paths, including the default path
style_paths = set()
style_paths.add(self.default_path)
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)

# Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save")
style_paths = self.get_style_paths()

csv_names = [os.path.split(path)[1].lower() for path in style_paths]

Expand Down

0 comments on commit 888b928

Please sign in to comment.