Skip to content

Commit

Permalink
Replace exit with return for python API use
Browse files Browse the repository at this point in the history
  • Loading branch information
dlesbre committed Oct 27, 2024
1 parent 4dd4ed4 commit cade8cb
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 46 deletions.
91 changes: 51 additions & 40 deletions bibtexautocomplete/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from sys import stdout
from tempfile import mkstemp
from typing import Any, Callable, Container, List, NoReturn, Optional, Set
from typing import Any, Callable, Container, List, Optional, Set

from bibtexparser.bibdatabase import UndefinedString

Expand Down Expand Up @@ -42,31 +42,37 @@
pass


def conflict(parser: MyParser, prefix: str, option1: str, option2: str) -> NoReturn:
parser.error(
"{StBold}Conflicting options:\n{Reset}"
+ " Specified both "
+ prefix
+ "{FgYellow}"
+ option1
+ "{Reset} and a {FgYellow}"
+ option2
+ "{Reset} option."
)
def conflict(parser: MyParser, prefix: str, option1: str, option2: str) -> int:
try:
parser.error(
"{StBold}Conflicting options:\n{Reset}"
+ " Specified both "
+ prefix
+ "{FgYellow}"
+ option1
+ "{Reset} and a {FgYellow}"
+ option2
+ "{Reset} option."
)
except ValueError:
return 2


def main(argv: Optional[List[str]] = None) -> None:
def main(argv: Optional[List[str]] = None) -> int:
"""The main function of bibtexautocomplete
Takes an argv like List as argument,
if none, uses sys.argv
see HELP_TEXT or main(["-h"]) for details"""
parser = make_parser()
if parser_autocomplete is not None:
parser_autocomplete(parser)
if argv is None:
args = parser.parse_args()
else:
args = parser.parse_args(argv)
try:
if argv is None:
args = parser.parse_args()
else:
args = parser.parse_args(argv)
except ValueError:
return 2

ANSICodes.use_ansi = stdout.isatty() and not args.no_color

Expand All @@ -85,14 +91,14 @@ def main(argv: Optional[List[str]] = None) -> None:
PREFIX=FIELD_PREFIX,
)
)
return
return 0
if args.version:
print(
"{NAME} version {VERSION} ({VERSION_DATE})".format(
NAME=SCRIPT_NAME, VERSION=VERSION_STR, VERSION_DATE=VERSION_DATE
)
)
return
return 0

if args.silent:
args.verbose = -args.silent
Expand All @@ -111,7 +117,7 @@ def main(argv: Optional[List[str]] = None) -> None:

lookups = OnlyExclude[str].from_nonempty(args.only_query, args.dont_query).filter(LOOKUPS, lambda x: x.name)
if args.only_query != [] and args.dont_query != []:
conflict(parser, "a ", "-q/--only-query", "-Q/--dont-query")
return conflict(parser, "a ", "-q/--only-query", "-Q/--dont-query")
if args.only_query != []:
# remove duplicate from list
args.only_query, dups = list_unduplicate(args.only_query)
Expand All @@ -122,11 +128,11 @@ def main(argv: Optional[List[str]] = None) -> None:

fields = OnlyExclude[FieldType].from_nonempty(args.only_complete, args.dont_complete)
if args.only_complete != [] and args.dont_complete != []:
conflict(parser, "a ", "-c/--only-complete", "-C/--dont-complete")
return conflict(parser, "a ", "-c/--only-complete", "-C/--dont-complete")

entries = OnlyExclude[str].from_nonempty(args.only_entry, args.exclude_entry)
if args.only_entry != [] and args.exclude_entry != []:
conflict(parser, "a ", "-e/--only-entry", "-E/--exclude-entry")
return conflict(parser, "a ", "-e/--only-entry", "-E/--exclude-entry")

if args.protect_all_uppercase:
fields_to_protect_uppercase: Container[str] = FieldNamesSet
Expand All @@ -135,11 +141,11 @@ def main(argv: Optional[List[str]] = None) -> None:
fields_to_protect_proto.default = False
fields_to_protect_uppercase = fields_to_protect_proto
if args.protect_all_uppercase and args.protect_uppercase != []:
conflict(parser, "", "--fpa/--protect-all-uppercase", "--fp/--protect-uppercase")
return conflict(parser, "", "--fpa/--protect-all-uppercase", "--fp/--protect-uppercase")
if args.protect_all_uppercase and args.dont_protect_uppercase != []:
conflict(parser, "", "--fpa/--protect-all-uppercase", "--FP/--dont-protect-uppercase")
return conflict(parser, "", "--fpa/--protect-all-uppercase", "--FP/--dont-protect-uppercase")
if args.protect_uppercase != [] and args.dont_protect_uppercase != []:
conflict(parser, "a ", "--fp/--protect-uppercase", "--FP/--dont-protect-uppercase")
return conflict(parser, "a ", "--fp/--protect-uppercase", "--FP/--dont-protect-uppercase")

if args.force_overwrite:
fields_to_overwrite: Set[FieldType] = FieldNamesSet
Expand All @@ -148,20 +154,23 @@ def main(argv: Optional[List[str]] = None) -> None:
overwrite.default = False
fields_to_overwrite = set(overwrite.filter(FieldNamesSet, lambda x: x))
if args.force_overwrite and args.overwrite != []:
conflict(parser, "", "-f/--force-overwrite", "-w/--overwrite")
return conflict(parser, "", "-f/--force-overwrite", "-w/--overwrite")
if args.force_overwrite and args.dont_overwrite != []:
conflict(parser, "", "-f/--force-overwrite", "-W/--dont-overwrite")
return conflict(parser, "", "-f/--force-overwrite", "-W/--dont-overwrite")
if args.overwrite != [] and args.dont_overwrite != []:
conflict(parser, "a ", "-w/--overwrite", "-W/--dont-overwrite")
return conflict(parser, "a ", "-w/--overwrite", "-W/--dont-overwrite")

if args.diff and args.inplace:
parser.error(
"Cannot use {FgYellow}-D/--diff{Reset} flag and {FgYellow}-i/--inplace{Reset} flag "
"simultaneously, as there\n"
" is a big risk of deleting data.\n"
" If that is truly what you want to do, specify the output file explictly\n"
" with {FgYellow}-o / --output {FgGreen}<filename>{Reset}."
)
try:
parser.error(
"Cannot use {FgYellow}-D/--diff{Reset} flag and {FgYellow}-i/--inplace{Reset} flag "
"simultaneously, as there\n"
" is a big risk of deleting data.\n"
" If that is truly what you want to do, specify the output file explictly\n"
" with {FgYellow}-o / --output {FgGreen}<filename>{Reset}."
)
except ValueError:
return 2

try:
completer = BibtexAutocomplete(
Expand Down Expand Up @@ -200,7 +209,7 @@ def main(argv: Optional[List[str]] = None) -> None:
logger.warn("Interrupted")
if completer.position == 0:
logger.info("No entries were completed")
return None
return 5
_, tempfile = mkstemp(suffix=".btac.bib", prefix="btac-interrupt-", text=True)
logger.header("Dumping data")
with open(tempfile, "w") as file:
Expand All @@ -226,12 +235,14 @@ def main(argv: Optional[List[str]] = None) -> None:
if i == completer.position:
logger.info("Only completed entries up to and including '{}'.\n".format(entry.get("ID", "<no_id>")))
break_next = True

return 5
except KeyboardInterrupt:
logger.warn("Interrupted x2")
return 7
except ValueError:
exit(2)
return 2
except UndefinedString:
exit(1)
return 1
except (IOError, UnicodeDecodeError):
exit(1)
return 1
return 0
3 changes: 2 additions & 1 deletion bibtexautocomplete/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import IO, Iterable, List, NoReturn, Optional, TypeVar

from ..bibtex.constants import FieldNamesSet
from ..utils.ansi import ANSICodes
from ..utils.constants import BTAC_FILENAME, CONNECTION_TIMEOUT, SCRIPT_NAME
from ..utils.logger import logger
from .apis import LOOKUP_NAMES
Expand Down Expand Up @@ -116,7 +117,7 @@ def print_usage(self, file: Optional[IO[str]] = None) -> None:
def error(self, message: str) -> NoReturn:
logger.critical(message + "\n", error="Invalid command line", NAME=SCRIPT_NAME)
self.print_usage(stderr)
raise ValueError(message)
raise ValueError(message.format(**ANSICodes.EmptyCodes))


def make_parser() -> MyParser:
Expand Down
7 changes: 2 additions & 5 deletions tests/test_6_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def get_value(self, res: SafeJSON) -> BibtexEntry:

@pytest.mark.parametrize(("argv", "files_to_compare"), tests)
def test_main(argv: List[str], files_to_compare: List[Tuple[str, str]]) -> None:
main(argv)
assert main(argv) == 0
FakeLookup.count = 0
day = datetime.today().strftime("%Y-%m-%d")
for expected, generated in files_to_compare:
Expand Down Expand Up @@ -637,10 +637,7 @@ def test_main(argv: List[str], files_to_compare: List[Tuple[str, str]]) -> None:

@pytest.mark.parametrize(("argv", "exit_code"), exit_tests)
def test_main_exit(argv: List[str], exit_code: int) -> None:
with pytest.raises(SystemExit) as test_exit:
main(argv)
assert test_exit.type is SystemExit
assert test_exit.value.code == exit_code
assert main(argv) == exit_code


def test_promote() -> None:
Expand Down

0 comments on commit cade8cb

Please sign in to comment.