Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions inference/transnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def predict_raw(self, frames: np.ndarray):

return single_frame_pred, all_frames_pred

def predict_frames(self, frames: np.ndarray):
def predict_frames(self, frames: np.ndarray, silent : bool = False):
assert len(frames.shape) == 4 and frames.shape[1:] == self._input_size, \
"[TransNetV2] Input shape must be [frames, height, width, 3]."

Expand Down Expand Up @@ -61,31 +61,39 @@ def input_iterator():
predictions.append((single_frame_pred.numpy()[0, 25:75, 0],
all_frames_pred.numpy()[0, 25:75, 0]))

print("\r[TransNetV2] Processing video frames {}/{}".format(
min(len(predictions) * 50, len(frames)), len(frames)
), end="")
print("")
if (not silent):

print("\r[TransNetV2] Processing video frames {}/{}".format(
min(len(predictions) * 50, len(frames)), len(frames)
), end="")

if (not silent):

print("")

single_frame_pred = np.concatenate([single_ for single_, all_ in predictions])
all_frames_pred = np.concatenate([all_ for single_, all_ in predictions])

return single_frame_pred[:len(frames)], all_frames_pred[:len(frames)] # remove extra padded frames

def predict_video(self, video_fn: str):
def predict_video(self, video_fn: str, silent : bool = False):
try:
import ffmpeg
except ModuleNotFoundError:
raise ModuleNotFoundError("For `predict_video` function `ffmpeg` needs to be installed in order to extract "
"individual frames from video file. Install `ffmpeg` command line tool and then "
"install python wrapper by `pip install ffmpeg-python`.")

print("[TransNetV2] Extracting frames from {}".format(video_fn))
if (not silent):

print("[TransNetV2] Extracting frames from {}".format(video_fn))

video_stream, err = ffmpeg.input(video_fn).output(
"pipe:", format="rawvideo", pix_fmt="rgb24", s="48x27"
).run(capture_stdout=True, capture_stderr=True)

video = np.frombuffer(video_stream, np.uint8).reshape([-1, 27, 48, 3])
return (video, *self.predict_frames(video))
return (video, *self.predict_frames(video, silent = silent))

@staticmethod
def predictions_to_scenes(predictions: np.ndarray, threshold: float = 0.5):
Expand Down