Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions focoos/utils/system.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib.metadata as metadata
import os
import platform
import shutil
import subprocess
import sys
import tarfile
Expand Down Expand Up @@ -266,7 +267,7 @@ def is_inside_sagemaker():
return res


def list_dir(base_directory: Union[str, Path]) -> List[Path]:
def list_directories(base_directory: Union[str, Path]) -> List[Path]:
"""
A function that lists directories within a base directory.

Expand Down Expand Up @@ -310,10 +311,9 @@ def extract_archive(
# Determine the extraction path
t0 = time.time()
base_dir = os.path.dirname(archive_path)
extracted_dir = base_dir
if destination is not None:
extracted_dir = os.path.join(base_dir, destination)
else:
extracted_dir = base_dir

if comm.is_main_process():
logger.info(f"Extracting archive: {archive_path} to {extracted_dir}")
Expand All @@ -340,8 +340,24 @@ def extract_archive(
logger.info(f"[elapsed {t1 - t0:.3f} ] Extracted archive to: {extracted_dir}")

comm.synchronize()
if len(list_dir(extracted_dir)) == 1:
extracted_dir = list_dir(extracted_dir)[0]
# Remove __MACOSX directory
if "__MACOSX" in os.listdir(extracted_dir):
shutil.rmtree(os.path.join(extracted_dir, "__MACOSX"))

if len(list_directories(extracted_dir)) == 1:
extracted_dir = list_directories(extracted_dir)[0]

POSSIBLE_TRAIN_DIRS = ["train", "training"]
POSSIBLE_VAL_DIRS = ["valid", "val", "validation"]
inner_dirs = list_directories(extracted_dir)
if not any(dir.name in POSSIBLE_TRAIN_DIRS for dir in inner_dirs):
raise FileNotFoundError(
f"Train split not found in {extracted_dir}: {[str(x) for x in inner_dirs]}. You should provide a zip dataset with only a root folder or train and val subfolders."
)
if not any(dir.name in POSSIBLE_VAL_DIRS for dir in inner_dirs):
raise FileNotFoundError(
f"Validation split not found in {extracted_dir}: {[str(x) for x in inner_dirs]}. You should provide a zip dataset with only a root folder or train and val subfolders."
)

# Optionally delete the original archive
if delete_original:
Expand Down
Loading