Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add content-type filter method to folder_paths #4054

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions comfy_extras/nodes_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,10 @@ def INPUT_TYPES(s):
}

class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')

@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
]
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}

CATEGORY = "audio"
Expand Down
26 changes: 26 additions & 0 deletions folder_paths.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
import mimetypes
import logging
from typing import Set, List, Dict, Tuple

Expand Down Expand Up @@ -43,6 +44,9 @@
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")

filename_list_cache = {}
extension_mimetypes_cache = {
"webp" : "image",
}

if not os.path.exists(input_directory):
try:
Expand Down Expand Up @@ -85,6 +89,28 @@ def get_directory_by_type(type_name):
return get_input_directory()
return None

def filter_files_content_types(files: List[str], content_types: List[str]) -> List[str]:
christian-byrne marked this conversation as resolved.
Show resolved Hide resolved
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]

if content_type in content_types:
result.append(file)
return result

# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
Expand Down
Empty file.
52 changes: 52 additions & 0 deletions tests-unit/folder_paths_test/filter_by_content_types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding tests!

import os
import tempfile
from folder_paths import filter_files_content_types

@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}


@pytest.fixture(scope="module")
def mock_dir(file_extensions):
with tempfile.TemporaryDirectory() as directory:
for content_type, extensions in file_extensions.items():
for extension in extensions:
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
f.write(f"Sample {content_type} file in {extension} format")
yield directory


def test_categorizes_all_correctly(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
for extension in extensions:
assert f"sample_{content_type}.{extension}" in filtered_files


def test_categorizes_all_uniquely(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
assert len(filtered_files) == len(extensions)


def test_handles_bad_extensions():
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []


def test_handles_no_extension():
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []


def test_handles_no_files():
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
Loading