Skip to content

Commit

Permalink
Add GitHub Workflow with Pylint analyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
Sirozha1337 committed Jan 5, 2024
1 parent fab2921 commit 1c0cdb6
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 65 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Pylint

on: [push]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
pip install -r requirements.txt
- name: Analysing the code with pylint
run: |
pylint --disable=C0114 --disable=C0115 --disable=C0116 $(git ls-files '*.py')
35 changes: 25 additions & 10 deletions auto_subtitle/cli.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,36 @@
import argparse
from faster_whisper import available_models
from .utils.constants import LANGUAGE_CODES
from .main import process
from .utils.convert import str2bool, str2timeinterval


def main():
"""
Main entry point for the script.
Parses command line arguments, processes the inputs using the specified options,
and performs transcription or translation based on the specified task.
"""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("video", nargs="+", type=str,
help="paths to video files to transcribe")
parser.add_argument("--audio_channel", default="0",
type=int, help="audio channel index to use")
parser.add_argument("--sample_interval", type=str2timeinterval, default=None,
help="generate subtitles for a specific fragment of the video (e.g. 01:02:05-01:03:45)")
help="generate subtitles for a specific \
fragment of the video (e.g. 01:02:05-01:03:45)")
parser.add_argument("--model", default="small",
choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--device", type=str, default="auto", choices=[
"cpu", "cuda", "auto"], help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")")
parser.add_argument("--device", type=str, default="auto",
choices=["cpu", "cuda", "auto"],
help="Device to use for computation (\"cpu\", \"cuda\", \"auto\")")
parser.add_argument("--compute_type", type=str, default="default", choices=[
"int8", "int8_float32", "int8_float16",
"int8_bfloat16", "int16", "float16",
"bfloat16", "float32"], help="Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html.")
"int8", "int8_float32", "int8_float16", "int8_bfloat16",
"int16", "float16", "bfloat16", "float32"],
help="Type to use for computation. \
See https://opennmt.net/CTranslate2/quantization.html.")
parser.add_argument("--output_dir", "-o", type=str,
default=".", help="directory to save the outputs")
parser.add_argument("--output_srt", type=str2bool, default=False,
Expand All @@ -32,10 +43,14 @@ def main():
help="model parameter, tweak to increase accuracy")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
help="model parameter, tweak to increase accuracy")
parser.add_argument("--task", type=str, default="transcribe", choices=[
"transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default="auto", choices=["auto","af","am","ar","as","az","ba","be","bg","bn","bo","br","bs","ca","cs","cy","da","de","el","en","es","et","eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi","hr","ht","hu","hy","id","is","it","ja","jw","ka","kk","km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi","mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no","oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk","sl","sn","so","sq","sr","su","sv","sw","ta","te","tg","th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo","zh"],
help="What is the origin language of the video? If unset, it is detected automatically.")
parser.add_argument("--task", type=str, default="transcribe",
choices=["transcribe", "translate"],
help="whether to perform X->X speech recognition ('transcribe') \
or X->English translation ('translate')")
parser.add_argument("--language", type=str, default="auto",
choices=LANGUAGE_CODES,
help="What is the origin language of the video? \
If unset, it is detected automatically.")

args = parser.parse_args().__dict__

Expand Down
32 changes: 19 additions & 13 deletions auto_subtitle/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
from .utils.ffmpeg import get_audio, overlay_subtitles
from .utils.whisper import WhisperAI


def process(args: dict):
model_name: str = args.pop("model")
output_dir: str = args.pop("output_dir")
output_srt: bool = args.pop("output_srt")
srt_only: bool = args.pop("srt_only")
language: str = args.pop("language")
sample_interval: str = args.pop("sample_interval")
device: str = args.pop("device")
compute_type: str = args.pop("compute_type")


os.makedirs(output_dir, exist_ok=True)

if model_name.endswith(".en"):
Expand All @@ -25,33 +24,40 @@ def process(args: dict):
elif language != "auto":
args["language"] = language

audios = get_audio(args.pop("video"), args.pop('audio_channel'), sample_interval)
subtitles = get_subtitles(
audios, output_srt or srt_only, output_dir, model_name, device, compute_type, args
)
audios = get_audio(args.pop("video"), args.pop(
'audio_channel'), sample_interval)

model_args = {}
model_args["model_size_or_path"] = model_name
model_args["device"] = args.pop("device")
model_args["compute_type"] = args.pop("compute_type")

srt_output_dir = output_dir if output_srt or srt_only else tempfile.gettempdir()
subtitles = get_subtitles(audios, srt_output_dir, model_args, args)

if srt_only:
return

overlay_subtitles(subtitles, output_dir, sample_interval)

def get_subtitles(audio_paths: list, output_srt: bool, output_dir: str, model_name: str, device: str, compute_type: str, model_args: dict):
model = WhisperAI(model_name, device, compute_type, model_args)

def get_subtitles(audio_paths: list, output_dir: str,
model_args: dict, transcribe_args: dict):
model = WhisperAI(model_args, transcribe_args)

subtitles_path = {}

for path, audio_path in audio_paths.items():
print(
f"Generating subtitles for {filename(path)}... This might take a while."
)
srt_path = output_dir if output_srt else tempfile.gettempdir()
srt_path = os.path.join(srt_path, f"{filename(path)}.srt")

srt_path = os.path.join(output_dir, f"{filename(path)}.srt")

segments = model.transcribe(audio_path)

with open(srt_path, "w", encoding="utf-8") as srt:
write_srt(segments, file=srt)

subtitles_path[path] = srt_path

return subtitles_path
return subtitles_path
105 changes: 105 additions & 0 deletions auto_subtitle/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
List of available language codes
"""
LANGUAGE_CODES = [
"af",
"am",
"ar",
"as",
"az",
"ba",
"be",
"bg",
"bn",
"bo",
"br",
"bs",
"ca",
"cs",
"cy",
"da",
"de",
"el",
"en",
"es",
"et",
"eu",
"fa",
"fi",
"fo",
"fr",
"gl",
"gu",
"ha",
"haw",
"he",
"hi",
"hr",
"ht",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"jw",
"ka",
"kk",
"km",
"kn",
"ko",
"la",
"lb",
"ln",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"nn",
"no",
"oc",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"sa",
"sd",
"si",
"sk",
"sl",
"sn",
"so",
"sq",
"sr",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tk",
"tl",
"tr",
"tt",
"uk",
"ur",
"uz",
"vi",
"yi",
"yo",
"zh",
"yue",
]
46 changes: 26 additions & 20 deletions auto_subtitle/utils/convert.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from datetime import datetime, timedelta

def str2bool(string):

def str2bool(string: str):
string = string.lower()
str2val = {"true": True, "false": False}

if string in str2val:
return str2val[string]
else:
raise ValueError(
f"Expected one of {set(str2val.keys())}, got {string}")

def str2timeinterval(string):
raise ValueError(
f"Expected one of {set(str2val.keys())}, got {string}")


def str2timeinterval(string: str):
if string is None:
return None

if '-' not in string:
raise ValueError(
f"Expected time interval HH:mm:ss-HH:mm:ss or HH:mm-HH:mm or ss-ss, got {string}")

intervals = string.split('-')
if len(intervals) != 2:
raise ValueError(
Expand All @@ -28,42 +30,47 @@ def str2timeinterval(string):
if start >= end:
raise ValueError(
f"Expected time interval end to be higher than start, got {start} >= {end}")

return [start, end]

def time_to_timestamp(string):

def time_to_timestamp(string: str):
split_time = string.split(':')
if len(split_time) == 0 or len(split_time) > 3 or not all([ x.isdigit() for x in split_time ]):
if len(split_time) == 0 or len(split_time) > 3 or not all(x.isdigit() for x in split_time):
raise ValueError(
f"Expected HH:mm:ss or HH:mm or ss, got {string}")

if len(split_time) == 1:
return int(split_time[0])

if len(split_time) == 2:
return int(split_time[0]) * 60 * 60 + int(split_time[1]) * 60

return int(split_time[0]) * 60 * 60 + int(split_time[1]) * 60 + int(split_time[2])

def try_parse_timestamp(string):

def try_parse_timestamp(string: str):
timestamp = parse_timestamp(string, '%H:%M:%S')
if timestamp is not None:
return timestamp

timestamp = parse_timestamp(string, '%H:%M')
if timestamp is not None:
return timestamp

return parse_timestamp(string, '%S')

def parse_timestamp(string, pattern):

def parse_timestamp(string: str, pattern: str):
try:
date = datetime.strptime(string, pattern)
delta = timedelta(hours=date.hour, minutes=date.minute, seconds=date.second)
delta = timedelta(
hours=date.hour, minutes=date.minute, seconds=date.second)
return int(delta.total_seconds())
except:
except: # pylint: disable=bare-except
return None


def format_timestamp(seconds: float, always_include_hours: bool = False):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
Expand All @@ -79,4 +86,3 @@ def format_timestamp(seconds: float, always_include_hours: bool = False):

hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d},{milliseconds:03d}"

Loading

0 comments on commit 1c0cdb6

Please sign in to comment.