Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add datafarame as input #204

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pytorchvideo/data/labeled_video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import gc
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

from typing import Any, Callable, Dict, List, Optional, Tuple, Type,Union
import pandas as pd
import torch.utils.data
from pytorchvideo.data.clip_sampling import ClipSampler
from pytorchvideo.data.video import VideoPathHandler
Expand Down Expand Up @@ -145,6 +145,7 @@ def __next__(self) -> dict:
)
self._loaded_video_label = (video, info_dict, video_index)
except Exception as e:
print('error is',e)#necessary to print error
logger.debug(
"Failed to load video with error: {}; trial {}".format(
e,
Expand Down Expand Up @@ -251,7 +252,7 @@ def __iter__(self):


def labeled_video_dataset(
data_path: str,
data:Union[str, pd.DataFrame],
clip_sampler: ClipSampler,
video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
Expand Down Expand Up @@ -292,8 +293,12 @@ def labeled_video_dataset(
decoder (str): Defines what type of decoder used to decode a video.

"""
labeled_video_paths = LabeledVideoPaths.from_path(data_path)
labeled_video_paths.path_prefix = video_path_prefix
if isinstance(data,pd.DataFrame):
labeled_video_paths= LabeledVideoPaths.from_df(data)
elif isinstance(data,str):
labeled_video_paths = LabeledVideoPaths.from_path(data)
labeled_video_paths.path_prefix = video_path_prefix

dataset = LabeledVideoDataset(
labeled_video_paths,
clip_sampler,
Expand Down
36 changes: 35 additions & 1 deletion pytorchvideo/data/labeled_video_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from iopath.common.file_io import g_pathmgr
from torchvision.datasets.folder import make_dataset

import pandas as pd

class LabeledVideoPaths:
"""
Expand All @@ -25,11 +25,14 @@ def from_path(cls, data_path: str) -> LabeledVideoPaths:
Args:
file_path (str): The path to the file to be read.
"""


if g_pathmgr.isfile(data_path):
return LabeledVideoPaths.from_csv(data_path)
elif g_pathmgr.isdir(data_path):
return LabeledVideoPaths.from_directory(data_path)


else:
raise FileNotFoundError(f"{data_path} not found.")

Expand Down Expand Up @@ -67,6 +70,37 @@ def from_csv(cls, file_path: str) -> LabeledVideoPaths:
), f"Failed to load dataset from {file_path}."
return cls(video_paths_and_label)


@classmethod
def from_df(cls, df:pd.DataFrame) -> LabeledVideoPaths:
"""
Factory function that creates a LabeledVideoPaths object by reading a dataframe.
Sample dataframe
df=pd.DataFrame(
{
"path":["path_to_video_1","path_to_video_2","path_to_video_3"],
"label":["label_1","label_2","label_3"]
})

Args:
df (dataframe): The dataframe variable.
"""
video_paths_and_label = []
for row in df.iterrows():
row=row[1].values
path=row[0]
label=row[1::].astype(float)
video_paths_and_label.append((path, label))

assert (
len(video_paths_and_label) > 0
), f"Failed to load dataset from df."
return cls(video_paths_and_label)





@classmethod
def from_directory(cls, dir_path: str) -> LabeledVideoPaths:
"""
Expand Down