From 92c3bafb05ed57f0195a37816dcaf6093a1461a2 Mon Sep 17 00:00:00 2001 From: star <15031259256@163.com> Date: Sun, 13 Nov 2022 18:43:03 +0800 Subject: [PATCH] feat(example): add video type & example (#1473) add video example --- client/starwhale/__init__.py | 2 + .../starwhale/api/_impl/dataset/__init__.py | 2 + client/starwhale/api/dataset.py | 2 + client/starwhale/core/dataset/type.py | 27 + client/tests/sdk/test_dataset.py | 10 + example/ucf101/.gitignore | 2 + example/ucf101/.swignore | 5 + example/ucf101/Makefile | 17 + example/ucf101/dataset.yaml | 9 + example/ucf101/generate_data.sh | 20 + example/ucf101/model.yaml | 10 + example/ucf101/requirements-sw-lock.txt | 54 ++ example/ucf101/runtime.yaml | 13 + example/ucf101/transform_video.sh | 16 + example/ucf101/ucf101/__init__.py | 0 example/ucf101/ucf101/dataset.py | 29 + example/ucf101/ucf101/evaluator.py | 133 ++++ example/ucf101/ucf101/lr_scheduler.py | 65 ++ example/ucf101/ucf101/metric.py | 125 ++++ example/ucf101/ucf101/model.py | 401 +++++++++++ example/ucf101/ucf101/sampler.py | 60 ++ example/ucf101/ucf101/train.py | 673 ++++++++++++++++++ example/ucf101/ucf101/transform.py | 223 ++++++ example/ucf101/ucf101/video_iterator.py | 352 +++++++++ 24 files changed, 2250 insertions(+) create mode 100644 example/ucf101/.gitignore create mode 100644 example/ucf101/.swignore create mode 100644 example/ucf101/Makefile create mode 100644 example/ucf101/dataset.yaml create mode 100644 example/ucf101/generate_data.sh create mode 100644 example/ucf101/model.yaml create mode 100644 example/ucf101/requirements-sw-lock.txt create mode 100644 example/ucf101/runtime.yaml create mode 100644 example/ucf101/transform_video.sh create mode 100644 example/ucf101/ucf101/__init__.py create mode 100644 example/ucf101/ucf101/dataset.py create mode 100644 example/ucf101/ucf101/evaluator.py create mode 100644 example/ucf101/ucf101/lr_scheduler.py create mode 100644 example/ucf101/ucf101/metric.py create mode 100644 example/ucf101/ucf101/model.py create mode 100644 example/ucf101/ucf101/sampler.py create mode 100644 example/ucf101/ucf101/train.py create mode 100644 example/ucf101/ucf101/transform.py create mode 100644 example/ucf101/ucf101/video_iterator.py diff --git a/client/starwhale/__init__.py b/client/starwhale/__init__.py index 57b16d0375..896c4ea0be 100644 --- a/client/starwhale/__init__.py +++ b/client/starwhale/__init__.py @@ -8,6 +8,7 @@ Text, Audio, Image, + Video, Binary, LinkType, MIMEType, @@ -49,6 +50,7 @@ "Binary", "Text", "Audio", + "Video", "Image", "ClassLabel", "BoundingBox", diff --git a/client/starwhale/api/_impl/dataset/__init__.py b/client/starwhale/api/_impl/dataset/__init__.py index 74e690b771..74dccb6838 100644 --- a/client/starwhale/api/_impl/dataset/__init__.py +++ b/client/starwhale/api/_impl/dataset/__init__.py @@ -3,6 +3,7 @@ Text, Audio, Image, + Video, Binary, LinkType, MIMEType, @@ -34,6 +35,7 @@ "Binary", "Text", "Audio", + "Video", "Image", "ClassLabel", "BoundingBox", diff --git a/client/starwhale/api/dataset.py b/client/starwhale/api/dataset.py index 6b7e242b31..e8cd7263e3 100644 --- a/client/starwhale/api/dataset.py +++ b/client/starwhale/api/dataset.py @@ -3,6 +3,7 @@ Text, Audio, Image, + Video, Binary, LinkType, MIMEType, @@ -40,6 +41,7 @@ "Binary", "Text", "Audio", + "Video", "Image", "ClassLabel", "BoundingBox", diff --git a/client/starwhale/core/dataset/type.py b/client/starwhale/core/dataset/type.py index 5f7f63f5ed..bd08076ffa 100644 --- a/client/starwhale/core/dataset/type.py +++ b/client/starwhale/core/dataset/type.py @@ -117,6 +117,7 @@ class MIMEType(Enum): AVIF = "image/avif" MP4 = "video/mp4" AVI = "video/avi" + WEBM = "video/webm" WAV = "audio/wav" MP3 = "audio/mp3" PLAIN = "text/plain" @@ -145,6 +146,7 @@ def create_by_file_suffix(cls, name: str) -> MIMEType: ".mp4": cls.MP4, ".avif": cls.AVIF, ".avi": cls.AVI, + ".webm": cls.WEBM, ".wav": cls.WAV, ".csv": cls.CSV, ".txt": cls.PLAIN, @@ -221,6 +223,10 @@ def reflect(cls, raw_data: bytes, data_type: t.Dict[str, t.Any]) -> BaseArtifact return Audio( raw_data, mime_type=mime_type, shape=shape, display_name=display_name ) + elif dtype == ArtifactType.Video.value: + return Video( + raw_data, mime_type=mime_type, shape=shape, display_name=display_name + ) elif not dtype or dtype == ArtifactType.Binary.value: return Binary(raw_data) elif dtype == ArtifactType.Link.value: @@ -350,6 +356,27 @@ def _do_validate(self) -> None: raise NoSupportError(f"Audio type: {self.mime_type}") +class Video(BaseArtifact): + def __init__( + self, + fp: _TArtifactFP = "", + display_name: str = "", + shape: t.Optional[_TShape] = None, + mime_type: t.Optional[MIMEType] = None, + ) -> None: + shape = shape or (None,) + super().__init__(fp, ArtifactType.Video, display_name, shape, mime_type) + + def _do_validate(self) -> None: + if self.mime_type not in ( + MIMEType.MP4, + MIMEType.AVI, + MIMEType.WEBM, + MIMEType.UNDEFINED, + ): + raise NoSupportError(f"Video type: {self.mime_type}") + + class ClassLabel(ASDictMixin): def __init__(self, names: t.List[t.Union[int, float, str]]) -> None: self.type = "class_label" diff --git a/client/tests/sdk/test_dataset.py b/client/tests/sdk/test_dataset.py index 1e64b3631b..9879dc94bc 100644 --- a/client/tests/sdk/test_dataset.py +++ b/client/tests/sdk/test_dataset.py @@ -32,6 +32,7 @@ Text, Audio, Image, + Video, Binary, ClassLabel, BoundingBox, @@ -482,6 +483,15 @@ def test_audio(self) -> None: assert _asdict["type"] == "audio" assert audio.to_bytes() == b"test" + def test_video(self) -> None: + fp = "/test/1.avi" + self.fs.create_file(fp, contents="test") + video = Video(fp) + _asdict = json.loads(json.dumps(video.asdict())) + assert _asdict["mime_type"] == MIMEType.AVI.value + assert _asdict["type"] == "video" + assert video.to_bytes() == b"test" + def test_bbox(self) -> None: bbox = BoundingBox(1, 2, 3, 4) assert bbox.to_list() == [1, 2, 3, 4] diff --git a/example/ucf101/.gitignore b/example/ucf101/.gitignore new file mode 100644 index 0000000000..a5b713f6a7 --- /dev/null +++ b/example/ucf101/.gitignore @@ -0,0 +1,2 @@ +data/ +models/ diff --git a/example/ucf101/.swignore b/example/ucf101/.swignore new file mode 100644 index 0000000000..92c3d73c44 --- /dev/null +++ b/example/ucf101/.swignore @@ -0,0 +1,5 @@ + venv +.git +.history +.vscode +.venv diff --git a/example/ucf101/Makefile b/example/ucf101/Makefile new file mode 100644 index 0000000000..2f56224c11 --- /dev/null +++ b/example/ucf101/Makefile @@ -0,0 +1,17 @@ +.POHNY: train +train: + mkdir -p models + python3 ucf101/train.py + +.POHNY: download-data +download-data: + rm -rf data + mkdir -p data + wget http://www.crcv.ucf.edu/data/UCF101/UCF101.rar --no-check-certificate -P data + unrar x data/UCF101.rar data + rm -rf data/UCF101.rar + rm -f data/all_list.txt + bash generate_data.sh + shuf data/all_list.txt -n 9000 -o data/train_list.txt + shuf data/all_list.txt -n 1000 -o data/validation_list.txt + shuf data/all_list.txt -n 200 -o data/test_list.txt diff --git a/example/ucf101/dataset.yaml b/example/ucf101/dataset.yaml new file mode 100644 index 0000000000..35e98bdf61 --- /dev/null +++ b/example/ucf101/dataset.yaml @@ -0,0 +1,9 @@ +name: ucf101 + +handler: ucf101.dataset:UCFDatasetBuildExecutor + +desc: ucf101 data and label test dataset + +attr: + alignment_size: 128 + volume_size: 10M diff --git a/example/ucf101/generate_data.sh b/example/ucf101/generate_data.sh new file mode 100644 index 0000000000..d0c4b81267 --- /dev/null +++ b/example/ucf101/generate_data.sh @@ -0,0 +1,20 @@ +#! /bin/bash + +global_index=0 +label_index=0 + +read_dir(){ + for file in `ls $1` + do + if [ -d $1"/"$file ] + then + read_dir $1"/"$file $label_index + let label_index++ + else + echo $global_index $2 ${1:13}"/"$file >> "data/"all_list.txt + let global_index++ + fi + done +} + +read_dir data/UCF-101 diff --git a/example/ucf101/model.yaml b/example/ucf101/model.yaml new file mode 100644 index 0000000000..fa9b40b7db --- /dev/null +++ b/example/ucf101/model.yaml @@ -0,0 +1,10 @@ +version: 1.0 +name: ucf101 + +model: + - models/PyTorch-MFNet_ep-0000.pth + +run: + handler: ucf101.evaluator:UCF101PipelineHandler + +desc: ucf101 by pytorch diff --git a/example/ucf101/requirements-sw-lock.txt b/example/ucf101/requirements-sw-lock.txt new file mode 100644 index 0000000000..7ff3539fd9 --- /dev/null +++ b/example/ucf101/requirements-sw-lock.txt @@ -0,0 +1,54 @@ +# Generated by Starwhale(0.3.1) Runtime Lock +--index-url 'https://pypi.doubanio.com/simple/' +--extra-index-url 'https://mirrors.bfsu.edu.cn/pypi/web/simple/' +--trusted-host 'mirrors.bfsu.edu.cn\npypi.doubanio.com' +appdirs==1.4.4 +attrs==21.4.0 +boto3==1.21.0 +botocore==1.24.46 +cattrs==1.7.1 +certifi==2022.9.24 +charset-normalizer==2.1.1 +click==8.1.3 +click-option-group==0.5.5 +commonmark==0.9.1 +conda-pack==0.6.0 +dill==0.3.5.1 +distlib==0.3.6 +filelock==3.8.0 +fs==2.4.16 +idna==3.4 +Jinja2==3.1.2 +jmespath==0.10.0 +joblib==1.2.0 +jsonlines==3.0.0 +loguru==0.6.0 +MarkupSafe==2.1.1 +numpy==1.23.4 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +opencv-python==4.6.0.66 +packaging==21.3 +platformdirs==2.5.3 +pyarrow==10.0.0 +Pygments==2.13.0 +pyparsing==3.0.9 +python-dateutil==2.8.2 +PyYAML==6.0 +requests==2.28.1 +requests-toolbelt==0.10.1 +rich==12.6.0 +s3transfer==0.5.2 +scikit-learn==1.1.3 +scipy==1.9.3 +shellingham==1.5.0 +six==1.16.0 +tenacity==8.1.0 +textual==0.1.18 +threadpoolctl==3.1.0 +torch==1.13.0 +typing_extensions==4.4.0 +urllib3==1.26.12 +virtualenv==20.16.6 diff --git a/example/ucf101/runtime.yaml b/example/ucf101/runtime.yaml new file mode 100644 index 0000000000..5a04ee306f --- /dev/null +++ b/example/ucf101/runtime.yaml @@ -0,0 +1,13 @@ +api_version: '1.1' +dependencies: +- requirements-sw-lock.txt +- pip: + - starwhale==0.3.1 +- wheels: + - starwhale-0.0.0.dev0-py3-none-any.whl +environment: + arch: noarch + os: ubuntu:20.04 + python: '3.9' +mode: venv +name: ucf101 diff --git a/example/ucf101/transform_video.sh b/example/ucf101/transform_video.sh new file mode 100644 index 0000000000..e9d3d6338e --- /dev/null +++ b/example/ucf101/transform_video.sh @@ -0,0 +1,16 @@ +#! /bin/bash + +read_dir(){ + for file in `ls $1` + do + if [ -d $1"/"$file ] + then + mkdir -p data/UCF-101-WEBM/${1:13}"/"$file + read_dir $1"/"$file + else + ffmpeg -i $1"/"$file -y data/UCF-101-WEBM/${1:13}"/"${file%.*}".webm" + fi + done +} + +read_dir data/UCF-101 diff --git a/example/ucf101/ucf101/__init__.py b/example/ucf101/ucf101/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ucf101/ucf101/dataset.py b/example/ucf101/ucf101/dataset.py new file mode 100644 index 0000000000..94dc11353b --- /dev/null +++ b/example/ucf101/ucf101/dataset.py @@ -0,0 +1,29 @@ +import typing as t +from pathlib import Path + +from starwhale import Video, MIMEType, BuildExecutor + +root_dir = Path(__file__).parent.parent +dataset_dir = root_dir / "data" / "UCF-101" +test_ds_path = [root_dir / "data" / "test_list.txt"] + + +class UCFDatasetBuildExecutor(BuildExecutor): + def iter_item(self) -> t.Generator[t.Tuple, None, None]: + for path in test_ds_path: + with path.open() as f: + for line in f.readlines(): + v_id, label, video_sub_path = line.split() + + data_path = dataset_dir / video_sub_path + data = Video( + data_path, + display_name=video_sub_path, + shape=(1,), + mime_type=MIMEType.AVI, + ) + + annotations = { + "label": label, + } + yield f"{label}_{video_sub_path}", data, annotations diff --git a/example/ucf101/ucf101/evaluator.py b/example/ucf101/ucf101/evaluator.py new file mode 100644 index 0000000000..19212701a5 --- /dev/null +++ b/example/ucf101/ucf101/evaluator.py @@ -0,0 +1,133 @@ +import typing as t +import logging +import tempfile +from pathlib import Path + +import cv2 +import numpy as np +import torch + +from starwhale import Video, PipelineHandler, PPLResultIterator, multi_classification + +from .model import MFNET_3D +from .sampler import RandomSampling +from .transform import Resize, Compose, ToTensor, Normalize, RandomCrop + +root_dir = Path(__file__).parent.parent + + +def ppl_post(output: torch.Tensor) -> t.Tuple[t.List[str], t.List[float]]: + output = output.squeeze() + pred_value = output.argmax(-1).flatten().tolist() + probability_matrix = np.exp(output.tolist()).tolist() + return [str(p) for p in pred_value], probability_matrix + + +def load_model(device): + model = MFNET_3D(num_classes=101) + # network + if torch.cuda.is_available(): + model = torch.nn.DataParallel(model).cuda() + else: + model = torch.nn.DataParallel(model) + + checkpoint = torch.load( + str(root_dir / "models/PyTorch-MFNet_ep-0000.pth"), map_location=device + ) + + # customized partially load function + net_state_keys = list(model.state_dict().keys()) + for name, param in checkpoint["state_dict"].items(): + if name in model.state_dict().keys(): + dst_param_shape = model.state_dict()[name].shape + if param.shape == dst_param_shape: + model.state_dict()[name].copy_(param.view(dst_param_shape)) + net_state_keys.remove(name) + # indicating missed keys + if net_state_keys: + logging.error(f">> Failed to load: {net_state_keys}") + raise RuntimeError(f">> Failed to load: {net_state_keys}") + + model.to(device) + model.eval() + print("ucf101 model loaded, start to inference...") + return model + + +def ppl_pre(video: Video, sampler, transforms) -> torch.Tensor: + with tempfile.NamedTemporaryFile() as f: + f.write(video.to_bytes()) + f.flush() + cap = cv2.VideoCapture(f.name) + ids = sampler.sampling(range_max=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) + frames = [] + pre_idx = max(ids) + for idx in ids: + if pre_idx != (idx - 1): + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + res, frame = cap.read() # in BGR/GRAY format + pre_idx = idx + if len(frame.shape) < 3: + # Convert Gray to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) + else: + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + clip_input = np.concatenate(frames, axis=2) + trans_result = transforms(clip_input) + # mock batch + trans_result = trans_result[None, :] + return ( + trans_result.float().cuda() + if torch.cuda.is_available() + else trans_result.float().cpu() + ) + + +class UCF101PipelineHandler(PipelineHandler): + def __init__(self): + super().__init__(ignore_error=False) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = load_model(self.device) + self.sampler = RandomSampling() + self.transforms = Compose( + [ + Resize((256, 256)), + RandomCrop((224, 224)), + ToTensor(), + Normalize( + mean=[124 / 255, 117 / 255, 104 / 255], std=[1 / (0.0167 * 255)] * 3 + ), + ] + ) + + @torch.no_grad() + def ppl(self, video: Video, annotations, index, **kw: t.Any) -> t.Any: + _frames_tensor = ppl_pre( + video=video, sampler=self.sampler, transforms=self.transforms + ) + output = self.model(_frames_tensor) + + # recording + probs = torch.nn.Softmax(dim=1)(output) + label = torch.max(probs, 1)[1].detach().cpu().numpy()[0] + print( + f"id is:{index},real label is:{annotations},predict value is:{label}, probability is:{probs[0][label]}" + ) + + return ppl_post(output) + + @multi_classification( + confusion_matrix_normalize="all", + show_hamming_loss=True, + show_cohen_kappa_score=True, + show_roc_auc=True, + ) + def cmp(self, ppl_result: PPLResultIterator) -> t.Any: + result, label, pr = [], [], [] + for _data in ppl_result: + label.append(_data["annotations"]["label"]) + pr.append(_data["result"][1]) + result.append(_data["result"][0][0]) + return label, result, pr diff --git a/example/ucf101/ucf101/lr_scheduler.py b/example/ucf101/ucf101/lr_scheduler.py new file mode 100644 index 0000000000..b8a98b547a --- /dev/null +++ b/example/ucf101/ucf101/lr_scheduler.py @@ -0,0 +1,65 @@ +import logging + + +class LRScheduler(object): + def __init__(self, step_counter=0, base_lr=0.01): + self.lr = None + self.step_counter = step_counter + self.base_lr = base_lr + + def update(self): + raise NotImplementedError("must override this") + + def get_lr(self): + return self.lr + + +class MultiFactorScheduler(LRScheduler): + def __init__(self, steps, base_lr=0.01, factor=0.1, step_counter=0): + super(MultiFactorScheduler, self).__init__(step_counter, base_lr) + assert isinstance(steps, list) and len(steps) > 0 + for i, _step in enumerate(steps): + if i != 0 and steps[i] <= steps[i - 1]: + raise ValueError("Schedule step must be an increasing integer list") + if _step < 1: + raise ValueError("Schedule step must be greater or equal than 1 round") + if factor > 1.0: + raise ValueError("Factor must be no more than 1 to make lr reduce") + + logging.info( + "Iter %d: start with learning rate: %0.5e (next lr step: %d)" + % (self.step_counter, self.base_lr, steps[0]) + ) + self.steps = steps + self.factor = factor + self.lr = self.base_lr + self.cursor = 0 + + def update(self): + self.step_counter += 1 + + if self.cursor >= len(self.steps): + return self.lr + while self.steps[self.cursor] < self.step_counter: + self.lr *= self.factor + self.cursor += 1 + # message + if self.cursor >= len(self.steps): + logging.info( + "Iter: %d, change learning rate to %0.5e for step [%d:Inf)" + % (self.step_counter - 1, self.lr, self.step_counter - 1) + ) + return self.lr + else: + logging.info( + "Iter: %d, change learning rate to %0.5e for step [%d:%d)" + % ( + self.step_counter - 1, + self.lr, + self.step_counter - 1, + self.steps[self.cursor], + ) + ) + if self.step_counter < 100: + return self.lr / 2.0 + return self.lr diff --git a/example/ucf101/ucf101/metric.py b/example/ucf101/ucf101/metric.py new file mode 100644 index 0000000000..568f0bddb4 --- /dev/null +++ b/example/ucf101/ucf101/metric.py @@ -0,0 +1,125 @@ +import logging + + +class EvalMetric(object): + def __init__(self, name, **kwargs): + self.sum_metric = None + self.num_inst = None + self.name = str(name) + self.reset() + + def update(self, preds, labels, losses): + raise NotImplementedError() + + def reset(self): + self.num_inst = 0 + self.sum_metric = 0.0 + + def get(self): + if self.num_inst == 0: + return self.name, float("nan") + else: + return self.name, self.sum_metric / self.num_inst + + def get_name_value(self): + name, value = self.get() + if not isinstance(name, list): + name = [name] + if not isinstance(value, list): + value = [value] + return list(zip(name, value)) + + def check_label_shapes(self, preds, labels): + # raise if the shape is inconsistent + if (type(labels) is list) and (type(preds) is list): + label_shape, pred_shape = len(labels), len(preds) + else: + label_shape, pred_shape = labels.shape[0], preds.shape[0] + + if label_shape != pred_shape: + raise NotImplementedError("") + + +class MetricList(EvalMetric): + """Handle multiple evaluation metric""" + + def __init__(self, *args, name="metric_list"): + assert all( + [issubclass(type(x), EvalMetric) for x in args] + ), f"MetricList input is illegal: {args}" + self.metrics = [metric for metric in args] + super(MetricList, self).__init__(name=name) + + def update(self, preds, labels, losses=None): + preds = [preds] if type(preds) is not list else preds + labels = [labels] if type(labels) is not list else labels + losses = [losses] if type(losses) is not list else losses + + for metric in self.metrics: + metric.update(preds, labels, losses) + + def reset(self): + if hasattr(self, "metrics"): + for metric in self.metrics: + metric.reset() + else: + logging.warning("No metric defined.") + + def get(self): + ouputs = [] + for metric in self.metrics: + ouputs.append(metric.get()) + return ouputs + + def get_name_value(self): + ouputs = [] + for metric in self.metrics: + ouputs.append(metric.get_name_value()) + return ouputs + + +#################### +# COMMON METRICS +#################### + + +class Accuracy(EvalMetric): + """Computes accuracy classification score.""" + + def __init__(self, name="accuracy", topk=1): + super(Accuracy, self).__init__(name) + self.topk = topk + + def update(self, preds, labels, losses): + preds = [preds] if type(preds) is not list else preds + labels = [labels] if type(labels) is not list else labels + + self.check_label_shapes(preds, labels) + for pred, label in zip(preds, labels): + assert ( + self.topk <= pred.shape[1] + ), f"topk({self.topk}) should no larger than the pred dim({pred.shape[1]})" + _, pred_topk = pred.topk(self.topk, 1, True, True) + + pred_topk = pred_topk.t() + correct = pred_topk.eq(label.view(1, -1).expand_as(pred_topk)) + + self.sum_metric += float( + correct.contiguous().view(-1).float().sum(0, keepdim=True).numpy() + ) + self.num_inst += label.shape[0] + + +class Loss(EvalMetric): + """Dummy metric for directly printing loss.""" + + def __init__(self, name="loss"): + super(Loss, self).__init__(name) + + def update(self, preds, labels, losses): + assert losses is not None, "Loss undefined." + for loss in losses: + # print(f"loss is:{loss}, type is:{type(loss)}, metric:{float(loss.numpy().sum())}, + # shape is:{loss.shape}, type is {type(loss.shape)}") + self.sum_metric += float(loss.numpy().sum()) + self.num_inst += 1 # loss.shape[0] diff --git a/example/ucf101/ucf101/model.py b/example/ucf101/ucf101/model.py new file mode 100644 index 0000000000..1c1dccd4f9 --- /dev/null +++ b/example/ucf101/ucf101/model.py @@ -0,0 +1,401 @@ +import json +import logging +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn + + +def xavier(net): + def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1 and hasattr(m, "weight"): + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + if m.bias is not None: + m.bias.data.zero_() + elif classname.find("BatchNorm") != -1: + m.weight.data.fill_(1.0) + if m.bias is not None: + m.bias.data.zero_() + elif classname.find("Linear") != -1: + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + if m.bias is not None: + m.bias.data.zero_() + elif ( + classname + in [ + "Sequential", + "AvgPool3d", + "MaxPool3d", + "Dropout", + "ReLU", + "Softmax", + "BnActConv3d", + ] + or "Block" in classname + ): + pass + else: + if classname != classname.upper(): + logging.warning(f"Initializer:: '{classname}' is uninitialized.") + + net.apply(weights_init) + + +def init_3d_from_2d_dict(net, state_dict, method="inflation"): + logging.debug( + f"Initializer:: loading from 2D neural network, filling method: `{method}' ..." + ) + + # filling method + def filling_kernel(src, dshape, method): + assert method in [ + "inflation", + "random", + ], f"filling method: {method} is unknown!" + + if method == "inflation": + dst = torch.FloatTensor(dshape) + # normalize + src = src / float(dshape[2]) + src = src.view(dshape[0], dshape[1], 1, dshape[3], dshape[4]) + dst.copy_(src) + elif method == "random": + dst = torch.FloatTensor(dshape) + tmp = torch.FloatTensor(src.shape) + # normalize + src = src / float(dshape[2]) + # random range + scale = src.abs().mean() + # filling + dst[:, :, 0, :, :].copy_(src) + i = 1 + while i < dshape[2]: + if i + 2 < dshape[2]: + nn.init.uniform(tmp, a=-scale, b=scale) + dst[:, :, i, :, :].copy_(tmp) + dst[:, :, i + 1, :, :].copy_(src) + dst[:, :, i + 2, :, :].copy_(-tmp) + i += 3 + elif i + 1 < dshape[2]: + nn.init.uniform(tmp, a=-scale, b=scale) + dst[:, :, i, :, :].copy_(tmp) + dst[:, :, i + 1, :, :].copy_(-tmp) + i += 2 + else: + dst[:, :, i, :, :].copy_(src) + i += 1 + # shuffle + tmp = dst.numpy().swapaxes(2, -1) + shp = tmp.shape[:-1] + for ndx in np.ndindex(shp): + np.random.shuffle(tmp[ndx]) + dst = torch.from_numpy(tmp) + else: + raise NotImplementedError + + return dst + + # customized partialy loading function + src_state_keys = list(state_dict.keys()) + dst_state_keys = list(net.state_dict().keys()) + for name, param in state_dict.items(): + if name in dst_state_keys: + src_param_shape = param.shape + dst_param_shape = net.state_dict()[name].shape + if src_param_shape != dst_param_shape: + if name.startswith("classifier"): + continue + assert ( + len(src_param_shape) == 4 and len(dst_param_shape) == 5 + ), f"{name} mismatch" + if list(src_param_shape) == [dst_param_shape[i] for i in [0, 1, 3, 4]]: + if dst_param_shape[2] != 1: + param = filling_kernel( + src=param, dshape=dst_param_shape, method=method + ) + else: + param = param.view(dst_param_shape) + assert ( + dst_param_shape == param.shape + ), f"Initializer:: error({name}): {dst_param_shape} != {param.shape}" + net.state_dict()[name].copy_(param) + src_state_keys.remove(name) + dst_state_keys.remove(name) + + # indicate missing / ignored keys + if src_state_keys: + out = "['" + "', '".join(src_state_keys) + "']" + logging.info( + f"Initializer:: >> {len(src_state_keys)} params are " + f"unused: {out if len(out) < 300 else out[0:150] + ' ... ' + out[-150:]}" + ) + if dst_state_keys: + logging.info( + f"Initializer:: >> failed to load: \n{json.dumps(dst_state_keys, indent=4, sort_keys=True)}" + ) + + +class BN_AC_CONV3D(nn.Module): + def __init__( + self, + num_in, + num_filter, + kernel=(1, 1, 1), + pad=(0, 0, 0), + stride=(1, 1, 1), + g=1, + bias=False, + ): + super(BN_AC_CONV3D, self).__init__() + self.bn = nn.BatchNorm3d(num_in) + self.relu = nn.ReLU(inplace=True) + self.conv = nn.Conv3d( + num_in, + num_filter, + kernel_size=kernel, + padding=pad, + stride=stride, + groups=g, + bias=bias, + ) + + def forward(self, x): + h = self.relu(self.bn(x)) + h = self.conv(h) + return h + + +class MF_UNIT(nn.Module): + def __init__( + self, + num_in, + num_mid, + num_out, + g=1, + stride=(1, 1, 1), + first_block=False, + use_3d=True, + ): + super(MF_UNIT, self).__init__() + num_ix = int(num_mid / 4) + kt, pt = (3, 1) if use_3d else (1, 0) + # prepare input + self.conv_i1 = BN_AC_CONV3D( + num_in=num_in, num_filter=num_ix, kernel=(1, 1, 1), pad=(0, 0, 0) + ) + self.conv_i2 = BN_AC_CONV3D( + num_in=num_ix, num_filter=num_in, kernel=(1, 1, 1), pad=(0, 0, 0) + ) + # main part + self.conv_m1 = BN_AC_CONV3D( + num_in=num_in, + num_filter=num_mid, + kernel=(kt, 3, 3), + pad=(pt, 1, 1), + stride=stride, + g=g, + ) + if first_block: + self.conv_m2 = BN_AC_CONV3D( + num_in=num_mid, num_filter=num_out, kernel=(1, 1, 1), pad=(0, 0, 0) + ) + else: + self.conv_m2 = BN_AC_CONV3D( + num_in=num_mid, num_filter=num_out, kernel=(1, 3, 3), pad=(0, 1, 1), g=g + ) + # adapter + if first_block: + self.conv_w1 = BN_AC_CONV3D( + num_in=num_in, + num_filter=num_out, + kernel=(1, 1, 1), + pad=(0, 0, 0), + stride=stride, + ) + + def forward(self, x): + + h = self.conv_i1(x) + x_in = x + self.conv_i2(h) + + h = self.conv_m1(x_in) + h = self.conv_m2(h) + + if hasattr(self, "conv_w1"): + x = self.conv_w1(x) + + return h + x + + +class MFNET_3D(nn.Module): + def __init__(self, num_classes): + super(MFNET_3D, self).__init__() + + groups = 16 + k_sec = {2: 3, 3: 4, 4: 6, 5: 3} + + # conv1 - x224 (x16) + conv1_num_out = 16 + self.conv1 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv3d( + 3, + conv1_num_out, + kernel_size=(3, 5, 5), + padding=(1, 2, 2), + stride=(1, 2, 2), + bias=False, + ), + ), + ("bn", nn.BatchNorm3d(conv1_num_out)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + self.maxpool = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1) + ) + + # conv2 - x56 (x8) + num_mid = 96 + conv2_num_out = 96 + self.conv2 = nn.Sequential( + OrderedDict( + [ + ( + "B%02d" % i, + MF_UNIT( + num_in=conv1_num_out if i == 1 else conv2_num_out, + num_mid=num_mid, + num_out=conv2_num_out, + stride=(2, 1, 1) if i == 1 else (1, 1, 1), + g=groups, + first_block=(i == 1), + ), + ) + for i in range(1, k_sec[2] + 1) + ] + ) + ) + + # conv3 - x28 (x8) + num_mid *= 2 + conv3_num_out = 2 * conv2_num_out + self.conv3 = nn.Sequential( + OrderedDict( + [ + ( + "B%02d" % i, + MF_UNIT( + num_in=conv2_num_out if i == 1 else conv3_num_out, + num_mid=num_mid, + num_out=conv3_num_out, + stride=(1, 2, 2) if i == 1 else (1, 1, 1), + g=groups, + first_block=(i == 1), + ), + ) + for i in range(1, k_sec[3] + 1) + ] + ) + ) + + # conv4 - x14 (x8) + num_mid *= 2 + conv4_num_out = 2 * conv3_num_out + self.conv4 = nn.Sequential( + OrderedDict( + [ + ( + "B%02d" % i, + MF_UNIT( + num_in=conv3_num_out if i == 1 else conv4_num_out, + num_mid=num_mid, + num_out=conv4_num_out, + stride=(1, 2, 2) if i == 1 else (1, 1, 1), + g=groups, + first_block=(i == 1), + ), + ) + for i in range(1, k_sec[4] + 1) + ] + ) + ) + + # conv5 - x7 (x8) + num_mid *= 2 + conv5_num_out = 2 * conv4_num_out + self.conv5 = nn.Sequential( + OrderedDict( + [ + ( + "B%02d" % i, + MF_UNIT( + num_in=conv4_num_out if i == 1 else conv5_num_out, + num_mid=num_mid, + num_out=conv5_num_out, + stride=(1, 2, 2) if i == 1 else (1, 1, 1), + g=groups, + first_block=(i == 1), + ), + ) + for i in range(1, k_sec[5] + 1) + ] + ) + ) + + # final + self.tail = nn.Sequential( + OrderedDict( + [("bn", nn.BatchNorm3d(conv5_num_out)), ("relu", nn.ReLU(inplace=True))] + ) + ) + + self.globalpool = nn.Sequential( + OrderedDict( + [ + ("avg", nn.AvgPool3d(kernel_size=(8, 7, 7), stride=(1, 1, 1))), + # ('dropout', nn.Dropout(p=0.5)), only for fine-tuning + ] + ) + ) + self.classifier = nn.Linear(conv5_num_out, num_classes) + + ############# + # Initialization + xavier(net=self) + + def forward(self, x): + assert x.shape[2] == 16 + + h = self.conv1(x) # x224 -> x112 + h = self.maxpool(h) # x112 -> x56 + + h = self.conv2(h) # x56 -> x56 + h = self.conv3(h) # x56 -> x28 + h = self.conv4(h) # x28 -> x14 + h = self.conv5(h) # x14 -> x7 + + h = self.tail(h) + h = self.globalpool(h) + + h = h.view(h.shape[0], -1) + h = self.classifier(h) + + return h + + +if __name__ == "__main__": + + logging.getLogger().setLevel(logging.DEBUG) + # --------- + net = MFNET_3D(num_classes=100) + data = torch.autograd.Variable(torch.randn(1, 3, 16, 224, 224)) + output = net(data) + torch.save({"state_dict": net.state_dict()}, "./tmp.pth") + print(output.shape) diff --git a/example/ucf101/ucf101/sampler.py b/example/ucf101/ucf101/sampler.py new file mode 100644 index 0000000000..15c29ce6d7 --- /dev/null +++ b/example/ucf101/ucf101/sampler.py @@ -0,0 +1,60 @@ +from typing import Any, List, Union + +import numpy as np + + +class RandomSampling(object): + def __init__(self, num=16, interval=2, speed=[1.0, 1.0], seed=0): + assert num > 0, "at least sampling 1 frame" + self.num = num + self.interval = interval if type(interval) == list else [interval] + self.speed = speed + self.rng = np.random.RandomState(seed) + + def sampling(self, range_max=30, **kwargs) -> Union[List[Any], object]: + assert range_max > 0, ValueError(f"range_max = {range_max}") + interval = self.rng.choice(self.interval) + if self.num == 1: + return [self.rng.choice(range(0, range_max))] + # sampling + speed_min = self.speed[0] + speed_max = min(self.speed[1], (range_max - 1) / ((self.num - 1) * interval)) + if speed_max < speed_min: + return [self.rng.choice(range(0, range_max))] * self.num + random_interval = self.rng.uniform(speed_min, speed_max) * interval + frame_range = (self.num - 1) * random_interval + clip_start = self.rng.uniform(0, (range_max - 1) - frame_range) + clip_end = clip_start + frame_range + return np.linspace(clip_start, clip_end, self.num).astype(dtype=np.int).tolist() + + +class SequentialSampling(object): + def __init__(self, num, interval=1, shuffle=False, fix_cursor=False, seed=0): + self.memory = {} + self.num = num + self.interval = interval if type(interval) == list else [interval] + self.shuffle = shuffle + self.fix_cursor = fix_cursor + self.rng = np.random.RandomState(seed) + + def sampling(self, range_max, v_id, prev_failed=False): + assert range_max > 0, ValueError(f"range_max = {range_max}") + num = self.num + interval = self.rng.choice(self.interval) + frame_range = (num - 1) * interval + 1 + # sampling clips + if v_id not in self.memory: + clips = list(range(0, range_max - (frame_range - 1), frame_range)) + if self.shuffle: + self.rng.shuffle(clips) + self.memory[v_id] = [-1, clips] + # pickup a clip + cursor, clips = self.memory[v_id] + if not clips: + return [self.rng.choice(range(0, range_max))] * num + cursor = (cursor + 1) % len(clips) + if prev_failed or not self.fix_cursor: + self.memory[v_id][0] = cursor + # sampling within clip + ids = range(clips[cursor], clips[cursor] + frame_range, interval) + return ids diff --git a/example/ucf101/ucf101/train.py b/example/ucf101/ucf101/train.py new file mode 100644 index 0000000000..2b68b9aefd --- /dev/null +++ b/example/ucf101/ucf101/train.py @@ -0,0 +1,673 @@ +import os +import time +import logging +from pathlib import Path + +import torch +import torch.backends.cudnn as cudnn +from model import MFNET_3D +from metric import Loss, Accuracy, MetricList +from sampler import RandomSampling, SequentialSampling +from transform import ( + Resize, + Compose, + ToTensor, + Normalize, + RandomHLS, + CenterCrop, + RandomCrop, + RandomScale, + RandomHorizontalFlip, +) +from lr_scheduler import MultiFactorScheduler +from video_iterator import VideoIter +from torch.utils.data import DataLoader + +ROOTDIR = Path(__file__).parent.parent + + +class Callback(object): + def __init__(self, with_header=False): + self.with_header = with_header + + def __call__(self): + raise NotImplementedError("To be implemented") + + def header(self, epoch=None, batch=None): + str_out = "" + if self.with_header: + if epoch is not None: + str_out += f"Epoch {('[%d]' % epoch).ljust(5, ' ')} " + if batch is not None: + str_out += f"Batch {('[%d]' % batch).ljust(6, ' ')} " + return str_out + + +class CallbackList(Callback): + def __init__(self, *args, with_header=True): + super(CallbackList, self).__init__(with_header=with_header) + assert all( + [issubclass(type(x), Callback) for x in args] + ), f"Callback inputs illegal: {args}" + self.callbacks = [callback for callback in args] + + def __call__(self, epoch=None, batch=None, silent=False, **kwargs): + str_out = self.header(epoch, batch) + + for callback in self.callbacks: + str_out += callback(**kwargs, silent=True) + " " + + if not silent: + logging.info(str_out) + return str_out + + +#################### +# CUSTOMIZED CALLBACKS +#################### + + +class SpeedMonitor(Callback): + def __init__(self, with_header=False): + super(SpeedMonitor, self).__init__(with_header=with_header) + + def __call__( + self, + sample_elapse, + update_elapse=None, + epoch=None, + batch=None, + silent=False, + **kwargs, + ): + str_out = self.header(epoch, batch) + + if sample_elapse is not None: + sample_freq = 1.0 / sample_elapse + if update_elapse is not None: + update_freq = 1.0 / update_elapse + str_out += "Speed {: >5.1f} (+{: >2.0f}) sample/sec ".format( + sample_freq, update_freq - sample_freq + ) + else: + str_out += "Speed {:.2f} sample/sec ".format(sample_freq) + + if not silent: + logging.info(str_out) + return str_out + + +class MetricPrinter(Callback): + def __init__(self, with_header=False): + super(MetricPrinter, self).__init__(with_header=with_header) + + def __call__(self, namevals, epoch=None, batch=None, silent=False, **kwargs): + str_out = self.header(epoch, batch) + + if namevals is not None: + for i, nameval in enumerate(namevals): + name, value = nameval[0] + str_out += "{} = {:.5f}".format(name, value) + str_out += ", " if i != (len(namevals) - 1) else " " + + if not silent: + logging.info(str_out) + return str_out + + +""" +Static Model +""" + + +class StaticModel(object): + def __init__(self, net, criterion=None, model_prefix="", **kwargs): + if kwargs: + logging.warning(f"Unknown kwargs: {kwargs}") + + # init params + self.net = net + self.model_prefix = model_prefix + self.criterion = criterion + + def load_state(self, state_dict, strict=False): + if strict: + self.net.load_state_dict(state_dict=state_dict) + else: + # customized partialy load function + net_state_keys = list(self.net.state_dict().keys()) + for name, param in state_dict.items(): + if name in self.net.state_dict().keys(): + dst_param_shape = self.net.state_dict()[name].shape + if param.shape == dst_param_shape: + self.net.state_dict()[name].copy_(param.view(dst_param_shape)) + net_state_keys.remove(name) + # indicating missed keys + if net_state_keys: + logging.warning(f">> Failed to load: {net_state_keys}") + return False + return True + + def get_checkpoint_path(self, epoch): + assert self.model_prefix, "model_prefix undefined!" + + return "{}_ep-{:04d}.pth".format(self.model_prefix, epoch) + + def load_checkpoint(self, epoch, optimizer=None): + + load_path = self.get_checkpoint_path(epoch) + assert os.path.exists( + load_path + ), f"Failed to load: {load_path} (file not exist)" + + checkpoint = torch.load(load_path) + + all_params_matched = self.load_state(checkpoint["state_dict"], strict=True) + + if optimizer: + if "optimizer" in checkpoint.keys() and all_params_matched: + optimizer.load_state_dict(checkpoint["optimizer"]) + logging.info( + f"Model & Optimizer states are resumed from: `{load_path}'" + ) + else: + logging.warning( + f">> Failed to load optimizer state from: `{load_path}'" + ) + else: + logging.info(f"Only model state resumed from: `{load_path}'") + + if "epoch" in checkpoint.keys(): + if checkpoint["epoch"] != epoch: + logging.warning( + f">> Epoch information inconsistant: {checkpoint['epoch']} vs {epoch}" + ) + + def save_checkpoint(self, epoch, optimizer_state=None): + + save_path = self.get_checkpoint_path(epoch) + save_folder = os.path.dirname(save_path) + + if not os.path.exists(save_folder): + logging.debug(f"mkdir {save_folder}") + os.makedirs(save_folder) + + if not optimizer_state: + torch.save({"epoch": epoch, "state_dict": self.net.state_dict()}, save_path) + logging.info(f"Checkpoint (only model) saved to: {save_path}") + else: + torch.save( + { + "epoch": epoch, + "state_dict": self.net.state_dict(), + "optimizer": optimizer_state, + }, + save_path, + ) + logging.info(f"Checkpoint (model & optimizer) saved to: {save_path}") + + def forward(self, data, target): + """typical forward function with: + single output and single loss + """ + + if torch.cuda.is_available(): + data = data.float().cuda() + target = target.cuda() + else: + data = data.float() + + if self.net.training: + input_var = torch.autograd.Variable(data, requires_grad=False) + target_var = torch.autograd.Variable(target, requires_grad=False) + else: + input_var = torch.autograd.Variable(data, volatile=True) + target_var = torch.autograd.Variable(target, volatile=True) + + output = self.net(input_var) + if ( + hasattr(self, "criterion") + and self.criterion is not None + and target is not None + ): + loss = self.criterion(output, target_var) + else: + loss = None + return [output], [loss] + + +""" +Dynamic model that is able to update itself +""" + + +class DynamicModel(StaticModel): + def __init__( + self, + net, + criterion, + model_prefix="", + step_callback=None, + step_callback_freq=50, + epoch_callback=None, + save_checkpoint_freq=1, + opt_batch_size=None, + **kwargs, + ): + + # load parameters + if kwargs: + logging.warning(f"Unknown kwargs: {kwargs}") + + super(DynamicModel, self).__init__( + net, criterion=criterion, model_prefix=model_prefix + ) + + # load optional arguments + # - callbacks + self.callback_kwargs = { + "epoch": None, + "batch": None, + "sample_elapse": None, + "update_elapse": None, + "epoch_elapse": None, + "namevals": None, + "optimizer_dict": None, + } + + if not step_callback: + step_callback = CallbackList(SpeedMonitor(), MetricPrinter()) + if not epoch_callback: + epoch_callback = lambda **kwargs: None + + self.step_callback = step_callback + self.step_callback_freq = step_callback_freq + self.epoch_callback = epoch_callback + self.save_checkpoint_freq = save_checkpoint_freq + self.batch_size = opt_batch_size + + """ + In order to customize the callback function, + you will have to overwrite the functions below + """ + + def step_end_callback(self): + logging.debug(f"Step {self.i_step} finished!") + self.step_callback(**self.callback_kwargs) + + def epoch_end_callback(self): + self.epoch_callback(**self.callback_kwargs) + if self.callback_kwargs["epoch_elapse"] is not None: + logging.info( + "Epoch [{:d}] time cost: {:.2f} sec ({:.2f} h)".format( + self.callback_kwargs["epoch"], + self.callback_kwargs["epoch_elapse"], + self.callback_kwargs["epoch_elapse"] / 3600.0, + ) + ) + # 0 or best result(least loss) ?????? by gxx + if ( + self.callback_kwargs["epoch"] == 0 + or ((self.callback_kwargs["epoch"] + 1) % self.save_checkpoint_freq) == 0 + ): + self.save_checkpoint( + epoch=self.callback_kwargs["epoch"] + 1, + optimizer_state=self.callback_kwargs["optimizer_dict"], + ) + + """ + Learning rate + """ + + def adjust_learning_rate(self, lr, optimizer): + for param_group in optimizer.param_groups: + if "lr_mult" in param_group: + lr_mult = param_group["lr_mult"] + else: + lr_mult = 1.0 + param_group["lr"] = lr * lr_mult + + """ + Optimization + """ + + def fit( + self, + train_iter, + optimizer, + lr_scheduler, + eval_iter=None, + metrics=None, + epoch_start=0, + epoch_end=10000, + **kwargs, + ): + + """ + checking + """ + if kwargs: + logging.warning(f"Unknown kwargs: {kwargs}") + + """ + start the main loop + """ + for i_epoch in range(epoch_start, epoch_end): + self.callback_kwargs["epoch"] = i_epoch + epoch_start_time = time.time() + + ########### + # 1] TRAINING + ########### + metrics.reset() + self.net.train() + sum_sample_inst = 0 + sum_sample_elapse = 0.0 + sum_update_elapse = 0 + batch_start_time = time.time() + logging.info(f"Start epoch {i_epoch}:") + for i_batch, (data, target) in enumerate(train_iter): + self.callback_kwargs["batch"] = i_batch + + update_start_time = time.time() + + # [forward] making next step + outputs, losses = self.forward(data, target) + + # [backward] + optimizer.zero_grad() + for loss in losses: + loss.backward() + self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.update()) + optimizer.step() + + # [evaluation] update train metric + metrics.update( + [output.data.cpu() for output in outputs], + target.cpu(), + [loss.data.cpu() for loss in losses], + ) + + # timing each batch + sum_sample_elapse += time.time() - batch_start_time + sum_update_elapse += time.time() - update_start_time + batch_start_time = time.time() + sum_sample_inst += data.shape[0] + + if (i_batch % self.step_callback_freq) == 0: + # retrive eval results and reset metic + self.callback_kwargs["namevals"] = metrics.get_name_value() + metrics.reset() + # speed monitor + self.callback_kwargs["sample_elapse"] = ( + sum_sample_elapse / sum_sample_inst + ) + self.callback_kwargs["update_elapse"] = ( + sum_update_elapse / sum_sample_inst + ) + sum_update_elapse = 0 + sum_sample_elapse = 0 + sum_sample_inst = 0 + # callbacks + self.step_end_callback() + + ########### + # 2] END OF EPOCH + ########### + self.callback_kwargs["epoch_elapse"] = time.time() - epoch_start_time + self.callback_kwargs["optimizer_dict"] = optimizer.state_dict() + self.epoch_end_callback() + + ########### + # 3] Evaluation + ########### + if (eval_iter is not None) and ( + (i_epoch + 1) % max(1, int(self.save_checkpoint_freq / 2)) + ) == 0: + logging.info(f"Start evaluating epoch {i_epoch}:") + + metrics.reset() + self.net.eval() + sum_sample_elapse = 0.0 + sum_sample_inst = 0 + sum_forward_elapse = 0.0 + batch_start_time = time.time() + for i_batch, (data, target) in enumerate(eval_iter): + self.callback_kwargs["batch"] = i_batch + + forward_start_time = time.time() + + outputs, losses = self.forward(data, target) + + metrics.update( + [output.data.cpu() for output in outputs], + target.cpu(), + [loss.data.cpu() for loss in losses], + ) + + sum_forward_elapse += time.time() - forward_start_time + sum_sample_elapse += time.time() - batch_start_time + batch_start_time = time.time() + sum_sample_inst += data.shape[0] + + # evaluation callbacks + self.callback_kwargs["sample_elapse"] = ( + sum_sample_elapse / sum_sample_inst + ) + self.callback_kwargs["update_elapse"] = ( + sum_forward_elapse / sum_sample_inst + ) + self.callback_kwargs["namevals"] = metrics.get_name_value() + self.step_end_callback() + + logging.info("Optimization done!") + + +def get_ucf101( + data_root="/data", + clip_length=8, + train_interval=2, + val_interval=2, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + seed=0, +): + """data iter for ucf-101""" + logging.debug( + f"VideoIter:: clip_length = {clip_length}, " + f"interval = [train: {train_interval}, val: {val_interval}], seed = {seed}" + ) + + normalize = Normalize(mean=mean, std=std) + + train_sampler = RandomSampling( + num=clip_length, interval=train_interval, speed=[1.0, 1.0], seed=(seed + 0) + ) + train = VideoIter( + video_prefix=os.path.join(data_root, "UCF-101"), + txt_list=os.path.join(data_root, "train_list.txt"), + sampler=train_sampler, + force_color=True, + video_transform=Compose( + [ + RandomScale( + make_square=True, aspect_ratio=[0.8, 1.0 / 0.8], slen=[224, 288] + ), + RandomCrop((224, 224)), # insert a resize if needed + RandomHorizontalFlip(), + RandomHLS(vars=[15, 35, 25]), + ToTensor(), + normalize, + ], + aug_seed=(seed + 1), + ), + name="train", + shuffle_list_seed=(seed + 2), + ) + + val_sampler = SequentialSampling( + num=clip_length, interval=val_interval, fix_cursor=True, shuffle=True + ) + val = VideoIter( + video_prefix=os.path.join(data_root, "UCF-101"), + txt_list=os.path.join(data_root, "validation_list.txt"), + sampler=val_sampler, + force_color=True, + video_transform=Compose( + [ + Resize((256, 256)), + CenterCrop((224, 224)), + ToTensor(), + normalize, + ] + ), + name="test", + ) + + return train, val + + +def train_model( + start_with_pretrained=False, + num_workers=16, + resume_epoch=-1, + batch_size=4, + save_frequency=1, + lr_base=0.01, + lr_factor=0.1, + lr_steps=[400000, 800000], + end_epoch=100, + fine_tune=False, +): + + # data iterator + iter_seed = 101 + max(0, resume_epoch) * 100 + + train, val = get_ucf101( + data_root=str(ROOTDIR / "data"), + clip_length=16, + train_interval=2, + val_interval=2, + mean=[124 / 255, 117 / 255, 104 / 255], + std=[1 / (0.0167 * 255)] * 3, + seed=iter_seed, + ) + + train_iter = DataLoader( + train, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=False, + ) + + eval_iter = DataLoader( + val, + batch_size=2 * batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=False, + ) + # load model + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = MFNET_3D(num_classes=101) + model.to(device) + + # init from 2d + # state_dict_2d = torch.load(str(ROOTDIR / "models/pretrained/MFNet2D_ImageNet1k-0000.pth")) + # init_3d_from_2d_dict(net=model, state_dict=state_dict_2d, method="inflation") + + # wrapper (dynamic model) + dynamic_model = DynamicModel( + net=model, + criterion=torch.nn.CrossEntropyLoss().cuda() + if torch.cuda.is_available() + else torch.nn.CrossEntropyLoss(), + model_prefix=str(ROOTDIR / "models/log"), + step_callback_freq=50, + save_checkpoint_freq=save_frequency, + opt_batch_size=batch_size, # optional + ) + + # config optimization + param_base_layers = [] + param_new_layers = [] + name_base_layers = [] + for name, param in dynamic_model.net.named_parameters(): + if fine_tune: + if name.startswith("classifier"): + param_new_layers.append(param) + else: + param_base_layers.append(param) + name_base_layers.append(name) + else: + param_new_layers.append(param) + + if name_base_layers: + out = "['" + "', '".join(name_base_layers) + "']" + logging.info( + f"Optimizer:: >> reducing the learning rate of {len(name_base_layers)} " + f"params: {out if len(out) < 300 else out[0:150] + ' ... ' + out[-150:]}" + ) + if torch.cuda.is_available(): + dynamic_model.net = torch.nn.DataParallel(dynamic_model.net).cuda() + else: + dynamic_model.net = torch.nn.DataParallel(dynamic_model.net).cpu() + + optimizer = torch.optim.SGD( + [ + {"params": param_base_layers, "lr_mult": 0.2}, + {"params": param_new_layers, "lr_mult": 1.0}, + ], + lr=lr_base, + momentum=0.9, + weight_decay=0.0001, + nesterov=True, + ) + + # resume training: model and optimizer + if resume_epoch < 0: + epoch_start = 0 + step_counter = 0 + if start_with_pretrained: + # load params from pretrained 3d network + checkpoint = torch.load( + str(ROOTDIR / "models/pretrained/MFNet3D_UCF-101_Split-1_96.3.pth") + ) + dynamic_model.load_state(checkpoint["state_dict"], strict=False) + else: + dynamic_model.load_checkpoint(epoch=resume_epoch, optimizer=optimizer) + epoch_start = resume_epoch + step_counter = epoch_start * train_iter.__len__() + + # set learning rate scheduler + num_worker = 1 + + lr_scheduler = MultiFactorScheduler( + base_lr=lr_base, + steps=[int(x / (batch_size * num_worker)) for x in lr_steps], + factor=lr_factor, + step_counter=step_counter, + ) + # define evaluation metric + metrics = MetricList( + Loss(name="loss-ce"), + Accuracy(name="top1", topk=1), + Accuracy(name="top5", topk=5), + ) + # enable cudnn tune + cudnn.benchmark = True + + dynamic_model.fit( + train_iter=train_iter, + eval_iter=eval_iter, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + metrics=metrics, + epoch_start=epoch_start, + epoch_end=end_epoch, + ) + + +if __name__ == "__main__": + train_model() diff --git a/example/ucf101/ucf101/transform.py b/example/ucf101/ucf101/transform.py new file mode 100644 index 0000000000..774ad5428e --- /dev/null +++ b/example/ucf101/ucf101/transform.py @@ -0,0 +1,223 @@ +import cv2 +import numpy as np +import torch + + +class Compose(object): + def __init__(self, transforms, aug_seed=0): + self.transforms = transforms + for i, trans in enumerate(self.transforms): + trans.set_random_state(seed=(aug_seed + i)) + + def __call__(self, data): + for trans in self.transforms: + data = trans(data) + return data + + +class Transform(object): + def __init__(self): + self.rng = None + + def set_random_state(self, seed=None): + self.rng = np.random.RandomState(seed) + + +class Normalize(Transform): + def __init__(self, mean, std): + super().__init__() + self.mean = mean + self.std = std + + def __call__(self, tensor): + for trans, m, s in zip(tensor, self.mean, self.std): + trans.sub_(m).div_(s) + return tensor + + +class Resize(Transform): + def __init__(self, size, interpolation=cv2.INTER_LINEAR): + super().__init__() + self.size = size # [w, h] + self.interpolation = interpolation + + def __call__(self, data): + h, w, c = data.shape + + if isinstance(self.size, int): + slen = self.size + if min(w, h) == slen: + return data + if w < h: + new_w = self.size + new_h = int(self.size * h / w) + else: + new_w = int(self.size * w / h) + new_h = self.size + else: + new_w = self.size[0] + new_h = self.size[1] + + if (h != new_h) or (w != new_w): + scaled_data = cv2.resize(data, (new_w, new_h), self.interpolation) + else: + scaled_data = data + + return scaled_data + + +class CenterCrop(Transform): + """Crops the given numpy array at the center to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + + def __init__(self, size): + super().__init__() + if isinstance(size, int): + self.size = (size, size) + else: + self.size = size + + def __call__(self, data): + h, w, c = data.shape + th, tw = self.size + x1 = int(round((w - tw) / 2.0)) + y1 = int(round((h - th) / 2.0)) + cropped_data = data[y1 : (y1 + th), x1 : (x1 + tw), :] + return cropped_data + + +class RandomCrop(Transform): + def __init__(self, size): + super().__init__() + if isinstance(size, int): + self.size = (size, size) + else: + self.size = size + self.rng = np.random.RandomState(0) + + def __call__(self, data): + h, w, c = data.shape + th, tw = self.size + x1 = self.rng.choice(range(w - tw)) + y1 = self.rng.choice(range(h - th)) + cropped_data = data[y1 : (y1 + th), x1 : (x1 + tw), :] + return cropped_data + + +class RandomHorizontalFlip(Transform): + """Randomly horizontally flips the given numpy array with a probability of 0.5""" + + def __init__(self): + super().__init__() + self.rng = np.random.RandomState(0) + + def __call__(self, data): + if self.rng.rand() < 0.5: + data = np.fliplr(data) + data = np.ascontiguousarray(data) + return data + + +class RandomHLS(Transform): + def __init__(self, vars=[15, 35, 25]): + super().__init__() + self.vars = vars + self.rng = np.random.RandomState(0) + + def __call__(self, data): + h, w, c = data.shape + assert c % 3 == 0, f"input channel = {c}, illegal" + + random_vars = [int(round(self.rng.uniform(-x, x))) for x in self.vars] + + base = len(random_vars) + augmented_data = np.zeros( + data.shape, + ) + + for i_im in range(0, int(c / 3)): + augmented_data[:, :, 3 * i_im : (3 * i_im + 3)] = cv2.cvtColor( + data[:, :, 3 * i_im : (3 * i_im + 3)], cv2.COLOR_RGB2HLS + ) + + hls_limits = [180, 255, 255] + for ic in range(0, c): + var = random_vars[ic % base] + limit = hls_limits[ic % base] + augmented_data[:, :, ic] = np.minimum( + np.maximum(augmented_data[:, :, ic] + var, 0), limit + ) + + for i_im in range(0, int(c / 3)): + augmented_data[:, :, 3 * i_im : (3 * i_im + 3)] = cv2.cvtColor( + augmented_data[:, :, 3 * i_im : (3 * i_im + 3)].astype(np.uint8), + cv2.COLOR_HLS2RGB, + ) + + return augmented_data + + +class RandomScale(Transform): + """Rescales the input numpy array to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: cv2.INTER_LINEAR + """ + + def __init__( + self, + make_square=False, + aspect_ratio=[1.0, 1.0], + slen=[224, 288], + interpolation=cv2.INTER_LINEAR, + ): + super().__init__() + assert slen[1] >= slen[0], f"slen ({slen}) should be in increase order" + assert ( + aspect_ratio[1] >= aspect_ratio[0] + ), f"aspect_ratio ({aspect_ratio}) should be in increase order" + self.slen = slen # [min factor, max factor] + self.aspect_ratio = aspect_ratio + self.make_square = make_square + self.interpolation = interpolation + self.rng = np.random.RandomState(0) + + def __call__(self, data): + h, w, c = data.shape + new_w = w + new_h = h if not self.make_square else w + if self.aspect_ratio: + random_aspect_ratio = self.rng.uniform( + self.aspect_ratio[0], self.aspect_ratio[1] + ) + if self.rng.rand() > 0.5: + random_aspect_ratio = 1.0 / random_aspect_ratio + new_w *= random_aspect_ratio + new_h /= random_aspect_ratio + resize_factor = self.rng.uniform(self.slen[0], self.slen[1]) / min(new_w, new_h) + new_w *= resize_factor + new_h *= resize_factor + scaled_data = cv2.resize( + data, (int(new_w + 1), int(new_h + 1)), self.interpolation + ) + return scaled_data + + +class ToTensor(Transform): + def __init__(self, dim=3): + super().__init__() + self.dim = dim + + def __call__(self, clips): + if isinstance(clips, np.ndarray): + H, W, _ = clips.shape + # handle numpy array + clips = torch.from_numpy( + clips.reshape((H, W, -1, self.dim)).transpose((3, 2, 0, 1)) + ) + # backward compatibility + return clips.float() / 255.0 diff --git a/example/ucf101/ucf101/video_iterator.py b/example/ucf101/ucf101/video_iterator.py new file mode 100644 index 0000000000..d0ff7ca606 --- /dev/null +++ b/example/ucf101/ucf101/video_iterator.py @@ -0,0 +1,352 @@ +import os +import logging + +import cv2 +import numpy as np +import torch.utils.data as data + + +class Video(object): + """basic Video class""" + + def __init__(self, vid_path): + self.cap = None + self.faulty_frame = None + self.frame_count = None + self.vid_path = None + self.open(vid_path) + + def __del__(self): + self.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.__del__() + + def reset(self): + self.close() + self.vid_path = None + self.frame_count = -1 + self.faulty_frame = None + return self + + def open(self, vid_path): + assert os.path.exists(vid_path), f"VideoIter:: cannot locate: `{vid_path}'" + + # close previous video & reset variables + self.reset() + + # try to open video + cap = cv2.VideoCapture(vid_path) + if cap.isOpened(): + self.cap = cap + self.vid_path = vid_path + else: + raise IOError(f"VideoIter:: failed to open video: `{vid_path}'") + + return self + + def count_frames(self, check_validity=False): + offset = 0 + if self.vid_path.endswith(".flv"): + offset = -1 + unverified_frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + offset + if check_validity: + verified_frame_count = 0 + for i in range(unverified_frame_count): + self.cap.set(cv2.CAP_PROP_POS_FRAMES, i) + if not self.cap.grab(): + logging.warning( + f"VideoIter:: >> frame (start from 0) {i} corrupted in {self.vid_path}" + ) + break + verified_frame_count = i + 1 + self.frame_count = verified_frame_count + else: + self.frame_count = unverified_frame_count + assert ( + self.frame_count > 0 + ), f"VideoIter:: Video: `{self.vid_path}' has no frames" + return self.frame_count + + def extract_frames(self, ids, force_color=True): + frames = self.extract_frames_fast(ids, force_color) + if frames is None: + # try slow method: + frames = self.extract_frames_slow(ids, force_color) + return frames + + def extract_frames_fast(self, ids, force_color=True): + assert self.cap is not None, "No opened video." + if len(ids) < 1: + return [] + + frames = [] + pre_idx = max(ids) + for idx in ids: + assert (self.frame_count < 0) or ( + idx < self.frame_count + ), f"ids: {ids} > total valid frames({self.frame_count})" + if pre_idx != (idx - 1): + self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + res, frame = self.cap.read() # in BGR/GRAY format + pre_idx = idx + if not res: + self.faulty_frame = idx + return None + if len(frame.shape) < 3: + if force_color: + # Convert Gray to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) + else: + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + return frames + + def extract_frames_slow(self, ids, force_color=True): + assert self.cap is not None, "No opened video." + if len(ids) < 1: + return [] + + frames = [None] * len(ids) + idx = min(ids) + self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + while idx <= max(ids): + res, frame = self.cap.read() # in BGR/GRAY format + if not res: + # end of the video + self.faulty_frame = idx + return None + if idx in ids: + # fond a frame + if len(frame.shape) < 3: + if force_color: + # Convert Gray to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) + else: + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pos = [k for k, i in enumerate(ids) if i == idx] + for k in pos: + frames[k] = frame + idx += 1 + return frames + + def close(self): + if hasattr(self, "cap") and self.cap is not None: + self.cap.release() + self.cap = None + return self + + +class VideoIter(data.Dataset): + def __init__( + self, + video_prefix, + txt_list, + sampler, + video_transform=None, + name="", + force_color=True, + cached_info_path=None, + return_item_subpath=False, + shuffle_list_seed=None, + check_video=False, + tolerant_corrupted_video=None, + ): + super(VideoIter, self).__init__() + # load params + self.sampler = sampler + self.force_color = force_color + self.video_prefix = video_prefix + self.video_transform = video_transform + self.return_item_subpath = return_item_subpath + self.backup_item = None + if (not check_video) and (tolerant_corrupted_video is None): + tolerant_corrupted_video = True + self.tolerant_corrupted_video = tolerant_corrupted_video + self.rng = np.random.RandomState(shuffle_list_seed if shuffle_list_seed else 0) + # load video list + self.video_list = self._get_video_list( + video_prefix=video_prefix, + txt_list=txt_list, + check_video=check_video, + cached_info_path=cached_info_path, + ) + if shuffle_list_seed is not None: + self.rng.shuffle(self.video_list) + logging.info( + f"VideoIter:: iterator initialized (phase: '{name}', num: {len(self.video_list)})" + ) + + def getitem_from_raw_video(self, index): + # get current video info + v_id, label, vid_subpath, frame_count = self.video_list[index] + video_path = os.path.join(self.video_prefix, vid_subpath) + + faulty_frames = [] + successful_trial = False + try: + with Video(vid_path=video_path) as video: + if frame_count < 0: + frame_count = video.count_frames(check_validity=False) + for i_trial in range(20): + # dynamic sampling + sampled_ids = self.sampler.sampling( + range_max=frame_count, v_id=v_id, prev_failed=(i_trial > 0) + ) + if set(list(sampled_ids)).intersection(faulty_frames): + continue + # extracting frames + sampled_frames = video.extract_frames( + ids=sampled_ids, force_color=self.force_color + ) + if sampled_frames is None: + faulty_frames.append(video.faulty_frame) + else: + successful_trial = True + break + except IOError as e: + logging.warning(f">> I/O error({e.errno}): {e.strerror}") + + if not successful_trial: + assert ( + self.backup_item is not None + ), f"VideoIter:: >> frame {faulty_frames} is error & backup is inavailable. [{video_path}]'" + logging.warning( + f">> frame {faulty_frames} is error, use backup item! [{video_path}]" + ) + with Video(vid_path=self.backup_item["video_path"]) as video: + sampled_frames = video.extract_frames( + ids=self.backup_item["sampled_ids"], force_color=self.force_color + ) + elif self.tolerant_corrupted_video: + # assume the error rate less than 10% + if (self.backup_item is None) or (self.rng.rand() < 0.1): + self.backup_item = { + "video_path": video_path, + "sampled_ids": sampled_ids, + } + + clip_input = np.concatenate(sampled_frames, axis=2) + # apply video augmentation + if self.video_transform is not None: + clip_input = self.video_transform(clip_input) + + # print(f"clip_input is:{clip_input}") + return clip_input, label, vid_subpath + + def __getitem__(self, index): + success = False + while not success: + try: + clip_input, label, vid_subpath = self.getitem_from_raw_video(index) + success = True + except Exception as e: + index = self.rng.choice(range(0, self.__len__())) + logging.warning( + f"VideoIter:: ERROR!! (Force using another index:\n{index})\n{e}" + ) + + if self.return_item_subpath: + return clip_input, label, vid_subpath + else: + return clip_input, label + + def __len__(self): + return len(self.video_list) + + def _get_video_list( + self, video_prefix, txt_list, check_video=False, cached_info_path=None + ): + # formate: + # [vid, label, video_subpath, frame_count] + assert os.path.exists( + video_prefix + ), f"VideoIter:: failed to locate: `{video_prefix}'" + assert os.path.exists(txt_list), f"VideoIter:: failed to locate: `{txt_list}'" + + # try to load cached list + cached_video_info = {} + if cached_info_path: + if os.path.exists(cached_info_path): + f = open(cached_info_path, "r") + cached_video_prefix = f.readline().split()[1] + cached_txt_list = f.readline().split()[1] + if (cached_video_prefix == video_prefix.replace(" ", "")) and ( + cached_txt_list == txt_list.replace(" ", "") + ): + logging.info( + f"VideoIter:: loading cached video info from: `{cached_info_path}'" + ) + lines = f.readlines() + for line in lines: + video_subpath, frame_count = line.split() + cached_video_info.update({video_subpath: int(frame_count)}) + else: + logging.warning( + ">> Cached video list mismatched: " + f"(prefix:{cached_video_prefix}, list:{cached_txt_list}) != " + f"expected (prefix:{video_prefix}, list:{txt_list})" + ) + f.close() + else: + if not os.path.exists(os.path.dirname(cached_info_path)): + os.makedirs(os.path.dirname(cached_info_path)) + + # building dataset + video_list = [] + new_video_info = {} + logging_interval = 100 + with open(txt_list) as f: + lines = f.read().splitlines() + logging.info(f"VideoIter:: found {len(lines)} videos in `{txt_list}'") + for i, line in enumerate(lines): + v_id, label, video_subpath = line.split() + video_path = os.path.join(video_prefix, video_subpath) + if not os.path.exists(video_path): + logging.warning(f"VideoIter:: >> cannot locate `{video_path}'") + continue + if check_video: + if video_subpath in cached_video_info: + frame_count = cached_video_info[video_subpath] + elif video_subpath in new_video_info: + frame_count = cached_video_info[video_subpath] + else: + with Video(vid_path=video_path) as video: + frame_count = video.open(video_path).count_frames( + check_validity=True + ) + new_video_info.update({video_subpath: frame_count}) + else: + frame_count = -1 + # [3417, 91, 'TennisSwing/v_TennisSwing_g02_c05.avi', -1] + info = [int(v_id), int(label), video_subpath, frame_count] + video_list.append(info) + if check_video and (i % logging_interval) == 0: + logging.info( + f"VideoIter:: - Checking: {i}/{len(lines)}, \tinfo: {info}" + ) + + # caching video list + if cached_info_path and len(new_video_info) > 0: + logging.info( + f"VideoIter:: adding {len(new_video_info)} lines new video info to: {cached_info_path}" + ) + cached_video_info.update(new_video_info) + with open(cached_info_path, "w") as f: + # head + f.write(f"video_prefix: {video_prefix.replace(' ', '')}\n") + f.write(f"txt_list: {txt_list.replace(' ', '')}\n") + # content + for i, (video_subpath, frame_count) in enumerate( + cached_video_info.items() + ): + if i > 0: + f.write("\n") + f.write(f"{video_subpath}\t{frame_count}") + + return video_list