From 7ace95427ad208618eee00746ed7dec024ffa7d1 Mon Sep 17 00:00:00 2001 From: Lukas Kahwe Smith Date: Fri, 7 Nov 2025 08:06:22 +0100 Subject: [PATCH 1/3] major refactoring to enhance readability AI supported * Split routes into focused modules and add central register_routes(app) (utility, auth, config, rephrase, prompt, check, debug) * Make Slack routes opt-in via settings.slack_enabled; return 503 until Bolt is initialized * Add app lifespan bootstrap (startup.py) to wire DB, HTTP, rules, LT, LLM, etc.; load test rules into Redis when configured * Add security headers and permissive CORS middleware (middleware.py) * Strongly type and cache settings (get_settings, reset_settings_cache); add minimum_versions typing * Centralize API version constants (CHECK_API_VERSION=2.4, REPHRASE_API_VERSION=1.0) and validators * Replace mutable Pydantic defaults with Field(default_factory=...); make BaseRequestIn.config factory-based * Extract rule engine helpers: pattern/phrase matching and utilities; use append_result to build results * Extract text utilities (parse_word_type, german_lemmatization) for reuse * Tighten Slack block builder and add Slack tests (route inclusion, 503 guard, blocks smoke test) * Update tests to pass context to parse_term_replacements/fetch_configs_for_request * Minor docstring and cleanup in model.py and review prompt (returns str | None, supports min_changes) --- README.md | 8 +- app/alternatives.py | 438 +---- app/alternatives_engine/formatting.py | 85 + app/alternatives_engine/sentences.py | 154 ++ app/alternatives_engine/utils.py | 89 + app/bolt.py | 15 +- app/categories.py | 10 + app/config_manager.py | 334 ++++ app/context.py | 4 +- app/dependencies.py | 67 + app/language_processor.py | 226 +++ app/main.py | 2571 +------------------------ app/middleware.py | 49 + app/model.py | 3 + app/models.py | 47 +- app/review_prompt.py | 6 +- app/routes/__init__.py | 41 + app/routes/auth.py | 106 + app/routes/check.py | 132 ++ app/routes/config_routes.py | 130 ++ app/routes/debug.py | 267 +++ app/routes/prompt.py | 155 ++ app/routes/rephrase.py | 134 ++ app/routes/slack.py | 95 + app/routes/utility.py | 192 ++ app/rule_check.py | 376 +--- app/rule_engine/matchers/pattern.py | 229 +++ app/rule_engine/utils.py | 84 + app/rule_processors.py | 445 +++++ app/settings.py | 23 +- app/startup.py | 141 ++ app/text_utils.py | 150 ++ app/version_validators.py | 52 + tests/test_api.py | 16 +- tests/test_slack.py | 91 + 35 files changed, 3743 insertions(+), 3222 deletions(-) create mode 100644 app/alternatives_engine/formatting.py create mode 100644 app/alternatives_engine/sentences.py create mode 100644 app/alternatives_engine/utils.py create mode 100644 app/config_manager.py create mode 100644 app/dependencies.py create mode 100644 app/language_processor.py create mode 100644 app/middleware.py create mode 100644 app/routes/__init__.py create mode 100644 app/routes/auth.py create mode 100644 app/routes/check.py create mode 100644 app/routes/config_routes.py create mode 100644 app/routes/debug.py create mode 100644 app/routes/prompt.py create mode 100644 app/routes/rephrase.py create mode 100644 app/routes/slack.py create mode 100644 app/routes/utility.py create mode 100644 app/rule_engine/matchers/pattern.py create mode 100644 app/rule_engine/utils.py create mode 100644 app/rule_processors.py create mode 100644 app/startup.py create mode 100644 app/text_utils.py create mode 100644 app/version_validators.py create mode 100644 tests/test_slack.py diff --git a/README.md b/README.md index 34b632812..7b8fc2197 100644 --- a/README.md +++ b/README.md @@ -163,14 +163,20 @@ Used for LLM-powered features (e.g., grammatically correct alternatives, rephras ## Slack -Enable the `/witty` Slack command and event handling. +Opt-in Slack integration. When enabled, the API registers the `/slack/commands` endpoint and provides the `/witty` command handler. | Variable | Default | Description | | --------------------- | ------- | ------------------------------------------------------------------------------------- | +| SLACK_ENABLED | false | When true, initialize Slack Bolt and include the Slack routes. | | SLACK_SIGNING_SECRET | (empty) | Slack app signing secret used to verify requests. | | SLACK_BOT_TOKEN | (empty) | Bot token to call Slack APIs. | | SLACK_ORGANIZATION_ID | (empty) | Optional: fallback organization ID for config lookup when user email isn’t available. | +Notes + +- If `SLACK_ENABLED=false` (default), no Slack code is initialized and the Slack routes are not included. +- With `SLACK_ENABLED=true` but empty Slack credentials, the app uses a local/dev Slack client for testing (no external calls). + ## Authentication Supported methods: diff --git a/app/alternatives.py b/app/alternatives.py index ebd711165..0f366053d 100644 --- a/app/alternatives.py +++ b/app/alternatives.py @@ -6,12 +6,7 @@ Config, RuleType, GenderedRolesFormatType, - Alternative, - Config, - Rule, - RuleType, Article, - FrenchGenderSeparatorType, BasicWordType, ) from app.settings import Settings @@ -21,13 +16,16 @@ from app.verbs import Verbs from app.adjectives import Adjectives from app.categories import is_sub_category_enabled, make_category_advanced -from app.helper import upperfirst, find_common_prefix, check_word_case +from app.helper import check_word_case from app.query_definitions import declensions_config from copy import deepcopy from spacy.tokens import Doc from logging import Logger from pluralizefr import pluralize +from app.alternatives_engine import utils +from app.alternatives_engine import formatting +from app.alternatives_engine import sentences class Alternatives: @@ -426,6 +424,70 @@ async def add_german_article_to_alternative( return alternative + def add_article_to_alternative( + self, + lang: LangType, + alternative: Alternative, + article_index: int | None, + article: str, + separator: str, + ) -> Alternative: + if not article: + return alternative + + # Compose lemma with article using language-specific adjustments + new_lemma = utils.add_article(lang, alternative.lemma, article, separator) + alternative.lemma = new_lemma + + # Update word_types to include an ARTICLE at the beginning + alternative.word_types.insert( + 0, + { + "word_type": WordType.ARTICLE, + "lower_case": True, + "lemmatize": True, + }, + ) + + # If the article contains a gender separator (e.g., la·le), mirror the German handling + # by inserting an extra empty token and a second ARTICLE marker to keep alignment. + if separator and separator in article: + alternative.word_types.insert( + 1, + { + "word_type": "", + "lower_case": True, + "lemmatize": True, + }, + ) + alternative.word_types.insert( + 2, + { + "word_type": WordType.ARTICLE, + "lower_case": True, + "lemmatize": True, + }, + ) + + # If male/female forms are provided, compose them with corresponding + # masculine/feminine articles as well. Do not rely on is_gendered_noun + # being set before this call (some call sites set it afterwards). + if alternative.male_form is not None and alternative.female_form is not None: + male_article, female_article = utils.article_binary_pair( + self.static_rules, lang, article + ) + + if alternative.male_form is not None: + alternative.male_form = utils.add_article( + lang, alternative.male_form, male_article, separator + ) + if alternative.female_form is not None: + alternative.female_form = utils.add_article( + lang, alternative.female_form, female_article, separator + ) + + return alternative + def fetch_german_article_for_flexion( self, flexion: str | None, gender: str, article: str ) -> Article | None: @@ -439,7 +501,9 @@ def fetch_german_article_for_flexion( ): return None - article_forms = self.static_rules[LangType.DE][gender + "_articles"][article][form] + article_forms = self.static_rules[LangType.DE][gender + "_articles"][article][ + form + ] if isinstance(article_forms, Article): article_forms.fallback = article @@ -869,9 +933,9 @@ async def german_gendered_alternatives( return [], binary_case male_form = male_forms[target_form] - male_form_with_prefix = self.add_german_prefix(male_form, prefix) + male_form_with_prefix = utils.add_german_prefix(male_form, prefix) female_form = female_forms[target_form] - female_form_with_prefix = self.add_german_prefix(female_form, prefix) + female_form_with_prefix = utils.add_german_prefix(female_form, prefix) if ( rule.dynamic.article is None or not is_singular @@ -895,7 +959,8 @@ async def german_gendered_alternatives( if male_form == female_form: lemma = prefix + male_form else: - lemma = self.inclusive_alternative( + lemma = formatting.inclusive_alternative( + self.static_rules, LangType.DE, male_form, female_form, @@ -910,7 +975,8 @@ async def german_gendered_alternatives( additional_prefix = "" for additional_word in additional_words: additional_prefix += ( - self.inclusive_alternative( + formatting.inclusive_alternative( + self.static_rules, LangType.DE, additional_word["male_form"], additional_word["female_form"], @@ -944,10 +1010,8 @@ async def german_gendered_alternatives( lemma = male_form_with_prefix if male_form != female_form: - conjunction = ( - self.static_rules[LangType.DE]["noun_conjunction"]["singular"] - if is_singular - else self.static_rules[LangType.DE]["noun_conjunction"]["plural"] + conjunction = utils.get_noun_conjunction( + self.static_rules, LangType.DE, is_singular ) lemma = female_form_with_prefix + conjunction + lemma @@ -1025,7 +1089,7 @@ async def german_gendered_alternatives( text, alternative_prefix + additional_prefix - + self.add_german_prefix(male_forms[form], prefix) + + utils.add_german_prefix(male_forms[form], prefix) + alternative_suffix, is_singular, True, @@ -1034,66 +1098,6 @@ async def german_gendered_alternatives( return alternatives, binary_case - def add_german_prefix(self, word: str, prefix: str) -> str: - if len(prefix) == 0 or word.startswith(prefix): - return word - - if not word.startswith("-") and not prefix.endswith("-"): - word = word[0].lower() + word[1:] - - return prefix + word - - def add_article(self, lang: LangType, text: str, article: str, separator: str): - if ( - lang == LangType.FR - and (article.endswith("le") or article == "la") - and text[0] in ["a", "e", "i", "o", "u", "h"] - ): - return "l'" + text - - if separator != FrenchGenderSeparatorType.POINT_MEDIAN: - article = article.replace(FrenchGenderSeparatorType.POINT_MEDIAN, separator) - - return article + " " + text - - def add_article_to_alternative( - self, - lang: LangType, - alternative: Alternative, - article_index: int, - article: str, - separator: str, - ): - alternative.lemma = self.add_article( - lang, alternative.lemma, article, separator - ) - - if isinstance(alternative.male_form, str): - alternative.male_form = self.add_article( - lang, - alternative.male_form, - self.get_article_by_index( - lang, - "masculine_articles", - article_index, - ), - separator, - ) - - if isinstance(alternative.female_form, str): - alternative.female_form = self.add_article( - lang, - alternative.female_form, - self.get_article_by_index( - lang, - "feminine_articles", - article_index, - ), - separator, - ) - - return alternative - async def noun_alternatives( self, lang: LangType, @@ -1103,228 +1107,28 @@ async def noun_alternatives( male_form: str, female_form: str, article: str | None = None, - ) -> dict[str]: + ) -> tuple[str, str, dict]: + """Build inclusive and binary noun alternatives for a pair of male/female forms. + + Returns a tuple of possibly adjusted male_form, female_form and a dict mapping + GenderedRolesFormatType to the constructed string. + """ sentence_male_tokens = self.model.fetch_tokens(lang, male_form) sentence_female_tokens = self.model.fetch_tokens(lang, female_form) - if len(sentence_male_tokens) != len(sentence_female_tokens): - return None, None, {} - inclusive_form = "" - binary_form = "" - male_form_sub_sentence = "" - female_form_sub_sentence = "" - sub_sentence_contains_noun = False - - if article: - article = article.lower() - - for token_index in range(len(sentence_male_tokens)): - if ( - sentence_male_tokens[token_index].text - != sentence_female_tokens[token_index].text - ): - inclusive_form += self.inclusive_alternative( - lang, - sentence_male_tokens[token_index].text, - sentence_female_tokens[token_index].text, - "", - separator, - noun_separator, - separate_gender_plural, - ) - conjunction = ( - self.static_rules[lang]["noun_conjunction"]["singular"] - if self.model.is_token_singular( - lang, sentence_male_tokens[token_index] - ) - else self.static_rules[lang]["noun_conjunction"]["plural"] - ) - if lang == LangType.FR: - if ( - token_index > 0 - and sentence_male_tokens[token_index - 1].lemma_ - in self.static_rules[lang]["masculine_articles"] - ): - is_noun = True - sub_sentence_contains_noun = True - else: - is_noun = ( - await self.model._fetch_word_type( - lang, - sentence_male_tokens[token_index], - WordType.NOUN, - True, - True, - ) - == WordType.NOUN - ) - if sub_sentence_contains_noun == True or is_noun: - sub_sentence_contains_noun = True - - if male_form_sub_sentence != "": - male_form_sub_sentence += sentence_male_tokens[ - token_index - 1 - ].whitespace_ - female_form_sub_sentence += sentence_female_tokens[ - token_index - 1 - ].whitespace_ - - male_form_sub_sentence += sentence_male_tokens[token_index].text - female_form_sub_sentence += ( - sentence_female_tokens[token_index].text - if token_index > 0 or is_noun - else sentence_female_tokens[token_index].text.lower() - ) - else: - binary_form += ( - sentence_female_tokens[token_index].text - + conjunction - + sentence_male_tokens[token_index].text - ) - else: - inclusive_form += sentence_male_tokens[token_index].text - if male_form_sub_sentence != "": - if sub_sentence_contains_noun: - binary_form += ( - male_form_sub_sentence - + conjunction - + female_form_sub_sentence - + sentence_female_tokens[token_index - 1].whitespace_ - ) - else: - # TODO add user preference to choose male form over female form - binary_form += ( - female_form_sub_sentence - + sentence_male_tokens[token_index - 1].whitespace_ - ) - male_form_sub_sentence = "" - female_form_sub_sentence = "" - sub_sentence_contains_noun = False - - binary_form += sentence_male_tokens[token_index].text - - inclusive_form += sentence_male_tokens[token_index].whitespace_ - - if male_form_sub_sentence == "": - binary_form += sentence_male_tokens[token_index].whitespace_ - - if male_form_sub_sentence != "": - if article: - if article in self.static_rules[lang]["articles_inclusive_map"]: - male_article = article - female_article = article - if article in self.static_rules[lang]["masculine_articles"]: - male_article = article - female_article = self.static_rules[lang]["articles_binary_map"][ - article - ] - else: - male_article = self.static_rules[lang]["articles_binary_map"][ - article - ] - female_article = article - - male_form_sub_sentence = self.add_article( - lang, male_form_sub_sentence, male_article, separator - ) - female_form_sub_sentence = self.add_article( - lang, female_form_sub_sentence, female_article, separator - ) - - binary_form += ( - male_form_sub_sentence + conjunction + female_form_sub_sentence - ) - - if article: - inclusive_form = self.add_article( - lang, - inclusive_form, - self.static_rules[lang]["articles_inclusive_map"][article], - separator, - ) - - return ( - male_form_sub_sentence, - female_form_sub_sentence, - { - GenderedRolesFormatType.INCLUSIVE_GENDER: inclusive_form, - GenderedRolesFormatType.BINARY_GENDER: binary_form, - }, + male_sub, female_sub, variants = await sentences.build_sentence_gendered_forms( + self.model, + self.static_rules, + lang, + sentence_male_tokens, + sentence_female_tokens, + separator, + noun_separator, + separate_gender_plural, + article, ) - def inclusive_alternative( - self, - lang: LangType, - male_form: str, - female_form: str, - prefix: str, - separator: str, - noun_separator: str, - separate_gender_plural: bool, - ): - if lang == LangType.DE: - if male_form.lower() in self.static_rules[lang]["masculine_articles"]: - return female_form + separator + male_form - - short_gender_star = True - common_prefix = ( - "" - if male_form.endswith("mann") - else find_common_prefix(male_form, female_form, False, False) - ) - if len(male_form) - len(common_prefix) > 2: - common_prefix = female_form - suffix = self.add_german_prefix(male_form, prefix) - short_gender_star = False - elif len(female_form) >= len(male_form): - # Mitarbeiterin + Mitarbeiter = Mitarbeiter - suffix = female_form[len(common_prefix) :] - else: - # Vorgesetze + Vorgesetzter = Vorgesetze - suffix = male_form[len(common_prefix) :] - - temp_separator = noun_separator - # In - if separator != noun_separator: - if short_gender_star: - suffix = upperfirst(suffix) - else: - temp_separator = "/" - - return self.add_german_prefix( - common_prefix + temp_separator + suffix, prefix - ) - - if lang == LangType.FR: - male_form_lower = male_form.lower() - if male_form_lower in self.static_rules[lang]["masculine_articles"]: - inclusive_form = self.static_rules[lang]["masculine_articles"][ - male_form_lower - ] - if male_form != male_form_lower: - inclusive_form = upperfirst(inclusive_form) - return inclusive_form - - common_prefix = find_common_prefix(male_form, female_form, False, False) - # Il est un poète - if len(common_prefix) < 3: - return prefix + male_form + separator + female_form.lower() - - if len(female_form) >= len(male_form): - suffix = female_form[len(common_prefix) :] - gender_prefix = male_form - else: - suffix = male_form[len(common_prefix) :] - gender_prefix = female_form - - # expérimentés / expérimentées => expérimenté·es - if gender_prefix.endswith("s"): - gender_prefix = gender_prefix[0:-1] - - if separate_gender_plural and suffix.endswith("s"): - suffix = suffix[0:-1] + separator + "s" - - return prefix + gender_prefix + separator + suffix + return male_sub or male_form, female_sub or female_form, variants def handle_single_tilde( self, alternative: Alternative, prefix: bool, is_singular: bool @@ -1338,7 +1142,7 @@ def handle_single_tilde( if word.count("~") == 1: slash = False if word.startswith("~"): - word = self.add_german_prefix(word[1:], prefix) + word = utils.add_german_prefix(word[1:], prefix) else: position = word.find("~") if word[position + 1].islower(): @@ -1457,10 +1261,11 @@ def french_nouns_with_articles( else: forms = {} for form in ["masculine", "feminine"]: - forms[form] = self.add_article( + forms[form] = utils.add_article( LangType.FR, alternative.lemma, - self.get_article_by_index( + utils.get_article_by_index( + self.static_rules, LangType.FR, form + "_articles", article_index, @@ -1469,13 +1274,10 @@ def french_nouns_with_articles( ) alternative.lemma = forms["masculine"] if forms["masculine"] != forms["feminine"]: - alternative.lemma += ( - self.static_rules[LangType.FR]["noun_conjunction"]["plural"] - if is_plural - else self.static_rules[LangType.FR]["noun_conjunction"][ - "singular" - ] - ) + forms["feminine"] + conjunction = utils.get_noun_conjunction( + self.static_rules, LangType.FR, not is_plural + ) + alternative.lemma += conjunction + forms["feminine"] alternative.male_form = forms["masculine"] alternative.female_form = forms["feminine"] @@ -1485,7 +1287,8 @@ def french_nouns_with_articles( LangType.FR, alternative, article_index, - self.get_article_by_index( + utils.get_article_by_index( + self.static_rules, LangType.FR, result["gender_1"] + "_articles", article_index, @@ -1496,30 +1299,3 @@ def french_nouns_with_articles( alternatives.append(alternative) return alternatives - - def get_article_by_index( - self, lang: LangType, articles_list: str, article_index: int - ): - return list(self.static_rules[lang][articles_list].keys())[article_index] - - def get_adjective_alternatives_french(self, male_form, female_form): - lemma = male_form + "~" + female_form - return [ - Alternative( - lemma, - [lemma], - [ - { - "word_type": "a", - "lower_case": True, - "lemmatize": True, - } - ], - False, - False, - False, - False, - False, - True, - ) - ] diff --git a/app/alternatives_engine/formatting.py b/app/alternatives_engine/formatting.py new file mode 100644 index 000000000..01375dba7 --- /dev/null +++ b/app/alternatives_engine/formatting.py @@ -0,0 +1,85 @@ +"""Formatting helpers for gendered/inclusive alternatives.""" + +from app.models import LangType +from app.helper import upperfirst, find_common_prefix + + +def inclusive_alternative( + static_rules: dict, + lang: LangType, + male_form: str, + female_form: str, + prefix: str, + separator: str, + noun_separator: str, + separate_gender_plural: bool, +): + if lang == LangType.DE: + if male_form.lower() in static_rules[lang]["masculine_articles"]: + return female_form + separator + male_form + + short_gender_star = True + common_prefix = ( + "" + if male_form.endswith("mann") + else find_common_prefix(male_form, female_form, False, False) + ) + if len(male_form) - len(common_prefix) > 2: + common_prefix = female_form + suffix = _add_german_prefix(male_form, prefix) + short_gender_star = False + elif len(female_form) >= len(male_form): + # Mitarbeiterin + Mitarbeiter = Mitarbeiter + suffix = female_form[len(common_prefix) :] + else: + # Vorgesetze + Vorgesetzter = Vorgesetze + suffix = male_form[len(common_prefix) :] + + temp_separator = noun_separator + # In + if separator != noun_separator: + if short_gender_star: + suffix = upperfirst(suffix) + else: + temp_separator = "/" + + return _add_german_prefix(common_prefix + temp_separator + suffix, prefix) + + if lang == LangType.FR: + male_form_lower = male_form.lower() + if male_form_lower in static_rules[lang]["masculine_articles"]: + inclusive_form = static_rules[lang]["masculine_articles"][male_form_lower] + if male_form != male_form_lower: + inclusive_form = upperfirst(inclusive_form) + return inclusive_form + + common_prefix = find_common_prefix(male_form, female_form, False, False) + # Il est un poète + if len(common_prefix) < 3: + return prefix + male_form + separator + female_form.lower() + + if len(female_form) >= len(male_form): + suffix = female_form[len(common_prefix) :] + gender_prefix = male_form + else: + suffix = male_form[len(common_prefix) :] + gender_prefix = female_form + + # expérimentés / expérimentées => expérimenté·es + if gender_prefix.endswith("s"): + gender_prefix = gender_prefix[0:-1] + + if separate_gender_plural and suffix.endswith("s"): + suffix = suffix[0:-1] + separator + "s" + + return prefix + gender_prefix + separator + suffix + + +def _add_german_prefix(word: str, prefix: str) -> str: + if len(prefix) == 0 or word.startswith(prefix): + return word + + if not word.startswith("-") and not prefix.endswith("-"): + word = word[0].lower() + word[1:] + + return prefix + word diff --git a/app/alternatives_engine/sentences.py b/app/alternatives_engine/sentences.py new file mode 100644 index 000000000..c5f0f2d1f --- /dev/null +++ b/app/alternatives_engine/sentences.py @@ -0,0 +1,154 @@ +"""Sentence-level gendered form builders for alternatives.""" + +from spacy.tokens import Doc +from app.models import LangType, WordType, GenderedRolesFormatType +from app.alternatives_engine import formatting +from app.alternatives_engine import utils + + +async def build_sentence_gendered_forms( + model, + static_rules: dict, + lang: LangType, + sentence_male_tokens: Doc, + sentence_female_tokens: Doc, + separator: str, + noun_separator: str, + separate_gender_plural: bool, + article: str | None, +): + inclusive_form = "" + binary_form = "" + male_form_sub_sentence = "" + female_form_sub_sentence = "" + sub_sentence_contains_noun = False + + if article: + article = article.lower() + + for token_index in range(len(sentence_male_tokens)): + if ( + sentence_male_tokens[token_index].text + != sentence_female_tokens[token_index].text + ): + inclusive_form += formatting.inclusive_alternative( + static_rules, + lang, + sentence_male_tokens[token_index].text, + sentence_female_tokens[token_index].text, + "", + separator, + noun_separator, + separate_gender_plural, + ) + conjunction = ( + static_rules[lang]["noun_conjunction"]["singular"] + if model.is_token_singular(lang, sentence_male_tokens[token_index]) + else static_rules[lang]["noun_conjunction"]["plural"] + ) + if lang == LangType.FR: + if ( + token_index > 0 + and sentence_male_tokens[token_index - 1].lemma_ + in static_rules[lang]["masculine_articles"] + ): + is_noun = True + sub_sentence_contains_noun = True + else: + is_noun = ( + await model._fetch_word_type( + lang, + sentence_male_tokens[token_index], + WordType.NOUN, + True, + True, + ) + == WordType.NOUN + ) + if sub_sentence_contains_noun == True or is_noun: + sub_sentence_contains_noun = True + + if male_form_sub_sentence != "": + male_form_sub_sentence += sentence_male_tokens[ + token_index - 1 + ].whitespace_ + female_form_sub_sentence += sentence_female_tokens[ + token_index - 1 + ].whitespace_ + + male_form_sub_sentence += sentence_male_tokens[token_index].text + female_form_sub_sentence += ( + sentence_female_tokens[token_index].text + if token_index > 0 or is_noun + else sentence_female_tokens[token_index].text.lower() + ) + else: + binary_form += ( + sentence_female_tokens[token_index].text + + conjunction + + sentence_male_tokens[token_index].text + ) + else: + inclusive_form += sentence_male_tokens[token_index].text + if male_form_sub_sentence != "": + if sub_sentence_contains_noun: + binary_form += ( + male_form_sub_sentence + + conjunction + + female_form_sub_sentence + + sentence_female_tokens[token_index - 1].whitespace_ + ) + else: + # TODO add user preference to choose male form over female form + binary_form += ( + female_form_sub_sentence + + sentence_male_tokens[token_index - 1].whitespace_ + ) + male_form_sub_sentence = "" + female_form_sub_sentence = "" + sub_sentence_contains_noun = False + + binary_form += sentence_male_tokens[token_index].text + + inclusive_form += sentence_male_tokens[token_index].whitespace_ + + if male_form_sub_sentence == "": + binary_form += sentence_male_tokens[token_index].whitespace_ + + if male_form_sub_sentence != "": + if article: + if article in static_rules[lang]["articles_inclusive_map"]: + male_article = article + female_article = article + if article in static_rules[lang]["masculine_articles"]: + male_article = article + female_article = static_rules[lang]["articles_binary_map"][article] + else: + male_article = static_rules[lang]["articles_binary_map"][article] + female_article = article + + male_form_sub_sentence = utils.add_article( + lang, male_form_sub_sentence, male_article, separator + ) + female_form_sub_sentence = utils.add_article( + lang, female_form_sub_sentence, female_article, separator + ) + + binary_form += male_form_sub_sentence + conjunction + female_form_sub_sentence + + if article: + inclusive_form = utils.add_article( + lang, + inclusive_form, + static_rules[lang]["articles_inclusive_map"][article], + separator, + ) + + return ( + male_form_sub_sentence, + female_form_sub_sentence, + { + GenderedRolesFormatType.INCLUSIVE_GENDER: inclusive_form, + GenderedRolesFormatType.BINARY_GENDER: binary_form, + }, + ) diff --git a/app/alternatives_engine/utils.py b/app/alternatives_engine/utils.py new file mode 100644 index 000000000..72db85a6b --- /dev/null +++ b/app/alternatives_engine/utils.py @@ -0,0 +1,89 @@ +"""Utilities for alternatives generation and formatting.""" + +from app.models import ( + LangType, + FrenchGenderSeparatorType, + Alternative, +) + + +def article_binary_pair( + static_rules: dict, lang: LangType, article: str +) -> tuple[str, str]: + if ( + "inclusive_articles" in static_rules[lang] + and article in static_rules[lang]["inclusive_articles"] + ): + masculine_base = static_rules[lang]["articles_map"][article] + male_article = masculine_base + female_article = static_rules[lang]["articles_binary_map"][masculine_base] + return male_article, female_article + + if article in static_rules[lang].get("masculine_articles", {}): + return article, static_rules[lang]["articles_binary_map"][article] + + # assume feminine + return static_rules[lang]["articles_binary_map"][article], article + + +def get_noun_conjunction(static_rules: dict, lang: LangType, is_singular: bool) -> str: + return ( + static_rules[lang]["noun_conjunction"]["singular"] + if is_singular + else static_rules[lang]["noun_conjunction"]["plural"] + ) + + +def add_german_prefix(word: str, prefix: str) -> str: + if not prefix or word.startswith(prefix): + return word + + if not word.startswith("-") and not prefix.endswith("-"): + word = word[0].lower() + word[1:] + + return prefix + word + + +def add_article(lang: LangType, text: str, article: str, separator: str) -> str: + if ( + lang == LangType.FR + and (article.endswith("le") or article == "la") + and text[0] in ["a", "e", "i", "o", "u", "h"] + ): + return "l'" + text + + if separator != FrenchGenderSeparatorType.POINT_MEDIAN: + article = article.replace(FrenchGenderSeparatorType.POINT_MEDIAN, separator) + + return article + " " + text + + +def get_article_by_index( + static_rules: dict, lang: LangType, articles_list: str, article_index: int +): + return list(static_rules[lang][articles_list].keys())[article_index] + + +def build_french_adjective_alternatives( + male_form: str, female_form: str +) -> list[Alternative]: + lemma = male_form + "~" + female_form + return [ + Alternative( + lemma, + [lemma], + [ + { + "word_type": "a", + "lower_case": True, + "lemmatize": True, + } + ], + False, + False, + False, + False, + False, + True, + ) + ] diff --git a/app/bolt.py b/app/bolt.py index eebb9b4c5..400de2b59 100644 --- a/app/bolt.py +++ b/app/bolt.py @@ -1,5 +1,6 @@ from app.settings import Settings from app.models import ResultOut, Language +from app.context import AppContext from slack_bolt.app.async_app import AsyncApp from slack_sdk.models.blocks import ( @@ -10,7 +11,7 @@ from slack_sdk.web.async_client import AsyncWebClient -def get_bolt(settings: Settings): +def get_bolt(settings: Settings, context: AppContext) -> AsyncApp: if settings.slack_bot_token and settings.slack_signing_secret: # pragma: no cover bolt = AsyncApp( token=settings.slack_bot_token, signing_secret=settings.slack_signing_secret @@ -24,6 +25,12 @@ def get_bolt(settings: Settings): ), ) + # Inject AppContext into Slack Bolt's context for all listeners + @bolt.middleware + async def inject_app_context(context_, next): # type: ignore[no-redef] + context_["app_context"] = context + return await next() + return bolt @@ -73,9 +80,10 @@ async def process_command_witty( ) ) - if len(result.alternatives): + alternatives_list = result.alternatives or [] + if alternatives_list: alternatives = "" - for alternative in result.alternatives: + for alternative in alternatives_list: if alternative.remove: alternatives += f"\n• ~{alternative.text}~" else: @@ -92,3 +100,4 @@ async def process_command_witty( ) await respond(blocks=blocks) + return None diff --git a/app/categories.py b/app/categories.py index 5e968dd5b..78fa200d6 100644 --- a/app/categories.py +++ b/app/categories.py @@ -2,6 +2,16 @@ from functools import lru_cache +# Centralized list of inclusive language-related categories used across the API. +inclusive_categories = [ + "communal", + "d_and_i", + "emotional_security", + "inclusive", + "orthography", +] + + @lru_cache() def load_json_data(file_name): with open(file_name) as file: diff --git a/app/config_manager.py b/app/config_manager.py new file mode 100644 index 000000000..cb8f93f7c --- /dev/null +++ b/app/config_manager.py @@ -0,0 +1,334 @@ +"""Configuration management functions for user and organization settings.""" + +from typing import Optional + +from fastapi import HTTPException + +from app.context import AppContext +from app.text_utils import parse_word_type +from app.models import ( + BaseRequestIn, + CheckRequestIn, + LangType, + ResultConf, + RuleConfig, +) +from app.categories import ( + get_category_keys, + get_parent_category_name, + make_category_advanced, +) +from app.categories import inclusive_categories + + +def parse_term_replacement( + lemma: str, term_replacement: dict, context: AppContext +) -> dict: + word_type = ( + term_replacement["word_type"] if "word_type" in term_replacement else "~" + ) + + word_type, lower_case, lemmatize = parse_word_type(word_type) + + word_type = tuple( + [ + { + "word_type": word_type, + "lower_case": lower_case, + "lemmatize": lemmatize, + } + ] + ) + + if lower_case and not lemmatize: + lemma = lemma.lower() + + if list(filter(lemma.endswith, context.term_replacement_langs)) != []: + term_replacement["lang"] = lemma[-2:] + term_replacement["lemma"] = lemma[0:-3] + else: + term_replacement["lang"] = None + term_replacement["lemma"] = lemma + + lang = LangType.EN if term_replacement["lang"] is None else term_replacement["lang"] + term_replacement["words"] = context.model.tokenize(term_replacement["lemma"], lang) + term_replacement["word_types"] = word_type * len(term_replacement["words"]) + + term_replacement["false_positives"] = [] + term_replacement["parsed_alternatives"] = [] + for alternative in term_replacement["alternatives"]: + term_replacement["false_positives"].append(alternative) + + alternative = {"lemma": alternative} + alternative["words"] = context.model.tokenize(alternative["lemma"], lang) + alternative["word_types"] = word_type * len(alternative["words"]) + term_replacement["parsed_alternatives"].append(alternative) + + return term_replacement + + +def parse_term_replacements( + term_replacements_source: dict | None, context: AppContext +) -> dict: + term_replacements = {} + if term_replacements_source is not None: + for lemma in term_replacements_source: + term_replacement = dict(term_replacements_source[lemma]) + term_replacement = parse_term_replacement(lemma, term_replacement, context) + + term_replacements[lemma] = term_replacement + + return term_replacements + + +async def fetch_user_organization_configs( + email: str, context: AppContext +) -> dict | None: + configs = await context.redis.fetch_user_configs_from_redis(email) + + configs["organization_name"] = None + configs["organization_config_hash"] = None + configs["organization_domains"] = None + configs["organization_trial_ends_at"] = None + + if "organization_id" in configs and configs["organization_id"] is not None: + organization_configs = ( + await context.redis.fetch_organization_configs_from_redis( + configs["organization_id"] + ) + ) + + if not configs.get("plan"): + configs["plan"] = organization_configs["plan"] + + if "trial_ends_at" in organization_configs: + configs["organization_trial_ends_at"] = organization_configs[ + "trial_ends_at" + ] + + configs["organization_name"] = organization_configs["name"] + + configs["organization_config_hash"] = organization_configs.get("config_hash") + + configs["organization_domains"] = organization_configs.get("domains", {}) + + configs["organization_config"] = organization_configs["config"] + + configs["organization_term_replacements"] = organization_configs[ + "term_replacements" + ] + + configs["organization_false_positives"] = organization_configs[ + "false_positives" + ] + else: + configs["organization_id"] = None + + return configs + + +def apply_configs( + check_request_in: CheckRequestIn, + configs: dict, + plan: str, + force_disables: bool = True, +): + disabled_categories = check_request_in.config.disabled_categories + if "force_categories" not in configs or configs["force_categories"] is None: + configs["force_categories"] = [] + + for config in configs: + if config == "force_categories": + continue + + data = configs[config] + if data is None: + continue + + if config == "categories": + for category in data: + category_data = data[category] + if category_data["status"] != "force": + continue + + # BC handling for old category names + if category.startswith("advanced_"): + category = make_category_advanced( + category.removeprefix("advanced_") + ) + + if category_data["value"]: + if category in disabled_categories: + disabled_categories.remove(category) + else: + force_disables_category = force_disables + if not force_disables_category and len(configs["force_categories"]): + parent_category = get_parent_category_name(category) + force_disables_category = ( + parent_category in configs["force_categories"] + ) + + if force_disables_category and category not in disabled_categories: + disabled_categories.append(category) + elif config == "store_context": + if ( + plan is not None + and plan != "witty_free" + and data["status"] == "force" + and not data["value"] + ): + check_request_in.config.__setattr__("store_context", False) + elif config == "llm_alternatives": + if ( + plan is not None + and plan != "witty_free" + and data["status"] == "force" + and data["value"] + ): + check_request_in.config.__setattr__("llm_alternatives", True) + elif data["status"] == "force": + check_request_in.config.__setattr__(config, data["value"]) + + for category in inclusive_categories: + if ( + category in check_request_in.config.disabled_categories + and category not in disabled_categories + ): + disabled_categories.append(category) + + check_request_in.config.__setattr__("disabled_categories", disabled_categories) + check_request_in.config.__setattr__("plan", plan) + + +async def fetch_configs_for_request( + request_in: BaseRequestIn, user_email: Optional[str], context: AppContext +) -> dict: + request_in.config.__setattr__("store_context", True) + request_in.config.__setattr__("llm_alternatives", False) + request_in.config.__setattr__("plan", None) + request_in.config.__setattr__( + "alternatives_max_count", context.settings.alternatives_max_count + ) + + if not user_email: + request_in.config.__setattr__("disabled_categories", get_category_keys(True)) + request_in.config.__setattr__("plan", None) + + return {} + + try: + configs = await fetch_user_organization_configs(user_email, context) + except HTTPException: + return {} + + apply_configs(request_in, configs["config"], configs["plan"]) + + if "organization_config" in configs: + apply_configs( + request_in, + configs["organization_config"], + configs["plan"], + False, + ) + + configs["term_replacements"] |= configs["organization_term_replacements"] + configs["false_positives"] = list( + set(configs["false_positives"] + configs["organization_false_positives"]) + ) + + return configs + + +async def fetch_organization_configs_for_request( + request_in: BaseRequestIn, organization_id: Optional[str], context: AppContext +) -> dict: + request_in.config.__setattr__("store_context", True) + request_in.config.__setattr__("llm_alternatives", False) + + if not organization_id: + return {} + + try: + configs = await context.redis.fetch_organization_configs_from_redis( + organization_id + ) + except HTTPException: + return {} + + for config in configs["configs"]: + if configs["configs"][config]["status"] == "suggestion": + configs["configs"][config]["status"] = "force" + + apply_configs(request_in, configs["config"], configs["plan"]) + + return configs + + +def fetch_config_change( + configs: dict, + check_request_in: Optional[BaseRequestIn] = None, +) -> bool | None: + if not check_request_in: + return True + + if ( + "config_hash" in configs + and check_request_in.config_hash != configs["config_hash"] + ): + return True + + if ( + "organization_config_hash" in configs + and check_request_in.organization_config_hash + != configs["organization_config_hash"] + ): + return True + + return None + + +def fetch_result_conf(configs: dict) -> ResultConf | None: + if "config" not in configs: + return None + + organization_config = ( + RuleConfig.model_validate(configs["organization_config"]) + if "organization_config" in configs + else None + ) + + plan = configs["plan"] + + config = RuleConfig.model_validate(configs["config"]) + + return ResultConf( + id=configs["id"], + name=configs["name"], + plan=plan, + config=config, + organization_id=configs["organization_id"], + organization_name=configs["organization_name"], + organization_config=organization_config, + domains=configs["domains"], + organization_domains=configs["organization_domains"], + config_hash=configs["config_hash"], + organization_config_hash=configs["organization_config_hash"], + organization_trial_ends_at=configs["organization_trial_ends_at"], + ) + + +def debug_configs(request_in: BaseRequestIn) -> dict: + if "none" in request_in.config.disabled_categories: + request_in.config.__setattr__("disabled_categories", []) + elif request_in.config.disabled_categories == []: + request_in.config.__setattr__( + "disabled_categories", ["plain_language_advanced"] + ) + + configs = { + "categories": {}, + "llm_alternatives": {"status": "suggestion", "value": True}, + } + apply_configs(request_in, configs, "witty_teams") + + return configs diff --git a/app/context.py b/app/context.py index f8f514e6d..66f65679a 100644 --- a/app/context.py +++ b/app/context.py @@ -31,7 +31,7 @@ from app.rules import fetch_static_rules from app.rule_check import RuleCheck from app.sentry import set_up_sentry_sdk -from app.settings import Settings +from app.settings import Settings, get_settings from app.translations import translations from app.verbs import Verbs @@ -65,7 +65,7 @@ def __init__(self): self.declensions_config = declensions_config self.verb_form_map = verb_form_map self.categories = get_categories() - self.settings = Settings.factory() + self.settings = get_settings() self.logger = Logger.factory(self.settings) self.logger.debug("app started with settings: %s", self.settings) diff --git a/app/dependencies.py b/app/dependencies.py new file mode 100644 index 000000000..a8dec3afe --- /dev/null +++ b/app/dependencies.py @@ -0,0 +1,67 @@ +"""FastAPI dependencies for authentication and authorization. + +Also provides a DI-friendly accessor for the application context stored on +FastAPI's app.state during lifespan. No legacy global context fallback. +""" + +import secrets +from typing import Optional + +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from app.context import AppContext + + +def get_app_context(request: Request) -> AppContext: + context = getattr(request.app.state, "context", None) + if context is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Application context is not initialized", + ) + return context + + +security = HTTPBasic(auto_error=False) + + +def fetch_current_username( + request: Request, + credentials: Optional[HTTPBasicCredentials] = Depends(security), +) -> str: # pragma: no cover + # Resolve context (used below) + context = get_app_context(request) + + # Credentials are missing + if not credentials: + # Auth is disabled, just proceed + if not context.settings.api_docs_auth_enabled: + return "anon" + + # Auth is enabled, raise 401 + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Basic"}, + ) + + # Verify the credentials as usual + if not context.settings.api_docs_username or not context.settings.api_docs_password: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Incorrect user configuration", + ) + + correct_username = secrets.compare_digest( + credentials.username, context.settings.api_docs_username + ) + correct_password = secrets.compare_digest( + credentials.password, context.settings.api_docs_password + ) + if not (correct_username and correct_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Basic"}, + ) + + return credentials.username diff --git a/app/language_processor.py b/app/language_processor.py new file mode 100644 index 000000000..4ce6b8d01 --- /dev/null +++ b/app/language_processor.py @@ -0,0 +1,226 @@ +"""Language processing functions for text analysis and rule application.""" + +from collections import defaultdict +import json + +from spacy.tokens import Doc + +from app.context import AppContext +from app.models import ( + Alternative, + CheckRequestIn, + Client, + Config, + Language, + LangType, + ResultOut, + Rule, + WordType, +) +from app.helper import utf16_offsets +from app.rule_processors import witty_rules + + +def fetch_text( + check_request_in: CheckRequestIn, + supported_langs: list, + context: AppContext, +) -> tuple[str, Language | None, bool]: + text = check_request_in.text + limit_reached = len(text) > context.settings.text_max_length + if limit_reached: + text = text[0 : context.settings.text_max_length] + text = text.rsplit(" ", 1)[0] + + locale = context.lang_detection.get_locale( + supported_langs, + text, + check_request_in.lang, + check_request_in.config.preferred_languages, + check_request_in.config.preferred_variants, + ) + + language = None if locale is None else context.languages[locale] + + return text, language, limit_reached + + +async def apply_language_rules( + client: Client, + config: Config, + configs: dict, + language: Language, + text: str, + context: AppContext, +) -> list: + tokens = context.model.fetch_tokens(language.lang, text) + offsets = utf16_offsets(text) + + term_replacements = fetch_term_replacements(configs, language.lang, context) + + list_results = await witty_rules( + config, + term_replacements, + client, + tokens, + offsets, + language, + text, + context, + ) + + list_results = await context.languagetool.apply_languagetool_rules( + config, client, language, text, tokens, offsets + ) + await context_false_positives(language.lang, tokens, list_results, context) + + return apply_false_positives(list_results, configs) + + +def fetch_term_replacements( + configs: dict, + lang: LangType, + context: AppContext, +) -> list[Rule]: + if "term_replacements" not in configs: + return [] + + word_types = [WordType.VERB, WordType.NOUN, WordType.ADJECTIVE] + term_replacement_rules = [] + for lemma in configs["term_replacements"]: + if lemma[-3:] in context.term_replacement_langs: + if not lemma.endswith(lang): + continue + + term_replacement = configs["term_replacements"][lemma] + + alternatives = [] + for alternative in term_replacement["parsed_alternatives"]: + alternatives.append( + Alternative( + alternative["lemma"], + alternative["words"], + alternative["word_types"], + ) + ) + + rule = Rule( + term_replacement["lemma"], + lang, + term_replacement["lemma"], + term_replacement["words"], + term_replacement["word_types"], + "corporate_rules", + alternatives, + ) + + if term_replacement["explanation"] is not None: + rule.explanation = term_replacement["explanation"].get("text") + rule.url = term_replacement["explanation"].get("url") + rule.icon = term_replacement["explanation"].get("icon") + + if term_replacement["word_types"][0]["lower_case"]: + rule.false_positives = term_replacement["false_positives"] + else: + rule.case_sensitive_false_positives = term_replacement["false_positives"] + + # If it is not a lemmatized rule + rule.adapt_alternatives = ( + term_replacement["word_types"][0]["word_type"] in word_types + ) + term_replacement_rules.append(rule) + + return term_replacement_rules + + +def apply_false_positives( + list_results: list, + configs: dict, +) -> list: + if len(list_results) == 0: + return list_results + + false_positives = [] + if "false_positives" in configs: + false_positives = configs["false_positives"] + + if len(false_positives): + for result in list_results.copy(): + if result.text in false_positives: + list_results.remove(result) + + return list_results + + +async def context_false_positives( + lang: LangType, + tokens: Doc, + list_results: list[ResultOut], + context: AppContext, +) -> list[ResultOut]: + if ( + lang not in context.settings.context_checker + or len(context.static_rules[lang]["context_check"]) == 0 + ): + return list_results + + sentences = {} + sentences_to_check = defaultdict(list) + for result_index in range(len(list_results)): + result = list_results[result_index] + if result.text_id in context.static_rules[lang]["context_check"]: + if len(sentences) == 0: + for sentence in tokens.sents: + sentences[sentence.end_char] = sentence.text + + sentence = None + for end_char in sentences: + if result.end <= end_char: + sentence = sentences[end_char] + break + + if sentence is None: + continue + + sentences_to_check[sentence].append(result_index) + + if sentences_to_check == {}: + return list_results + + sentences = list(sentences_to_check.keys()) + + headers = { + "Content-Type": "application/json", + "Authorization": ( + "Bearer " + context.settings.context_checker[lang]["api_key"] + ), + } + + payload = { + "data": sentences, + } + + context_results = await context.http.fetch_json_post( + context.settings.context_checker[lang]["url"], + json.dumps(payload), + headers, + "context checker", + ) + + keys_to_remove = [] + for sentence_index in range(len(sentences)): + sentence = sentences[sentence_index] + if context_results[sentence_index] == "1": + continue + + for result_key in sentences_to_check[sentence]: + if result_key in keys_to_remove: + continue + + keys_to_remove.append(result_key) + + # Ensure we remove from the end so that the list indexes remain the same + keys_to_remove.sort(reverse=True) + for key_to_remove in keys_to_remove: + list_results.pop(key_to_remove) + + return list_results diff --git a/app/main.py b/app/main.py index db89a448c..251fab762 100644 --- a/app/main.py +++ b/app/main.py @@ -1,2578 +1,55 @@ -import os import uvicorn -import json -import secrets -from typing import Optional, Union -from collections import defaultdict -from inspect import currentframe -import logging -import re - -from spacy import displacy -from spacy.tokens import Doc - -from fastapi import ( - FastAPI, - Request, - Response, - HTTPException, - Depends, - status, -) - from contextlib import asynccontextmanager -from fastapi.encoders import jsonable_encoder -from fastapi.openapi.docs import get_swagger_ui_html -from fastapi.openapi.utils import get_openapi -from fastapi.security import ( - HTTPBasic, - HTTPBasicCredentials, - HTTPBearer, - APIKeyHeader, -) -from fastapi.exceptions import RequestValidationError -from fastapi.middleware.cors import CORSMiddleware -from starlette.responses import RedirectResponse - -from app.auth_service import get_unverified_token_claims, fetch_user +from fastapi import FastAPI -import secure +from app.startup import lifespan -from cmp_version import VersionString - -from slack_sdk import WebClient -from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler -from slack_bolt import Ack, Respond -from app.bolt import get_bolt, process_command_witty -from app.models import ( - Client, - Config, - GermanGenderEndingType, - LangType, - Language, - BaseRequestIn, - RephraseRequestIn, - CheckRequestIn, - Result, - ResultOut, - ResultsOut, - PromptOut, - RephrasesOut, - UserConfRequest, - OrganizationConfRequest, - ConfResponse, - UserConfResponse, - RuleConfig, - ResultConf, - ErrorMessage, - PrettyJSONResponse, - RuleIn, - Alternative, - Rule, - BasicWordType, - WordType, - MetricsType, - ReviewType, -) -from app.helper import is_valid_text, remove_gender_ending, utf16_offsets -from app.db import Db -from app.emoji_check import EmojiCheck -from app.rule_check import RuleCheck -from app.regex_check import RegexCheck -from app.nouns import Nouns -from app.verbs import Verbs -from app.adjectives import Adjectives -from app.prompt import Prompt -from app.llm_alternatives import LlmAlternatives -from app.review_prompt import ReviewPrompt -from app.categories import ( - get_category_keys, - get_parent_category_name, - is_sub_category_enabled, - make_category_advanced, -) -from app.alternatives import Alternatives -from app.languagetool import LanguageTool -from app.db import Db -from app.http import Http +from app.middleware import setup_middleware +from app.routes import register_routes from app.context import AppContext context = AppContext() +# Create FastAPI application with lifecycle management @asynccontextmanager -async def lifespan(app: FastAPI): - if os.environ.get("BLACKFIRE_ENABLE_CONTINUOUS_PROFILING"): - try: - from blackfire_conprof.profiler import Profiler - - app_name = os.environ.get("PLATFORM_APPLICATION_NAME") - # app_name += "-worker-%d" % (os.getpid(),) - profiler = Profiler(application_name=app_name) - profiler.start() - - print("Profiler started for %s" % app_name) - except: - pass - - global context - - context.http = Http(context.settings, context.logger) - - sqlite_logger = logging.getLogger("aiosqlite") - sqlite_logger.setLevel(logging.ERROR) - - sqlite_logger.setLevel(logging.WARNING) - - context.db = await Db.factory(context.settings, context.languages) - context.model.db = context.db +async def app_lifespan(app: FastAPI): + """Wrapper for lifespan with context.""" + async with lifespan(app, context): + yield - if context.settings.testing_rules: - rules = json.loads(context.settings.testing_rules) - rules["term_replacements"] = parse_term_replacements(rules["term_replacements"]) - email = rules["email"] - context.redis.db.set(context.redis.get_user_id(email), json.dumps(rules)) - - if context.settings.testing_organization_rules: - organization_rules = json.loads(context.settings.testing_organization_rules) - organization_rules["term_replacements"] = parse_term_replacements( - organization_rules["term_replacements"] - ) - key = organization_rules["id"] - context.redis.db.set(key, json.dumps(organization_rules)) - - if context.settings.redis_log_emails: - log_emails = json.loads(context.settings.redis_log_emails) - for email in log_emails: - context.redis.db.lpush("debug_emails", email) - - context.nouns = Nouns( - context.settings, - context.logger, - context.static_rules, - context.model, - context.db, - ) - context.verbs = Verbs( - context.settings, - context.logger, - context.static_rules, - context.model, - context.db, - ) - context.adjectives = Adjectives(context.settings, context.logger, context.db) - context.alternatives = Alternatives( - context.settings, - context.logger, - context.static_rules, - context.db, - context.model, - context.nouns, - context.verbs, - context.adjectives, - ) - context.languagetool = LanguageTool( - context.settings, - context.logger, - context.static_rules, - context.db, - context.categories, - context.http, - ) - context.prompt = Prompt(context.settings) - context.llm_alternatives = LlmAlternatives( - context.settings, context.alternatives, context.prompt - ) - context.rule_check = RuleCheck( - context.settings, - context.logger, - context.static_rules, - context.model, - context.db, - context.nouns, - context.verbs, - context.adjectives, - context.alternatives, - ) - context.regex_check = RegexCheck( - context.settings, context.logger, context.static_rules, context.nouns - ) - context.emoji_check = EmojiCheck( - context.settings, context.logger, context.static_rules - ) - - yield - - await context.http.close() - await context.db.close() - - -application_name = "Witty NLP API" app = FastAPI( - title=application_name, + title="Witty NLP API", version=context.version, terms_of_service=context.settings.terms_of_service, contact=context.settings.contact, docs_url=None, redoc_url=None, openapi_url=None, - lifespan=lifespan, -) - -security = HTTPBasic(auto_error=False) - -csp = secure.ContentSecurityPolicy().set("default-src 'self' cdn.jsdelivr.net") -hsts = secure.StrictTransportSecurity().include_subdomains().preload().max_age(31536000) -referrer = secure.ReferrerPolicy().no_referrer() -cache_value = secure.CacheControl().no_cache() -xfo = secure.XFrameOptions().deny() - -secure_headers = secure.Secure( - csp=csp, - hsts=hsts, - referrer=referrer, - cache=cache_value, - xfo=xfo, -) - -disabled_categories_api = [ - "communal", - "d_and_i", - "emotional_security", - "inclusive", - "orthography", -] - - -@app.middleware("http") -async def add_security_headers(request, call_next): - response = await call_next(request) - await secure_headers.set_headers_async(response) - return response - - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -bolt = get_bolt(context.settings) -bolt_handler = AsyncSlackRequestHandler(bolt) - - -def fetch_current_username( - credentials: Optional[HTTPBasicCredentials] = Depends(security), -) -> str: # pragma: no cover - """Verify HTTP Basic Auth credentials for API docs access. - - Args: - credentials: HTTP Basic Auth credentials from request - - Returns: - Username string ("anon" if auth disabled, or verified username) - - Raises: - HTTPException: If credentials invalid or auth misconfigured - """ - # Credentials are missing - if not credentials: - # Auth is disabled, just proceed - if not context.settings.api_docs_auth_enabled: - return "anon" - - # Auth is enabled, raise 401 - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - headers={"WWW-Authenticate": "Basic"}, - ) - - # Verify the credentials as usual - if not context.settings.api_docs_username or not context.settings.api_docs_password: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Incorrect user configuration", - ) - - correct_username = secrets.compare_digest( - credentials.username, context.settings.api_docs_username - ) - correct_password = secrets.compare_digest( - credentials.password, context.settings.api_docs_password - ) - if not (correct_username and correct_password): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - headers={"WWW-Authenticate": "Basic"}, - ) - - return credentials.username - - -@app.post( - "/debug/rephrase", - response_model=Union[RephrasesOut, Result], - response_model_exclude_none=True, - dependencies=[ - Depends(HTTPBearer(auto_error=False)), - Depends(APIKeyHeader(name="x-key", auto_error=False)), - ], - include_in_schema=not context.settings.is_prod, -) -async def post_debug_rephrase( - request: Request, - response: Response, - rephrase_request_in: RephraseRequestIn, - username: str = Depends(fetch_current_username), -): - return await rephrase_sentence(request, response, rephrase_request_in) - - -@app.post( - "/v1.0/rephrase", - response_model=Union[RephrasesOut, Result], - response_model_exclude_none=True, - dependencies=[ - Depends(HTTPBearer(auto_error=False)), - Depends(APIKeyHeader(name="x-key", auto_error=False)), - ], -) -async def post_rephrase_v1_0( - request: Request, - response: Response, - rephrase_request_in: RephraseRequestIn, -): - return await rephrase_sentence(request, response, rephrase_request_in, "1.0") - - -async def rephrase_sentence( - request: Request, - response: Response, - rephrase_request_in: RephraseRequestIn, - version: str | None = None, -): - client = parse_client(rephrase_request_in.client) - check_client_version(client) - - if version is not None: - if rephrase_request_in.model is not None: - return Result.factory("Model can only be set in debug mode") - - rephrase_api_version(version) - - user_email = await fetch_user( - request, context.settings, context.redis, context.http - ) - if user_email is None: - response.status_code = status.HTTP_401_UNAUTHORIZED - return Result.factory("User not found") - - configs = await fetch_configs_for_request(rephrase_request_in, user_email) - - if ( - rephrase_request_in.config.plan is None - or not rephrase_request_in.config.plan.startswith("witty_") - ): - response.status_code = status.HTTP_401_UNAUTHORIZED - return Result.factory("No valid plan on user") - - if not rephrase_request_in.config.llm_alternatives: - response.status_code = status.HTTP_403_FORBIDDEN - return Result.factory("Rephrasing via LLM not enabled on user") - else: - # debug - configs = {} - - context.redis.store_metrics(request, configs, version, "rephrase") - - try: - result = RephrasesOut.factory( - rephrase_request_in.sentence, - await context.llm_alternatives.handle(rephrase_request_in), - ) - except Exception as e: - response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - message = "An error occurred" - if not version: - message = f"{message}: {e}" - - return Result.factory(message) - - return result - - -@app.post( - "/debug/review_prompt", - response_model=Union[str, Result], - response_model_exclude_none=True, - dependencies=[ - Depends(HTTPBearer(auto_error=False)), - Depends(APIKeyHeader(name="x-key", auto_error=False)), - ], - include_in_schema=not context.settings.is_prod, -) -async def debug_review_prompt( - request: Request, - response: Response, - check_request_in: CheckRequestIn, - review_type: ReviewType = ReviewType.EXPLAIN_EDITS, -) -> Result | str: - for category in disabled_categories_api: - if category in check_request_in.config.disabled_categories: - continue - - check_request_in.config.disabled_categories.append(category) - - check_result = await check(request, response, check_request_in, None) - if isinstance(check_result, Result): - return check_result - - if len(check_result.results) == 0: - return "WITTYNOCHANGES" - - return ReviewPrompt.handle( - check_result.results, review_type, check_request_in.text, 1900 - ) - - -@app.post( - "/debug/prompt", - response_model=Union[Result, PromptOut, None], - response_model_exclude_none=True, - include_in_schema=not context.settings.is_prod, -) -async def debug_prompt( - request: Request, - response: Response, - check_request_in: CheckRequestIn, - username: str = Depends(fetch_current_username), -) -> Result | PromptOut: - configs = debug_configs(check_request_in) - return await prompt(response, check_request_in, configs) - - -@app.post( - "/v1.0/prompt", - response_model=Union[Result, PromptOut, None], - response_model_exclude_none=True, -) -async def post_prompt( - request: Request, - response: Response, - check_request_in: CheckRequestIn, - user_email: str, - username: str = Depends(fetch_current_username), -) -> Result | PromptOut: - configs = await fetch_configs_for_request(check_request_in, user_email) - if configs == {}: - response.status_code = status.HTTP_401_UNAUTHORIZED - return Result.factory("User config missing") - - context.redis.store_metrics(request, configs, "1.0", "prompt") - return await prompt(response, check_request_in, configs) - - -async def prompt( - response: Response, - check_request_in: CheckRequestIn, - configs: dict, -) -> Result | PromptOut: - if ( - check_request_in.config.plan is None - or not check_request_in.config.plan.startswith("witty_") - ): - response.status_code = status.HTTP_402_PAYMENT_REQUIRED - return Result.factory("Plan missing") - - if not check_request_in.config.llm_alternatives: - response.status_code = status.HTTP_401_UNAUTHORIZED - return Result.factory("User config disallows LLM use") - - check_request_in.text = await context.prompt.handle( - check_request_in.text, None, None, 0.4 - ) - check_request_in.text = context.prompt.parse_json(check_request_in.text) - - for category in disabled_categories_api: - if category in check_request_in.config.disabled_categories: - continue - - check_request_in.config.disabled_categories.append(category) - - text, language, limit_reached = fetch_text(check_request_in, context.langs) - - if language is None: - response.status_code = status.HTTP_422_UNPROCESSABLE_ENTITY - return Result.factory("Language could not be determined") - - client = parse_client(check_request_in.client) - check_result = await apply_language_rules( - client, check_request_in.config, configs, language, text - ) - - reviewed_response = None - if len(check_result) != 0: - review_prompt = ReviewPrompt.handle( - check_result, ReviewType.INCLUDE_PREVIOUS, check_request_in.text - ) - - reviewed_response = await context.prompt.handle(review_prompt) - reviewed_response = context.prompt.parse_json(reviewed_response) - - return PromptOut( - initial_response=check_request_in.text, - reviewed_response=reviewed_response, - check_results=check_result, - limit_reached=limit_reached, - ) - - -@bolt.command("/witty") -async def handle_command_witty( - body: dict, ack: Ack, respond: Respond, client: WebClient -): # pragma: no cover - await ack() - - check_request_in = CheckRequestIn(client="slack:1.0.0", text=body["text"]) - text, language, limit_reached = fetch_text(check_request_in, context.langs) - - if language is None: - await respond(f"Witty could not determine a language for '{text}'.") - return - - configs = {} - - try: - user = await client.users_info(user=body["user_id"]) - configs = await fetch_configs_for_request( - check_request_in, user.data["user"]["profile"]["email"] - ) - except KeyError: - pass - - if configs == {} and context.settings.slack_organization_id: - configs = await fetch_organization_configs_for_request( - check_request_in, context.settings.slack_organization_id - ) - - check_request_in.config.__setattr__("alternatives_max_count", None) - client = parse_client(check_request_in.client) - results = await apply_language_rules( - client, check_request_in.config, configs, language, text - ) - - return await process_command_witty(text, language, limit_reached, results, respond) - - -@app.post("/slack/commands") -async def post_slack_commands(request: Request): # pragma: no cover - return await bolt_handler.handle(request) - - -@app.get("/health") -async def get_health(check_external: bool = False): - health = {} - - langs = { - LangType.EN: "Hello guys", - LangType.DE: "Hallo Kunde", - LangType.FR: "Je m'appelle Luc", - } - - for lang in context.langs: - try: - context.model.fetch_tokens(lang, langs[lang]) - health["model_" + lang] = True - except Exception: - health["model_" + lang] = False - - if check_external: - try: - languagetool_health = await context.http.fetch_json_get( - context.settings.languagetool_api + "/healthcheck", - {}, - {}, - "LanguageTool", - context.settings.languagetool_verify_ssl, - False, - ) - - health["spelling"] = languagetool_health == "OK" - except Exception: - health["spelling"] = False - - try: - health["config"] = context.redis.db.ping() - except Exception: - health["config"] = False - - content = jsonable_encoder(health) - - for key in health: - if not health[key]: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=content, - ) - - return content - - -@app.get("/lt", include_in_schema=not context.settings.is_prod) -def get_lt(username: str = Depends(fetch_current_username)): - return context.settings.languagetool_api - - -@app.get("/settings", include_in_schema=not context.settings.is_prod) -def get_settings(username: str = Depends(fetch_current_username)): - return context.settings - - -@app.get("/docs", include_in_schema=False) -def get_swagger_documentation( - username: str = Depends(fetch_current_username), - include_in_schema=not context.settings.is_prod, -): # pragma: no cover - return get_swagger_ui_html(openapi_url="/openapi.json", title="docs") - - -@app.get("/openapi.json", include_in_schema=False) -def get_openapi_json( - username: str = Depends(fetch_current_username), -): # pragma: no cover - return get_openapi(title=app.title, version=app.version, routes=app.routes) - - -# public routes -@app.get("/", include_in_schema=False) -def get_root(): - if ( - not context.settings.is_prod and context.settings.testing is False - ): # pragma: no cover - return RedirectResponse(url="/docs", status_code=302) - - return application_name + ": https://witty.works" - - -@app.get( - "/save_openapi_json", - include_in_schema=not context.settings.is_prod, -) -def get_save_openapi_json( - username: str = Depends(fetch_current_username), -): # pragma: no cover - openapi_data = app.openapi() - for path in openapi_data["paths"].copy(): - if "v2.0" not in path: - del openapi_data["paths"][path] - - with open("openapi.json", "w") as file: - json.dump(openapi_data, file, indent=4, sort_keys=True) - - -@app.get( - "/debug/german_gender_ending", - include_in_schema=not context.settings.is_prod, - response_class=PrettyJSONResponse, -) -async def get_german_gender_ending( - alternative: str, - german_gender_ending: GermanGenderEndingType | None = None, - username: str = Depends(fetch_current_username), -): - """Debug endpoint to test German gender ending alternatives. - - Args: - alternative: German word to generate alternatives for - german_gender_ending: Gender ending type (defaults to inclusive + binary) - username: Authenticated username - - Returns: - List of gendered alternatives - """ - if not german_gender_ending: - inclusive = True - binary = True - else: - inclusive = Config.gendered_roles_format_inclusive(german_gender_ending) - binary = Config.gendered_roles_format_binary(german_gender_ending) - - alternatives, _ = await context.alternatives.german_gendered_alternatives( - Rule(""), - Alternative(alternative), - inclusive, - binary, - GermanGenderEndingType.STAR[0], - GermanGenderEndingType.STAR[0], - False, - ) - - return alternatives - - -@app.get( - "/debug/declension", - include_in_schema=not context.settings.is_prod, - response_class=PrettyJSONResponse, -) -async def get_declension_debug( - lang: LangType, - word_type: BasicWordType, - word: str, -): - return await context.db.fetch_declensions(lang, word_type, word) - - -@app.get( - "/debug/align_form", - include_in_schema=not context.settings.is_prod, - response_class=PrettyJSONResponse, + lifespan=app_lifespan, ) -async def get_align_form_debug( - lang: LangType, - word_type: BasicWordType, - index: int, - source_text: str, - target_text: str, -): - source_tokens = context.model.fetch_tokens(lang, source_text) - target_tokens = context.model.fetch_tokens(lang, target_text) - - target_form = await context.alternatives.find_form( - lang, word_type, index, source_tokens - ) - - if WordType.VERB == word_type: - return await context.verbs.align_form_verb( - lang, - target_form, - source_tokens[0].text, - source_tokens[0].lemma_, - target_tokens[0], - ) - - if WordType.ADJECTIVE == word_type: - return await context.adjectives.align_form_adjective( - lang, - target_form, - source_tokens[0].text, - source_tokens[0].lemma_, - target_tokens[0], - ) - - # if WordType.NOUN == word_type: - return context.nouns.align_form_noun( - lang, - target_form, - target_tokens[0], - ) +# Setup middleware (security headers, CORS) +setup_middleware(app) -@app.get( - "/debug/configs", - include_in_schema=not context.settings.is_prod, - dependencies=[ - Depends(HTTPBearer(auto_error=False)), - Depends(APIKeyHeader(name="x-key", auto_error=False)), - ], -) -async def get_config_debug( - user_email: str, - username: str = Depends(fetch_current_username), -): # pragma: no cover - check_request_in = CheckRequestIn(text="") +# Register all route modules +register_routes(app) +# Initialize Slack Bolt app and handlers with AppContext if enabled +if context.settings.slack_enabled: try: - configs = await fetch_user_organization_configs(user_email) - except HTTPException: - try: - configs = await context.redis.fetch_user_configs_from_redis(user_email) - except HTTPException: - configs = {} - - result_configs = await fetch_configs_for_request(check_request_in, user_email) - del result_configs["organization_config"] - del result_configs["organization_domains"] - del result_configs["organization_false_positives"] - del result_configs["organization_term_replacements"] - - return { - "configs": configs, - "result_configs": result_configs, - "check_request_in": check_request_in, - } - - -@app.post( - "/debug/auth", - include_in_schema=not context.settings.is_prod, - dependencies=[ - Depends(HTTPBearer(auto_error=False)), - Depends(APIKeyHeader(name="x-key", auto_error=False)), - ], -) -async def post_auth_debug( - request: Request, check_request_in: CheckRequestIn -): # pragma: no cover - user_email = await fetch_user( - request, context.settings, context.redis, context.http - ) - if not user_email: - return None - - configs = await fetch_configs_for_request(check_request_in, user_email) - - if "authorization" in request.headers and request.headers[ - "authorization" - ].lower().startswith("bearer"): - unverified_claims = get_unverified_token_claims(request) - else: - unverified_claims = "using auth token override" - - return { - "claim": unverified_claims, - "configs": configs, - "check_request_in": check_request_in, - } - - -@app.get( - "/debug/metrics", - dependencies=[ - Depends(HTTPBearer(auto_error=False)), - Depends(APIKeyHeader(name="x-key", auto_error=False)), - ], -) -async def get_user_configs( - key: MetricsType, - top_x: int | None, - username: str = Depends(fetch_current_username), -): - result = {} - keys = MetricsType if key == MetricsType.ALL else [key] - for _key in keys: - if _key == MetricsType.ALL: - continue - - metrics = context.redis.db.hgetall(_key) - - for a in metrics: - metrics[a] = int(metrics[a]) - - result[_key] = { - k: metrics[k] for k in sorted(metrics, key=metrics.get, reverse=True) - } - if top_x is not None: - result[_key] = { - dkey: value for dkey, value in list(result[_key].items())[0:top_x] - } - - if key != MetricsType.ALL: - return result[key] - - return result - - -@app.post( - "/v2.0/auth", - response_model=Union[ResultConf, dict, None], - response_model_exclude_none=True, - dependencies=[ - Depends(HTTPBearer(auto_error=False)), - Depends(APIKeyHeader(name="x-key", auto_error=False)), - ], -) -async def post_auth_2_0( - request: Request, check_request_in: BaseRequestIn | None = None -): - client = parse_client( - check_request_in.client if check_request_in is not None else None - ) - check_client_version(client) - - user_email = await fetch_user( - request, context.settings, context.redis, context.http - ) - configs = ( - await fetch_configs_for_request(CheckRequestIn(text=""), user_email) - if user_email - else {} - ) - - context.redis.store_metrics(request, configs, "2.0", "auth") - - if configs == {}: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - ) - - config = fetch_result_conf(configs) - - if "team_analytics" in configs and not configs["team_analytics"]: - config.organization_id = None - - return config - - -@app.post( - "/debug/rule", - include_in_schema=not context.settings.is_prod, - response_model=list[ResultOut], - response_model_exclude_none=True, -) -async def post_debug_rule( - rule_data: RuleIn, - username: str = Depends(fetch_current_username), -): - language = context.languages[rule_data.lang] - config = Config(plan="witty_teams") - - tokens = context.model.fetch_tokens(language.lang, rule_data.text) - for token in tokens: - word_type = await context.model.fetch_word_type(language.lang, token) - for lemmatization in rule_data.lemmatizations: - if token.text.lower() == lemmatization.text.lower() and ( - word_type == lemmatization.word_type or lemmatization.word_type == "" - ): - token.lemma_ = lemmatization.text - break - - offsets = utf16_offsets(rule_data.text) - false_positive_matcher = context.model.fetch_false_positive_matchers( - language.lang, tokens - ) - - if rule_data.alternatives is not None: - alternative_list = [] - for alternative_in in rule_data.alternatives: - alternative = Alternative( - alternative_in.lemma, - context.model.tokenize(alternative_in.lemma, rule_data.lang), - alternative_in.word_types, - alternative_in.is_remove, - alternative_in.is_inspiration, - alternative_in.is_placeholder, - alternative_in.is_advanced, - alternative_in.is_collective_noun, - alternative_in.is_gendered_noun, - alternative_in.label, - ) - - alternative_list.append(alternative) - else: - alternative_list = None - - rule = Rule( - "test", - rule_data.lang, - rule_data.lemma, - context.model.tokenize(rule_data.lemma, rule_data.lang), - rule_data.word_types, - rule_data.subcategories, - None, - rule_data.actual_word_types, - ) - - rule.dynamic.alternatives = alternative_list - rule.pattern = rule_data.pattern - rule.is_pattern_match = rule_data.is_pattern_match - rule.false_positives = rule_data.false_positives - rule.label = rule_data.label - rule.type = rule_data.type - rule.entity_type = rule_data.entity_type - rule.pluralization = rule_data.pluralization - rule.adapt_alternatives = bool(len(alternative_list)) - - rules = [rule] - - list_full = [] - client = parse_client("debug:" + context.version) - - token_index = 0 - token_count = len(tokens) - while token_index < token_count: - if rule_data.lang == LangType.DE: - tokens[token_index].lemma_ = await german_lemmatization(tokens, token_index) - - for rule in rules: - await context.rule_check.handle( - config, - client, - language, - rule_data.text, - token_index, - tokens, - offsets, - list_full, - [rule], - false_positive_matcher, - ) - - token_index += 1 - - return list_full - - -@app.get( - "/debug/spacy", - include_in_schema=not context.settings.is_prod, - response_class=PrettyJSONResponse, -) -async def get_debug_spacy( - text: str, - lang: LangType, - detailed: bool = False, - username: str = Depends(fetch_current_username), -): - results = [] - tokens = context.model.fetch_tokens(lang, text) - - word_type_parts = [] - for token_index in range(len(tokens)): - token = tokens[token_index] - - if lang == LangType.DE: - token.lemma_ = await german_lemmatization(tokens, token_index) - word_type = await context.model.fetch_word_type(lang, token) - - # Get string value for word_type, prefix with '~' if text != lemma - word_type_str = getattr(word_type, "value", word_type) - if token.text != token.lemma_: - word_type_str = f"~{word_type_str}" - word_type_parts.append(word_type_str) - - token_info = { - "text": token.text, - "lemma": token.lemma_, - "word_type": word_type, - "is_singular": context.model.is_token_singular(lang, token), - "ner": token.ent_type_, - } - - if detailed: - token_info["start"] = token.idx - token_info["whitespace"] = token.whitespace_ - token_info["emoji_desc"] = token._.emoji_desc - token_info["is_emoji"] = token._.is_emoji - token_info["morph"] = token.morph.to_dict() - token_info["tag"] = token.tag_ - token_info["pos"] = token.pos_ - token_info["dep"] = token.dep_ - token_info["head"] = token.head.text - - dependent = None - children = [] - for a in token.ancestors: - for atok in a.children: - children.append( - {"dep": atok.dep_, "token": atok.text, "ner": atok.ent_type_} - ) - if dependent is None and atok.dep_ in ["pobj", "dobj"]: - dependent = atok.text - - token_info["dependent"] = dependent - token_info["children"] = children - - results.append(token_info) - - if detailed: - noun_chunks = [] - for chunk in tokens.noun_chunks: - noun_chunks.append( - { - "text": chunk.text, - "start": chunk.start, - "end": chunk.end, - } - ) - - results = [{"noun chunks": noun_chunks}] + results - - # Build the word type rule string - word_type_rule = "|".join(word_type_parts) - return [{"auto-detected word type": word_type_rule}] + results - - -@app.get( - "/debug/displacy", - include_in_schema=not context.settings.is_prod, -) -async def get_debug_displacy( - text: str, - lang: LangType, - username: str = Depends(fetch_current_username), -): - tokens = context.model.fetch_tokens(lang, text) - - sentence_spans = list(tokens.sents) - data = displacy.render(sentence_spans, style="dep") - return Response(content=data, media_type="image/svg+xml") - - -@app.get( - "/debug/german_noun", - include_in_schema=not context.settings.is_prod, - response_class=PrettyJSONResponse, -) -async def get_debug_german_noun( - word: str, - username: str = Depends(fetch_current_username), -): - return await context.nouns.german_noun_lookup(word) - - -@app.post( - "/debug/check", - response_model=Union[ResultsOut, Result], - response_model_exclude_none=True, - dependencies=[ - Depends(HTTPBearer(auto_error=False)), - Depends(APIKeyHeader(name="x-key", auto_error=False)), - ], - include_in_schema=not context.settings.is_prod, -) -async def post_debug_check( - request: Request, - response: Response, - check_request_in: CheckRequestIn, - username: str = Depends(fetch_current_username), -): - return await check(request, response, check_request_in) - - -@app.post( - "/v2.4/check", - response_model=Union[ResultsOut, Result], - response_model_exclude_none=True, - dependencies=[ - Depends(HTTPBearer(auto_error=False)), - Depends(APIKeyHeader(name="x-key", auto_error=False)), - ], -) -async def post_check_v2_4( - request: Request, - response: Response, - check_request_in: CheckRequestIn, -): - return await check(request, response, check_request_in, "2.4") - - -@app.get("/lemmatize") -async def get_lemmatize( - text: str, - lang: LangType, - all: bool = False, - username: str = Depends(fetch_current_username), -): - tokens = context.model.fetch_tokens(lang, text) - if all: - return tuple([i.lemma_ for i in tokens]) - - if len(tokens) != 1: - return None - - return tokens[0].lemma_ - - -@app.get("/tokenize") -async def get_tokenize( - text: str, - lang: LangType, - username: str = Depends(fetch_current_username), -): - return context.model.tokenize(text, lang) - - -@app.get("/parse-word-types") -async def get_tokenize( - text: str, - word_types: str, - lang: LangType, - username: str = Depends(fetch_current_username), -): - tokens = context.model.fetch_tokens(lang, text) - word_type_list = word_types.split("|") - - if len(tokens) != len(word_type_list): - raise RequestValidationError( - f"Word type '{word_types}' count does not match text token count '{len(tokens)}' for text '{text}'." - ) - - parsed_word_types = [] - for word_type in word_type_list: - parsed_word_type, lower_case, lemmatize = parse_word_type(word_type) - - if ( - parsed_word_type != "" - and parsed_word_type not in context.supported_word_types - ): - raise RequestValidationError( - f"Word type '{word_type}' within '{word_types}' contains unsupported word type '{parsed_word_type}'" - ) - - parsed_word_types.append( - { - "word_type": parsed_word_type, - "lower_case": lower_case, - "lemmatize": lemmatize, - } - ) - - return parsed_word_types - - -@app.post( - "/organization/configs", - status_code=status.HTTP_204_NO_CONTENT, -) -async def post_organization_configs( - organization_configs: OrganizationConfRequest, - username: str = Depends(fetch_current_username), -): - organization_configs.term_replacements = parse_term_replacements( - organization_configs.term_replacements - ) - - context.redis.db.set( - organization_configs.id, organization_configs.model_dump_json() - ) - - -@app.delete( - "/organization/configs", - status_code=status.HTTP_204_NO_CONTENT, -) -async def delete_organization_configs( - organization_id: str, - username: str = Depends(fetch_current_username), -): - context.redis.db.delete(organization_id) - - -@app.get( - "/organization/configs", - response_model=ConfResponse, - response_model_exclude_none=True, - responses={404: {"model": ErrorMessage}}, -) -async def get_organization_configs( - organization_id: str, - username: str = Depends(fetch_current_username), -): - return await context.redis.fetch_organization_configs_from_redis(organization_id) - - -@app.post( - "/user/configs", - status_code=status.HTTP_204_NO_CONTENT, -) -async def post_user_configs( - user_configs: UserConfRequest, username: str = Depends(fetch_current_username) -): - user_configs.term_replacements = parse_term_replacements( - user_configs.term_replacements - ) - - context.redis.db.set( - context.redis.get_user_id(user_configs.email), user_configs.model_dump_json() - ) - - -@app.delete( - "/user/configs", - status_code=status.HTTP_204_NO_CONTENT, -) -async def delete_user_configs( - email: str, - username: str = Depends(fetch_current_username), -): - context.redis.db.delete(context.redis.get_user_id(email)) - - -@app.get( - "/user/configs", - response_model=UserConfResponse, - response_model_exclude_none=True, - responses={404: {"model": ErrorMessage}}, -) -async def get_user_configs( - email: str, - username: str = Depends(fetch_current_username), -): - return await fetch_user_organization_configs(email) - - -@app.get( - "/user/logs", - response_class=PrettyJSONResponse, -) -async def get_user_logs( - email: str, - username: str = Depends(fetch_current_username), -): - return context.redis.get_user_logs(email) - - -@app.get( - "/api_key", - status_code=status.HTTP_204_NO_CONTENT, - responses={404: {"model": ErrorMessage}}, -) -async def get_api_key( - api_key: str, - username: str = Depends(fetch_current_username), -): - """Get email associated with an API key. - - Args: - api_key: API key to lookup - username: Authenticated username - - Returns: - Email address associated with the key - - Raises: - HTTPException: If API key not found - """ - email = context.redis.db.get(f"api_key:{api_key}") - if not email: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="API key not found" - ) - - return email - - -@app.post( - "/api_key", - status_code=status.HTTP_204_NO_CONTENT, -) -async def post_api_key( - api_key: str, email: str, username: str = Depends(fetch_current_username) -): - """Associate an email with an API key. - - Args: - api_key: API key to store - email: Email address to associate - username: Authenticated username - """ - context.redis.db.set(f"api_key:{api_key}", email) - - -@app.delete( - "/api_key", - status_code=status.HTTP_204_NO_CONTENT, -) -async def delete_api_key( - api_key: str, - username: str = Depends(fetch_current_username), -): - context.redis.db.delete("api_key:" + api_key) - - -def fetch_text( - check_request_in: CheckRequestIn, supported_langs: list -) -> tuple[str, Language | None, bool]: - text = check_request_in.text - limit_reached = len(text) > context.settings.text_max_length - if limit_reached: - text = text[0 : context.settings.text_max_length] - text = text.rsplit(" ", 1)[0] - - locale = context.lang_detection.get_locale( - supported_langs, - text, - check_request_in.lang, - check_request_in.config.preferred_languages, - check_request_in.config.preferred_variants, - ) - - language = None if locale is None else context.languages[locale] - - return text, language, limit_reached - - -def parse_term_replacement(lemma, term_replacement: dict): - word_type = ( - term_replacement["word_type"] if "word_type" in term_replacement else "~" - ) - - word_type, lower_case, lemmatize = parse_word_type(word_type) - - word_type = tuple( - [ - { - "word_type": word_type, - "lower_case": lower_case, - "lemmatize": lemmatize, - } - ] - ) - - if lower_case and not lemmatize: - lemma = lemma.lower() - - if list(filter(lemma.endswith, context.term_replacement_langs)) != []: - term_replacement["lang"] = lemma[-2:] - term_replacement["lemma"] = lemma[0:-3] - else: - term_replacement["lang"] = None - term_replacement["lemma"] = lemma - - lang = LangType.EN if term_replacement["lang"] is None else term_replacement["lang"] - term_replacement["words"] = context.model.tokenize(term_replacement["lemma"], lang) - term_replacement["word_types"] = word_type * len(term_replacement["words"]) - - term_replacement["false_positives"] = [] - term_replacement["parsed_alternatives"] = [] - for alternative in term_replacement["alternatives"]: - term_replacement["false_positives"].append(alternative) - - alternative = {"lemma": alternative} - alternative["words"] = context.model.tokenize(alternative["lemma"], lang) - alternative["word_types"] = word_type * len(alternative["words"]) - term_replacement["parsed_alternatives"].append(alternative) - - return term_replacement - - -def parse_term_replacements(term_replacements_source: dict | None = None): - term_replacements = {} - if term_replacements_source is not None: - for lemma in term_replacements_source: - term_replacement = dict(term_replacements_source[lemma]) - term_replacement = parse_term_replacement(lemma, term_replacement) - - term_replacements[lemma] = term_replacement - - return term_replacements - - -async def fetch_user_organization_configs(email: str) -> dict | None: - configs = await context.redis.fetch_user_configs_from_redis(email) - - configs["organization_name"] = None - configs["organization_config_hash"] = None - configs["organization_domains"] = None - configs["organization_trial_ends_at"] = None - - if "organization_id" in configs and configs["organization_id"] is not None: - organization_configs = ( - await context.redis.fetch_organization_configs_from_redis( - configs["organization_id"] - ) - ) - - if not configs.get("plan"): - configs["plan"] = organization_configs["plan"] - - if "trial_ends_at" in organization_configs: - configs["organization_trial_ends_at"] = organization_configs[ - "trial_ends_at" - ] - - configs["organization_name"] = organization_configs["name"] - - configs["organization_config_hash"] = organization_configs.get("config_hash") - - configs["organization_domains"] = organization_configs.get("domains", {}) - - configs["organization_config"] = organization_configs["config"] - - configs["organization_term_replacements"] = organization_configs[ - "term_replacements" - ] - - configs["organization_false_positives"] = organization_configs[ - "false_positives" - ] - else: - configs["organization_id"] = None - - return configs - - -def apply_configs( - check_request_in: CheckRequestIn, - configs: dict, - plan: str, - force_disables: bool = True, -): - disabled_categories = check_request_in.config.disabled_categories - if "force_categories" not in configs or configs["force_categories"] is None: - configs["force_categories"] = [] - - for config in configs: - if config == "force_categories": - continue - - data = configs[config] - if data is None: - continue - - if config == "categories": - for category in data: - category_data = data[category] - if category_data["status"] != "force": - continue - - # BC handling for old category names -> needs to be fixed in the dashboard - if category.startswith("advanced_"): - category = make_category_advanced( - category.removeprefix("advanced_") - ) - - if category_data["value"]: - if category in disabled_categories: - disabled_categories.remove(category) - else: - force_disables_category = force_disables - if not force_disables_category and len(configs["force_categories"]): - parent_category = get_parent_category_name(category) - force_disables_category = ( - parent_category in configs["force_categories"] - ) - - if force_disables_category and category not in disabled_categories: - disabled_categories.append(category) - elif config == "store_context": - if ( - plan is not None - and plan != "witty_free" - and data["status"] == "force" - and not data["value"] - ): - check_request_in.config.__setattr__("store_context", False) - elif config == "llm_alternatives": - if ( - plan is not None - and plan != "witty_free" - and data["status"] == "force" - and data["value"] - ): - check_request_in.config.__setattr__("llm_alternatives", True) - elif data["status"] == "force": - check_request_in.config.__setattr__(config, data["value"]) - - for category in disabled_categories_api: - if ( - category in check_request_in.config.disabled_categories - and category not in disabled_categories - ): - disabled_categories.append(category) - - check_request_in.config.__setattr__("disabled_categories", disabled_categories) - check_request_in.config.__setattr__("plan", plan) - - -async def fetch_configs_for_request( - request_in: BaseRequestIn, user_email=Optional[str] -) -> dict: - request_in.config.__setattr__("store_context", True) - request_in.config.__setattr__("llm_alternatives", False) - request_in.config.__setattr__("plan", None) - request_in.config.__setattr__( - "alternatives_max_count", context.settings.alternatives_max_count - ) - - if not user_email: - request_in.config.__setattr__("disabled_categories", get_category_keys(True)) - request_in.config.__setattr__("plan", None) - - return {} - - try: - configs = await fetch_user_organization_configs(user_email) - except HTTPException: - return {} - - apply_configs(request_in, configs["config"], configs["plan"]) - - if "organization_config" in configs: - apply_configs( - request_in, - configs["organization_config"], - configs["plan"], - False, - ) - - configs["term_replacements"] |= configs["organization_term_replacements"] - configs["false_positives"] = list( - set(configs["false_positives"] + configs["organization_false_positives"]) - ) - - return configs - - -async def fetch_organization_configs_for_request( - request_in: BaseRequestIn, organization_id=Optional[str] -) -> dict: - request_in.config.__setattr__("store_context", True) - request_in.config.__setattr__("llm_alternatives", False) - - if not organization_id: - return {} - - try: - configs = await context.redis.fetch_organization_configs_from_redis( - organization_id - ) - except HTTPException: - return {} - - for config in configs["configs"]: - if configs["configs"][config]["status"] == "suggestion": - configs["configs"][config]["status"] = "force" - - apply_configs(request_in, configs["config"], configs["plan"]) - - return configs - - -def rephrase_api_version(version: str): - if version != "1.0": # pragma: no cover - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"API version '{version}' not supported, please use version '1.0'.", - ) - - -def check_api_version(version: str): - if version != "2.4": # pragma: no cover - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"API version '{version}' not supported, please use version '2.4'.", - ) - - -def check_client_version(client: Client): - if ( - client.name in context.settings.minimum_versions - and client.version - < VersionString(context.settings.minimum_versions[client.name]) - ): # pragma: no cover - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Client version '{client.version}' not supported, please use at least '{context.settings.minimum_versions[client.name]}'.", - ) - - -def debug_configs( - request_in: BaseRequestIn, -): - if "none" in request_in.config.disabled_categories: - request_in.config.__setattr__("disabled_categories", []) - elif request_in.config.disabled_categories == []: - request_in.config.__setattr__( - "disabled_categories", ["plain_language_advanced"] - ) - - configs = { - "categories": {}, - "llm_alternatives": {"status": "suggestion", "value": True}, - } - apply_configs(request_in, configs, "witty_teams") - - return configs - - -async def check( - request: Request, - response: Response, - check_request_in: CheckRequestIn, - version: str | None = None, -) -> Result | ResultsOut: - client = parse_client(check_request_in.client) - check_client_version(client) - - user_email = None - if version is not None: - check_api_version(version) - - user_email = await fetch_user( - request, context.settings, context.redis, context.http - ) - - configs = await fetch_configs_for_request(check_request_in, user_email) - else: - configs = debug_configs(check_request_in) - - context.redis.store_metrics(request, configs, version, "check") - - context.redis.store_request_log( - check_request_in, - user_email, - request, - configs, - version, - "check", - ) - - if ( - check_request_in.config.plan is not None - and check_request_in.config.plan.startswith("witty_") - ): - text, language, limit_reached = fetch_text(check_request_in, context.langs) - - if language is None: - response.status_code = status.HTTP_422_UNPROCESSABLE_ENTITY - return Result.factory("Language could not be determined") - - results = await apply_language_rules( - client, check_request_in.config, configs, language, text - ) - - lang = language.lang - - if isinstance(results, Result): - return results - else: - results = [] - lang = LangType.EN - limit_reached = False - - notifications = None - if "notifications" in configs and configs["notifications"] > 0: - notifications = configs["notifications"] - - has_consented_to_mailing = None - if "has_consented_to_mailing" in configs: - has_consented_to_mailing = configs["has_consented_to_mailing"] - - if not isinstance(results, Result): - results = ResultsOut( - results=results, - language=lang, - limit_reached=limit_reached, - config_changed=fetch_config_change(configs, check_request_in), - notifications=notifications, - has_consented_to_mailing=has_consented_to_mailing, - gender_separator=check_request_in.config.get_gender_separator(lang), - ) - - context.redis.store_response_log( - user_email, - results, - ) - - return results - - -def fetch_config_change( - configs: dict, - check_request_in: Optional[BaseRequestIn] = None, -) -> bool | None: - if not check_request_in: - return True - - if ( - "config_hash" in configs - and check_request_in.config_hash != configs["config_hash"] - ): - return True - - if ( - "organization_config_hash" in configs - and check_request_in.organization_config_hash - != configs["organization_config_hash"] - ): - return True - - return None - - -def fetch_result_conf(configs: dict) -> ResultConf | None: - if "config" not in configs: - return None - - organization_config = ( - RuleConfig.model_validate(configs["organization_config"]) - if "organization_config" in configs - else None - ) - - plan = configs["plan"] - - config = RuleConfig.model_validate(configs["config"]) - - return ResultConf( - id=configs["id"], - name=configs["name"], - plan=plan, - config=config, - organization_id=configs["organization_id"], - organization_name=configs["organization_name"], - organization_config=organization_config, - domains=configs["domains"], - organization_domains=configs["organization_domains"], - config_hash=configs["config_hash"], - organization_config_hash=configs["organization_config_hash"], - organization_trial_ends_at=configs["organization_trial_ends_at"], - ) - - -def parse_client(client: str) -> Client: - if client is None: - client = "0.0.0" - - if ":" in client: - client = client.split(":") - else: - client = ["web-ext", client] - - return Client(name=client[0], version=client[1]) - - -async def apply_language_rules( - client: Client, - config: Config, - configs: dict, - language: Language, - text: str, -) -> list: - tokens = context.model.fetch_tokens(language.lang, text) - offsets = utf16_offsets(text) - - term_replacements = fetch_term_replacements(configs, language.lang) - - list_results = await witty_rules( - config, - term_replacements, - client, - tokens, - offsets, - language, - text, - ) - - list_results = await context.languagetool.apply_languagetool_rules( - config, client, language, text, tokens, offsets - ) + await context_false_positives(language.lang, tokens, list_results) - - return apply_false_positives(list_results, configs) - - -def fetch_term_replacements( - configs: dict, - lang: LangType, -) -> list[Rule]: - if "term_replacements" not in configs: - return [] - - word_types = [WordType.VERB, WordType.NOUN, WordType.ADJECTIVE] - term_replacement_rules = [] - for lemma in configs["term_replacements"]: - if lemma[-3:] in context.term_replacement_langs: - if not lemma.endswith(lang): - continue - - term_replacement = configs["term_replacements"][lemma] - - alternatives = [] - for alternative in term_replacement["parsed_alternatives"]: - alternatives.append( - Alternative( - alternative["lemma"], - alternative["words"], - alternative["word_types"], - ) - ) - - rule = Rule( - term_replacement["lemma"], - lang, - term_replacement["lemma"], - term_replacement["words"], - term_replacement["word_types"], - "corporate_rules", - alternatives, - ) - - if term_replacement["explanation"] is not None: - rule.explanation = term_replacement["explanation"].get("text") - rule.url = term_replacement["explanation"].get("url") - rule.icon = term_replacement["explanation"].get("icon") - - if term_replacement["word_types"][0]["lower_case"]: - rule.false_positives = term_replacement["false_positives"] - else: - rule.case_sensitive_false_positives = term_replacement["false_positives"] - - # If it is not a lemmatized rule (and for lemmatization we only support single words as term replacements) - rule.adapt_alternatives = ( - term_replacement["word_types"][0]["word_type"] in word_types - ) - term_replacement_rules.append(rule) - - return term_replacement_rules - - -def apply_false_positives( - list_results: list, - configs: dict, -) -> list: - if len(list_results) == 0: - return list_results - - false_positives = [] - if "false_positives" in configs: - false_positives = configs["false_positives"] - - if len(false_positives): - for result in list_results.copy(): - if result.text in false_positives: - list_results.remove(result) - - return list_results - - -async def context_false_positives( - lang: LangType, tokens: Doc, list_results: list[ResultOut] -): - if ( - lang not in context.settings.context_checker - or len(context.static_rules[lang]["context_check"]) == 0 - ): - return list_results - - sentences = {} - sentences_to_check = defaultdict(list) - for result_index in range(len(list_results)): - result = list_results[result_index] - if result.text_id in context.static_rules[lang]["context_check"]: - if len(sentences) == 0: - for sentence in tokens.sents: - sentences[sentence.end_char] = sentence.text - - sentence = None - for end_char in sentences: - if result.end <= end_char: - sentence = sentences[end_char] - break - - if sentence is None: - continue - - sentences_to_check[sentence].append(result_index) - - if sentences_to_check == {}: - return list_results - - sentences = list(sentences_to_check.keys()) - - headers = { - "Content-Type": "application/json", - "Authorization": ( - "Bearer " + context.settings.context_checker[lang]["api_key"] - ), - } - - payload = { - "data": sentences, - } - - context_results = await context.http.fetch_json_post( - context.settings.context_checker[lang]["url"], - json.dumps(payload), - headers, - "context checker", - ) - - keys_to_remove = [] - for sentence_index in range(len(sentences)): - sentence = sentences[sentence_index] - if context_results[sentence_index] == "1": - continue - - for result_key in sentences_to_check[sentence]: - if result_key in keys_to_remove: - continue - - keys_to_remove.append(result_key) - - # ensure we remove from the end so that the list indexes remain the same - keys_to_remove.sort(reverse=True) - for key_to_remove in keys_to_remove: - list_results.pop(key_to_remove) - - return list_results - - -def check_continue( - list_full: list, token_index: int, new_token_index: int, tokens: Doc, func_name: str -): - if new_token_index == token_index: - return False - - if new_token_index < token_index: - cf = currentframe() - - text_id = list_full[-1].text_id if len(list_full) else "" - - context.logger.error( - "Incorrect new_token_index on line %i using '%s': expected %i < %i for '%s' versus '%s' for text_id '%s'", - cf.f_back.f_lineno, - func_name, - token_index, - new_token_index, - tokens[token_index].text, - tokens[new_token_index].text, - text_id, - ) - - return False - - return True - - -async def german_lemmatization(tokens: Doc, token_index: int): - token = tokens[token_index] - word_type = await context.model.fetch_word_type(LangType.DE, token) - - match word_type: - case WordType.NOUN: - if ( - token.text != token.lemma_ - or not token.text[0].isupper() - or len(token.text) <= 3 - ): - return token.lemma_ - - word = remove_gender_ending(token.text) - - result = await context.nouns.german_noun_lookup(word, token) - if result is not None: - target = "male_form" if result["male_form"] else "base_form" - return result[target] - case WordType.VERB: - verb_form = token.morph.get("VerbForm") - verb_form = verb_form[0] if len(verb_form) else "" - - if verb_form not in context.verb_form_map: - return token.lemma_ - - column_name = False - if isinstance(context.verb_form_map[verb_form], dict): - tense = token.morph.get("Tense") - tense = tense[0] if len(tense) else "" - person = token.morph.get("Person") - person = person[0] if len(person) else "" - - if ( - tense in context.verb_form_map[verb_form] - and person in context.verb_form_map[verb_form][tense] - ): - parameters = [token.text + "%"] - operator = "LIKE" - column_name = context.verb_form_map[verb_form][tense][person] - else: - if token_index > 0 and tokens[token_index - 1].text == "zu": - parameters = ["zu " + token.text] - prev = True - else: - parameters = [token.text] - prev = False - operator = "=" - column_name = context.verb_form_map[verb_form] - - if column_name: - table_name = context.declensions_config[LangType.DE][WordType.VERB][ - "name" - ] - query = f"SELECT base_form, {column_name} FROM {table_name} WHERE {column_name} {operator} ? LIMIT 1" - - rows = await context.db.fetch_rows(query, parameters) - if len(rows): - if operator == "LIKE": - token_index_offset = 1 - for sentence_token in token.sent: - if sentence_token.i > token.i: - token_index_offset += 1 - if ( - tokens[token_index].text - + " " - + sentence_token.text.lower() - == rows[0][1] - ): - if ( - sentence_token.text.lower() == "schwarz" - and token_index_offset > 2 - ): - sentence_token._.connected_token = token - token._.child_token = sentence_token - token._.label = sentence_token._.label = ( - tokens[token_index].text - + " .. " - + sentence_token.text.lower() - ) - else: - token._.token_index_offset = token_index_offset - - await context.db.fetch_declensions( - LangType.DE, WordType.VERB, rows[0][0], token - ) - - token._.form = column_name - if token_index_offset == 2: - token._.text = ( - tokens[token_index].text - + tokens[token_index].whitespace_ - + sentence_token.text - ) - return rows[0][0] - return token.lemma_ - - if prev: - token._.start = tokens[token_index - 1].idx - token._.text = ( - tokens[token_index - 1].text - + tokens[token_index - 1].whitespace_ - + tokens[token_index].text - ) - token._.form = column_name - - await context.db.fetch_declensions( - LangType.DE, WordType.VERB, rows[0][0] - ) - return rows[0][0] - - return token.lemma_ - - -async def german_gender_endings( - config: Config, - client: Client, - tokens: Doc, - offsets: dict, - language: Language, - text: str, - token_index: int, - list_full: list, -) -> int: - # shallow check to see if any of the delimeters is even contained - if not re.search("[/):_*I]", text): - return token_index - - subcategory = "d_and_i" - if is_sub_category_enabled(config.disabled_categories, subcategory): - word_types = ( - (-1, 1, config.german_gender_ending[0]) - if config.german_gender_ending.startswith("/") - else (None, None, config.german_gender_ending[0]) - ) - - endings = [ - Rule( - config.german_gender_ending + "", - LangType.DE, - config._gendereddenom_ending[config.german_gender_ending], - None, - config._gendereddenom_ending_word_type[config.german_gender_ending], - subcategory, - ), - ] - - if config.german_gender_ending in config._gendereddenom_ending_article: - endings.append( - Rule( - config.german_gender_ending + " article", - LangType.DE, - config._gendereddenom_ending_article[config.german_gender_ending], - None, - word_types, - subcategory, - ) - ) - - new_token_index = await context.regex_check.handle( - config, - client, - language, - text, - token_index, - tokens, - offsets, - list_full, - endings, - ) - - if check_continue( - list_full, token_index, new_token_index, tokens, "regex_match" - ): - return new_token_index - - subcategory = "gendered_denominations_ending_advanced" - if is_sub_category_enabled( - config.disabled_categories, subcategory - ) and Config.gendered_roles_format_inclusive(config.gendered_roles_format): - endings = [] - for key, regexp in config._gendereddenom_ending.items(): - if config.german_gender_ending == key: - continue - - ending = Rule( - key + "", - LangType.DE, - regexp, - None, - config._gendereddenom_ending_word_type[key], - subcategory, - [Alternative(config.german_gender_ending)], - ) - - endings.append(ending) - - if ( - # GermanGenderEndingType.SLASH_DASH is redundant to GermanGenderEndingType.SLASH - key != GermanGenderEndingType.SLASH_DASH - # only check if relevant regexp is defined - and key in config._gendereddenom_ending_article - ): - word_types = ( - (-1, 2, key[0]) if key.startswith("/") else (None, None, key[0]) - ) - - ending = Rule( - key + "article", - LangType.DE, - config._gendereddenom_ending_article[key], - None, - word_types, - subcategory, - ) - - endings.append(ending) - - new_token_index = await context.regex_check.handle( - config, - client, - language, - text, - token_index, - tokens, - offsets, - list_full, - endings, - ) - - if check_continue( - list_full, token_index, new_token_index, tokens, "regex_match" - ): - return new_token_index - - return token_index - - -async def witty_rules( - config: Config, - term_replacements: list[Rule], - client: Client, - tokens: Doc, - offsets: dict, - language: Language, - text: str, -) -> list: - false_positive_matcher = context.model.fetch_false_positive_matchers( - language.lang, tokens - ) - - list_full = [] - - new_token_index = 0 - token_count = len(tokens) - while new_token_index < token_count: - token_index = new_token_index - - token = tokens[token_index] - if token._.connected_token is not None: - new_token_index += 1 - continue - - if language.lang == LangType.DE: - token.lemma_ = await german_lemmatization(tokens, token_index) - - if len(term_replacements): - new_token_index = await context.rule_check.handle( - config, - client, - language, - text, - token_index, - tokens, - offsets, - list_full, - term_replacements, - ) - - if check_continue( - list_full, token_index, new_token_index, tokens, "rule_check" - ): - continue - - if is_sub_category_enabled( - config.disabled_categories, "gender_specific_abbreviation" - ): - new_token_index = await context.regex_check.handle( - config, - client, - language, - text, - token_index, - tokens, - offsets, - list_full, - context.static_rules["m_f_regexes"], - ) - - if check_continue( - list_full, token_index, new_token_index, tokens, "regex_match" - ): - continue - - new_token_index = context.emoji_check.handle( - config, - client, - language, - text, - token_index, - tokens, - offsets, - list_full, - ) - - if check_continue( - list_full, - token_index, - new_token_index, - tokens, - "detect_non_inclusive_emoji", - ): - continue - - if token.text.startswith("#"): - new_token_index = await context.regex_check.handle( - config, - client, - language, - text, - token_index, - tokens, - offsets, - list_full, - context.static_rules[language.lang]["hashtags"], - ) - - if check_continue( - list_full, token_index, new_token_index, tokens, "regex_match" - ): - continue - - valid_text = is_valid_text(language.lang, token.text) - if valid_text: - new_token_index = await context.rule_check.handle( - config, - client, - language, - text, - token_index, - tokens, - offsets, - list_full, - None, - false_positive_matcher, - ) - - if check_continue( - list_full, token_index, new_token_index, tokens, "rule_check" - ): - continue - - if language.lang == LangType.DE: - new_token_index = await context.rule_check.handle( - config, - client, - language, - text, - token_index, - tokens, - offsets, - list_full, - None, - false_positive_matcher, - True, - ) - - if check_continue( - list_full, token_index, new_token_index, tokens, "rule_check" - ): - continue - - if language.lang == LangType.DE: - new_token_index = await german_gender_endings( - config, - client, - tokens, - offsets, - language, - text, - token_index, - list_full, - ) - - if check_continue( - list_full, token_index, new_token_index, tokens, "rule_check" - ): - continue - - if len(token.text) > 18 and is_sub_category_enabled( - config.disabled_categories, "plain_language" - ): - subwords = ( - token.text.replace("/", "-") - .replace("@", "-") - .replace(":", "-") - .replace(".", "-") - .replace("_", "-") - .split("-") - ) - highlight = len(subwords) == 1 - for subword in subwords: - if len(subword) > 12: - highlight = True - break - - if highlight: - list_full.append( - ResultOut.factory( - config, - client, - language, - token.text, - token.text, - text, - offsets, - "plain_language", - token.idx, - explanation=language.translate("TOO_LONG_WORD"), - ) - ) - - new_token_index += 1 - - if is_sub_category_enabled(config.disabled_categories, "plain_language"): - sentence_word_limit = 30 - - for sent in tokens.sents: - if len(sent) <= sentence_word_limit: - continue - - sentences = {} - sentence_parts = [] - start = sent[0].idx - - lines = sent.text.split("\n") - for line in lines: - if is_bullet_point(line): - sentence = "\n".join(sentence_parts) - if sentence and sentence.count(" ") > sentence_word_limit: - sentences[start] = sentence - - start += len(sentence) + 1 - sentence_parts = [line] - else: - sentence_parts.append(line) - - sentence = "\n".join(sentence_parts) - if sentence and sentence.count(" ") >= sentence_word_limit: - sentences[start] = sentence - - for start in sentences: - list_full.append( - ResultOut.factory( - config, - client, - language, - sentences[start], - sentences[start], - text, - offsets, - "plain_language", - start, - explanation=language.translate("TOO_LONG_SENTENCE"), - ) - ) - - return list_full - - -def is_bullet_point(line: str) -> bool: - """Check if a line is a bullet point or numbered list item. - - Args: - line: Text line to check - - Returns: - True if line starts with bullet or number, False otherwise - """ - line = line.strip() - if not line: - return False - - # Check for bullet characters - bullet_chars = {"-", "*", "•", "‣", "⁃", "⁌", "⁍", "◘", "◦", "⦾", "⦿"} - if line[0] in bullet_chars: - return True - - # Check for numbered list (e.g., "1.", "2)", "3:") - return bool(re.match(r"^\d+[).:]", line)) - - -def parse_word_type(word_type: str, lower_case: bool = True) -> tuple[str, bool, bool]: - """Parse word type string with modifiers. - - Args: - word_type: Word type string (may be prefixed with ~ or =) - lower_case: Default case sensitivity flag - - Returns: - Tuple of (word_type, lower_case, lemmatize) flags - """ - lemmatize = True - - if not word_type: - return "", lower_case, lemmatize - - match word_type[0]: - case "~": - # exact match - lower_case = True - lemmatize = False - word_type = word_type[1:] - case "=": - # exact match - lower_case = False - lemmatize = False - word_type = word_type[1:] - case "-": - # force lower case off - lower_case = False - lemmatize = True - word_type = word_type[1:] + from app.routes import slack as slack_routes - return word_type, lower_case, lemmatize + slack_routes.init_slack(context) + except (AttributeError, TypeError, RuntimeError) as e: + # Slack is optional; log and continue if initialization is not supported in this environment + if hasattr(context, "logger"): + context.logger.warning(f"Slack initialization skipped: {e}") -# want to server to run app.py in the folder app as main app, port=8000 is defaut port for the fast api -# reload=True is debag mode in Flask is on, to set =False, when deploy the app to the production -# might be added host="0.0.0.0" +# Main entry point for direct execution if __name__ == "__main__": # pragma: no cover - # If this is being ran directly as a script, run an internal uvicorn server - # to service API requests uvicorn.run( app, host="0.0.0.0", diff --git a/app/middleware.py b/app/middleware.py new file mode 100644 index 000000000..edf592bae --- /dev/null +++ b/app/middleware.py @@ -0,0 +1,49 @@ +"""Middleware configuration for security headers and CORS.""" + +import secure +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + + +# Security headers configuration +# Allow Swagger UI assets from jsdelivr and inline scripts/styles used by the +# generated Swagger HTML. Keep default-src restrictive and explicitly allow +# script-src and style-src for docs to render. +csp = ( + secure.ContentSecurityPolicy() + .set("default-src 'self' cdn.jsdelivr.net") + .set("script-src 'self' 'unsafe-inline' cdn.jsdelivr.net") + .set("style-src 'self' 'unsafe-inline' cdn.jsdelivr.net") +) +hsts = secure.StrictTransportSecurity().include_subdomains().preload().max_age(31536000) +referrer = secure.ReferrerPolicy().no_referrer() +cache_value = secure.CacheControl().no_cache() +xfo = secure.XFrameOptions().deny() + +secure_headers = secure.Secure( + csp=csp, + hsts=hsts, + referrer=referrer, + cache=cache_value, + xfo=xfo, +) + + +async def add_security_headers(request, call_next): + response = await call_next(request) + await secure_headers.set_headers_async(response) + return response + + +def setup_middleware(app: FastAPI): + # Security headers middleware + app.middleware("http")(add_security_headers) + + # CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) diff --git a/app/model.py b/app/model.py index a65dc4648..09d9253d7 100644 --- a/app/model.py +++ b/app/model.py @@ -143,6 +143,9 @@ async def _fetch_word_type( single_word: bool = False, strict: bool = False, ) -> str: + """Hacky approach to fix some issues in spaCy POS detection. + It was optimized for spaCy large models for the current rule set. + Fine tuning the spaCy models is probably the cleaner approach.""" # https://machinelearningknowledge.ai/tutorial-on-spacy-part-of-speech-pos-tagging/ # https://github.com/explosion/spaCy/blob/master/spacy/glossary.py diff --git a/app/models.py b/app/models.py index e333af544..788ac2c42 100644 --- a/app/models.py +++ b/app/models.py @@ -1,4 +1,4 @@ -from pydantic import field_validator, BaseModel +from pydantic import field_validator, BaseModel, Field from typing import Union, Optional, Annotated, Any from annotated_types import Len from enum import Enum @@ -27,6 +27,17 @@ class Client(BaseModel): name: Optional[str] = None version: Optional[str] = None + @classmethod + def parse(cls, version: Optional[str]) -> "Client": + if version is None: + version = "0.0.0" + + name = "web-ext" + if ":" in version: + name, version = version.split(":", 2) + + return cls(name=name, version=version) + class Language(object): translations = {} @@ -347,7 +358,7 @@ def get_article(self, gender: str, lemma: str) -> str | None: class RuleDynamic(BaseModel): alternatives: Optional[list] = None - false_positives: Optional[list[str]] = [] + false_positives: Optional[list[str]] = Field(default_factory=list) subcategory: Optional[str] = None article: Optional[Article] = None @@ -358,9 +369,9 @@ class Rule(Lemma): parent_id: Optional[int] lang: str actual_word_types: Optional[str] = None - subcategories: Optional[list[str]] = [] + subcategories: Optional[list[str]] = None is_advanced: bool = False - alternatives: Optional[list[Alternative]] = [] + alternatives: Optional[list[Alternative]] = None false_positives: Optional[list[str]] = None case_sensitive_false_positives: Optional[list[str]] = None explanation: Optional[str] = None @@ -489,15 +500,15 @@ class RuleIn(BaseModel): word_types: list | dict actual_word_types: Optional[str] = None subcategories: list[str] - alternatives: Optional[list[AlternativeIn]] = [] - false_positives: Optional[list[str]] = [] + alternatives: Optional[list[AlternativeIn]] = Field(default_factory=list) + false_positives: Optional[list[str]] = Field(default_factory=list) label: Optional[str] = None pattern: Optional[str] = None is_pattern_match: Optional[bool] = None type: Optional[RuleType] = RuleType.DEFAULT entity_type: Optional[EntityType] = EntityType.DEFAULT pluralization: Optional[PluralizationType] = PluralizationType.DEFAULT - lemmatizations: Optional[list[LemmatizationIn]] = [] + lemmatizations: Optional[list[LemmatizationIn]] = Field(default_factory=list) class Config(BaseModel): @@ -580,7 +591,7 @@ class Config(BaseModel): FrenchGenderSeparatorType.POINT_MEDIAN ) - disabled_categories: list = [] + disabled_categories: list = Field(default_factory=list) gendered_roles_format: GenderedRolesFormatType = GenderedRolesFormatType.BOTH show_inspiration_alternatives: bool = False alternatives_max_count: Optional[int] = None @@ -722,8 +733,8 @@ class RuleConfig(BaseModel): german_gender_ending: Optional[GermanGenderEndingConfigType] = None french_gender_separator: Optional[FrenchGenderSeparatorConfigType] = None gendered_roles_format: Optional[GenderedRolesFormatConfigType] = None - categories: Optional[dict[str, BooleanConfigType]] = {} - force_categories: Optional[list[str]] = [] + categories: Optional[dict[str, BooleanConfigType]] = Field(default_factory=dict) + force_categories: Optional[list[str]] = Field(default_factory=list) addons: Optional[list[str]] = None show_inspiration_alternatives: Optional[BooleanConfigType] = None @@ -785,8 +796,8 @@ class ConfRequest(BaseModel): name: str plan: Optional[str] = None config: RuleConfig - false_positives: list[str] = [] - term_replacements: dict[str, TermReplacement | dict] = {} + false_positives: list[str] = Field(default_factory=list) + term_replacements: dict[str, TermReplacement | dict] = Field(default_factory=dict) domains: Optional[DomainConfig] = None config_hash: Optional[str] = None sync_date: Optional[str] = None @@ -809,8 +820,8 @@ class ConfResponse(BaseModel): name: str plan: Optional[str] = None config: RuleConfig - false_positives: list[str] = [] - term_replacements: dict[str, TermReplacement] = {} + false_positives: list[str] = Field(default_factory=list) + term_replacements: dict[str, TermReplacement] = Field(default_factory=dict) domains: Optional[DomainConfig] = None config_hash: Optional[str] = None @@ -820,8 +831,10 @@ class UserConfResponse(ConfRequest): organization_id: Optional[str] = None organization_name: Optional[str] = None organization_config: Optional[RuleConfig] = None - organization_false_positives: Optional[list[str]] = [] - organization_term_replacements: Optional[dict[str, TermReplacement]] = {} + organization_false_positives: Optional[list[str]] = Field(default_factory=list) + organization_term_replacements: Optional[dict[str, TermReplacement]] = Field( + default_factory=dict + ) organization_domains: Optional[DomainConfig] = None organization_config_hash: Optional[str] = None organization_trial_ends_at: Optional[str] = None @@ -832,7 +845,7 @@ class UserConfResponse(ConfRequest): class BaseRequestIn(BaseModel): client: Optional[str] = None - config: Optional[Config] = Config() + config: Optional[Config] = Field(default_factory=Config) config_hash: Optional[str] = None organization_config_hash: Optional[str] = None diff --git a/app/review_prompt.py b/app/review_prompt.py index 9841d1bb0..180fcb1a5 100644 --- a/app/review_prompt.py +++ b/app/review_prompt.py @@ -9,7 +9,8 @@ def handle( review_type: ReviewType, previous_prompt: str | None = None, max_prompt_length: int | None = None, - ): + min_changes: int | None = 0, + ) -> str | None: prompt = ( 'You are an expert in inclusive language. You are tasked with editing the "previous response".' + "\n" @@ -66,4 +67,7 @@ def handle( while len(json.dumps(changes)) > max_prompt_length - len(prompt): changes.pop() + if len(changes) <= min_changes: + return None + return prompt + '\nBelow is the "issues list":\n' + json.dumps(changes) diff --git a/app/routes/__init__.py b/app/routes/__init__.py new file mode 100644 index 000000000..7018b3ea8 --- /dev/null +++ b/app/routes/__init__.py @@ -0,0 +1,41 @@ +"""Route registration for all API endpoints.""" + +from app.routes import ( + utility, + auth, + config_routes, + rephrase, + prompt, + check, + debug, +) +from app.settings import get_settings + + +def register_routes(app): + # Register utility routes + app.include_router(utility.router, tags=["utility"]) + + # Register authentication routes + app.include_router(auth.router, tags=["auth"]) + + # Register configuration management routes + app.include_router(config_routes.router, tags=["config"]) + + # Register rephrase routes + app.include_router(rephrase.router, tags=["rephrase"]) + + # Register prompt routes + app.include_router(prompt.router, tags=["prompt"]) + + # Register Slack integration routes if enabled + if get_settings().slack_enabled: + from app.routes import slack as slack_module + + app.include_router(slack_module.router, tags=["slack"]) + + # Register text checking routes (core functionality) + app.include_router(check.router, tags=["check"]) + + # Register debug routes + app.include_router(debug.router, tags=["debug"]) diff --git a/app/routes/auth.py b/app/routes/auth.py new file mode 100644 index 000000000..24c885a89 --- /dev/null +++ b/app/routes/auth.py @@ -0,0 +1,106 @@ +""" +Authentication Routes +Handles authentication, authorization, and API key management. +""" + +from typing import Union + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.security import HTTPBearer +from fastapi.security.api_key import APIKeyHeader + +from app.context import AppContext +from app.dependencies import fetch_current_username, get_app_context +from app.models import BaseRequestIn, CheckRequestIn, ErrorMessage, ResultConf, Client +from app.version_validators import client_version +from app.config_manager import fetch_configs_for_request, fetch_result_conf +from app.auth_service import fetch_user + +router = APIRouter() + + +@router.post( + "/v2.0/auth", + response_model=Union[ResultConf, dict, None], + response_model_exclude_none=True, + dependencies=[ + Depends(HTTPBearer(auto_error=False)), + Depends(APIKeyHeader(name="x-key", auto_error=False)), + ], +) +async def post_auth_2_0( + request: Request, + check_request_in: BaseRequestIn | None = None, + context: AppContext = Depends(get_app_context), +): + client = Client.parse( + check_request_in.client if check_request_in is not None else None + ) + client_version(client) + + user_email = await fetch_user( + request, context.settings, context.redis, context.http + ) + configs = ( + await fetch_configs_for_request(CheckRequestIn(text=""), user_email, context) + if user_email + else {} + ) + + context.redis.store_metrics(request, configs, "2.0", "auth") + + if configs == {}: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + ) + + config = fetch_result_conf(configs) + + if "team_analytics" in configs and not configs["team_analytics"]: + config.organization_id = None + + return config + + +@router.get( + "/api_key", + status_code=status.HTTP_204_NO_CONTENT, + responses={404: {"model": ErrorMessage}}, +) +async def get_api_key( + api_key: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + email = context.redis.db.get(f"api_key:{api_key}") + if not email: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="API key not found" + ) + + return email + + +@router.post( + "/api_key", + status_code=status.HTTP_204_NO_CONTENT, +) +async def post_api_key( + api_key: str, + email: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + context.redis.db.set(f"api_key:{api_key}", email) + + +@router.delete( + "/api_key", + status_code=status.HTTP_204_NO_CONTENT, +) +async def delete_api_key( + api_key: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + context.redis.db.delete("api_key:" + api_key) diff --git a/app/routes/check.py b/app/routes/check.py new file mode 100644 index 000000000..157653ed1 --- /dev/null +++ b/app/routes/check.py @@ -0,0 +1,132 @@ +""" +Check Routes +Core text checking functionality with language rules. +""" + +from typing import Union + +from fastapi import APIRouter, Depends, Request, Response, status +from fastapi.security import HTTPBearer +from fastapi.security.api_key import APIKeyHeader + +from app.context import AppContext +from app.dependencies import get_app_context +from app.models import CheckRequestIn, Result, ResultsOut, Client +from app.version_validators import ( + client_version, + check_api_version, + CHECK_API_VERSION, +) +from app.config_manager import ( + fetch_configs_for_request, + debug_configs, + fetch_config_change, +) +from app.language_processor import fetch_text, apply_language_rules +from app.auth_service import fetch_user + +router = APIRouter() + + +@router.post( + f"/v{CHECK_API_VERSION}/check", + response_model=Union[ResultsOut, Result], + response_model_exclude_none=True, + dependencies=[ + Depends(HTTPBearer(auto_error=False)), + Depends(APIKeyHeader(name="x-key", auto_error=False)), + ], +) +async def post_check_v2_4( + request: Request, + response: Response, + check_request_in: CheckRequestIn, + context: AppContext = Depends(get_app_context), +) -> Union[ResultsOut, Result]: + return await check(request, response, check_request_in, CHECK_API_VERSION, context) + + +async def check( + request: Request, + response: Response, + check_request_in: CheckRequestIn, + version: str | None = None, + context: AppContext | None = None, +) -> Result | ResultsOut: + client = Client.parse(check_request_in.client) + client_version(client) + + user_email = None + assert context is not None, "AppContext must be provided" + if version is not None: + check_api_version(version) + + user_email = await fetch_user( + request, context.settings, context.redis, context.http + ) + configs = await fetch_configs_for_request(check_request_in, user_email, context) + else: + configs = debug_configs(check_request_in) + + context.redis.store_metrics(request, configs, version, "check") + + context.redis.store_request_log( + check_request_in, + user_email, + request, + configs, + version, + "check", + ) + + if ( + check_request_in.config.plan is not None + and check_request_in.config.plan.startswith("witty_") + ): + text, language, limit_reached = fetch_text( + check_request_in, context.langs, context + ) + + if language is None: + response.status_code = status.HTTP_422_UNPROCESSABLE_ENTITY + return Result.factory("Language could not be determined") + + results = await apply_language_rules( + client, check_request_in.config, configs, language, text, context + ) + + lang = language.lang + + if isinstance(results, Result): + return results + else: + results = [] + # Default to English when plan is not witty_*; use string to align with other paths + lang = "en" + limit_reached = False + + notifications = None + if "notifications" in configs and configs["notifications"] > 0: + notifications = configs["notifications"] + + has_consented_to_mailing = None + if "has_consented_to_mailing" in configs: + has_consented_to_mailing = configs["has_consented_to_mailing"] + + if not isinstance(results, Result): + results = ResultsOut( + results=results, + language=lang, + limit_reached=limit_reached, + config_changed=fetch_config_change(configs, check_request_in), + notifications=notifications, + has_consented_to_mailing=has_consented_to_mailing, + gender_separator=check_request_in.config.get_gender_separator(lang), + ) + + context.redis.store_response_log( + user_email, + results, + ) + + return results diff --git a/app/routes/config_routes.py b/app/routes/config_routes.py new file mode 100644 index 000000000..f09cdc949 --- /dev/null +++ b/app/routes/config_routes.py @@ -0,0 +1,130 @@ +""" +Configuration Management Routes +Handles organization and user configuration CRUD operations. +""" + +from fastapi import APIRouter, Depends, status + +from app.context import AppContext +from app.dependencies import fetch_current_username, get_app_context +from app.models import ( + ConfResponse, + ErrorMessage, + OrganizationConfRequest, + UserConfRequest, + UserConfResponse, +) +from app.config_manager import ( + parse_term_replacements, + fetch_user_organization_configs, +) +from app.models import PrettyJSONResponse + +router = APIRouter() + + +# Organization Configuration Routes + + +@router.post( + "/organization/configs", + status_code=status.HTTP_204_NO_CONTENT, +) +async def post_organization_configs( + organization_configs: OrganizationConfRequest, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + organization_configs.term_replacements = parse_term_replacements( + organization_configs.term_replacements, context + ) + context.redis.db.set( + organization_configs.id, organization_configs.model_dump_json() + ) + + +@router.delete( + "/organization/configs", + status_code=status.HTTP_204_NO_CONTENT, +) +async def delete_organization_configs( + organization_id: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + context.redis.db.delete(organization_id) + + +@router.get( + "/organization/configs", + response_model=ConfResponse, + response_model_exclude_none=True, + responses={404: {"model": ErrorMessage}}, +) +async def get_organization_configs( + organization_id: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + return await context.redis.fetch_organization_configs_from_redis(organization_id) + + +# User Configuration Routes + + +@router.post( + "/user/configs", + status_code=status.HTTP_204_NO_CONTENT, +) +async def post_user_configs( + user_configs: UserConfRequest, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + user_configs.term_replacements = parse_term_replacements( + user_configs.term_replacements, context + ) + context.redis.db.set( + context.redis.get_user_id(user_configs.email), user_configs.model_dump_json() + ) + + +@router.delete( + "/user/configs", + status_code=status.HTTP_204_NO_CONTENT, +) +async def delete_user_configs( + email: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + context.redis.db.delete(context.redis.get_user_id(email)) + + +@router.get( + "/user/configs", + response_model=UserConfResponse, + response_model_exclude_none=True, + responses={404: {"model": ErrorMessage}}, +) +async def get_user_configs( + email: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + return await fetch_user_organization_configs(email, context) + + +# User Logs Route + + +@router.get( + "/user/logs", + response_class=PrettyJSONResponse, +) +async def get_user_logs( + email: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + return context.redis.get_user_logs(email) diff --git a/app/routes/debug.py b/app/routes/debug.py new file mode 100644 index 000000000..e159577d4 --- /dev/null +++ b/app/routes/debug.py @@ -0,0 +1,267 @@ +""" +Debug Routes +Debugging endpoints for testing rules, spacy analysis, and language features. +Only available in non-production environments. +""" + +from typing import Union + +from fastapi import APIRouter, Depends, Request, Response, status +from fastapi.security import HTTPBearer +from fastapi.security.api_key import APIKeyHeader +from spacy import displacy + +from app.context import AppContext +from app.dependencies import fetch_current_username, get_app_context +from app.settings import get_settings +from app.models import ( + Alternative, + Client, + CheckRequestIn, + Config, + LangType, + Result, + ResultOut, + ResultsOut, + Rule, + RuleIn, +) +from app.text_utils import german_lemmatization +from app.routes.check import check +from app.models import PrettyJSONResponse +from app.helper import utf16_offsets + +router = APIRouter() + + +@router.post( + "/debug/rule", + include_in_schema=not get_settings().is_prod, + response_model=list[ResultOut], + response_model_exclude_none=True, +) +async def post_debug_rule( + rule_data: RuleIn, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + language = context.languages[rule_data.lang] + config = Config(plan="witty_teams") + + tokens = context.model.fetch_tokens(language.lang, rule_data.text) + for token in tokens: + word_type = await context.model.fetch_word_type(language.lang, token) + for lemmatization in rule_data.lemmatizations: + if token.text.lower() == lemmatization.text.lower() and ( + word_type == lemmatization.word_type or lemmatization.word_type == "" + ): + token.lemma_ = lemmatization.text + break + + offsets = utf16_offsets(rule_data.text) + false_positive_matcher = context.model.fetch_false_positive_matchers( + language.lang, tokens + ) + + if rule_data.alternatives is not None: + alternative_list = [] + for alternative_in in rule_data.alternatives: + alternative = Alternative( + alternative_in.lemma, + context.model.tokenize(alternative_in.lemma, rule_data.lang), + alternative_in.word_types, + alternative_in.is_remove, + alternative_in.is_inspiration, + alternative_in.is_placeholder, + alternative_in.is_advanced, + alternative_in.is_collective_noun, + alternative_in.is_gendered_noun, + alternative_in.label, + ) + + alternative_list.append(alternative) + else: + alternative_list = None + + rule = Rule( + "test", + rule_data.lang, + rule_data.lemma, + context.model.tokenize(rule_data.lemma, rule_data.lang), + rule_data.word_types, + rule_data.subcategories, + None, + rule_data.actual_word_types, + ) + + rule.dynamic.alternatives = alternative_list + rule.pattern = rule_data.pattern + rule.is_pattern_match = rule_data.is_pattern_match + rule.false_positives = rule_data.false_positives + rule.label = rule_data.label + rule.type = rule_data.type + rule.entity_type = rule_data.entity_type + rule.pluralization = rule_data.pluralization + rule.adapt_alternatives = bool(len(alternative_list)) + + rules = [rule] + + list_full = [] + client = Client.parse("debug:" + context.version) + + token_index = 0 + token_count = len(tokens) + while token_index < token_count: + if rule_data.lang == LangType.DE: + tokens[token_index].lemma_ = await german_lemmatization( + tokens, token_index, context + ) + + for rule in rules: + await context.rule_check.handle( + config, + client, + language, + rule_data.text, + token_index, + tokens, + offsets, + list_full, + [rule], + false_positive_matcher, + ) + + token_index += 1 + + return list_full + + +@router.get( + "/debug/spacy", + include_in_schema=not get_settings().is_prod, + response_class=PrettyJSONResponse, +) +async def get_debug_spacy( + text: str, + lang: LangType, + detailed: bool = False, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + results = [] + tokens = context.model.fetch_tokens(lang, text) + + word_type_parts = [] + for token_index in range(len(tokens)): + token = tokens[token_index] + + if lang == LangType.DE: + token.lemma_ = await german_lemmatization(tokens, token_index, context) + word_type = await context.model.fetch_word_type(lang, token) + + # Get string value for word_type, prefix with '~' if text != lemma + word_type_str = getattr(word_type, "value", word_type) + if token.text != token.lemma_: + word_type_str = f"~{word_type_str}" + word_type_parts.append(word_type_str) + + token_info = { + "text": token.text, + "lemma": token.lemma_, + "word_type": word_type, + "is_singular": context.model.is_token_singular(lang, token), + "ner": token.ent_type_, + } + + if detailed: + token_info["start"] = token.idx + token_info["whitespace"] = token.whitespace_ + token_info["emoji_desc"] = token._.emoji_desc + token_info["is_emoji"] = token._.is_emoji + token_info["morph"] = token.morph.to_dict() + token_info["tag"] = token.tag_ + token_info["pos"] = token.pos_ + token_info["dep"] = token.dep_ + token_info["head"] = token.head.text + + dependent = None + children = [] + for a in token.ancestors: + for atok in a.children: + children.append( + {"dep": atok.dep_, "token": atok.text, "ner": atok.ent_type_} + ) + if dependent is None and atok.dep_ in ["pobj", "dobj"]: + dependent = atok.text + + token_info["dependent"] = dependent + token_info["children"] = children + + results.append(token_info) + + if detailed: + noun_chunks = [] + for chunk in tokens.noun_chunks: + noun_chunks.append( + { + "text": chunk.text, + "start": chunk.start, + "end": chunk.end, + } + ) + + results = [{"noun chunks": noun_chunks}] + results + + # Build the word type rule string + word_type_rule = "|".join(word_type_parts) + return [{"auto-detected word type": word_type_rule}] + results + + +@router.get( + "/debug/displacy", + include_in_schema=not get_settings().is_prod, +) +async def get_debug_displacy( + text: str, + lang: LangType, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + tokens = context.model.fetch_tokens(lang, text) + + sentence_spans = list(tokens.sents) + data = displacy.render(sentence_spans, style="dep") + return Response(content=data, media_type="image/svg+xml") + + +@router.get( + "/debug/german_noun", + include_in_schema=not get_settings().is_prod, + response_class=PrettyJSONResponse, +) +async def get_debug_german_noun( + word: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + return await context.nouns.german_noun_lookup(word) + + +@router.post( + "/debug/check", + response_model=Union[ResultsOut, Result], + response_model_exclude_none=True, + dependencies=[ + Depends(HTTPBearer(auto_error=False)), + Depends(APIKeyHeader(name="x-key", auto_error=False)), + ], + include_in_schema=not get_settings().is_prod, +) +async def post_debug_check( + request: Request, + response: Response, + check_request_in: CheckRequestIn, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +): + return await check(request, response, check_request_in, None, context) diff --git a/app/routes/prompt.py b/app/routes/prompt.py new file mode 100644 index 000000000..d33695116 --- /dev/null +++ b/app/routes/prompt.py @@ -0,0 +1,155 @@ +""" +Prompt Routes +Handles LLM prompt processing and review functionality. +""" + +from typing import Union + +from fastapi import APIRouter, Depends, Request, Response, status +from fastapi.security import HTTPBearer +from fastapi.security.api_key import APIKeyHeader + +from app.context import AppContext +from app.categories import inclusive_categories +from app.dependencies import fetch_current_username, get_app_context +from app.settings import get_settings +from app.models import CheckRequestIn, PromptOut, Result, ReviewType, Client +from app.version_validators import REPHRASE_API_VERSION +from app.config_manager import fetch_configs_for_request, debug_configs +from app.language_processor import fetch_text, apply_language_rules +from app.review_prompt import ReviewPrompt +from app.routes.check import check + +router = APIRouter() + + +@router.post( + "/debug/review_prompt", + response_model=Union[str, Result], + response_model_exclude_none=True, + dependencies=[ + Depends(HTTPBearer(auto_error=False)), + Depends(APIKeyHeader(name="x-key", auto_error=False)), + ], + include_in_schema=not get_settings().is_prod, +) +async def debug_review_prompt( + request: Request, + response: Response, + check_request_in: CheckRequestIn, + context: AppContext = Depends(get_app_context), + review_type: ReviewType = ReviewType.EXPLAIN_EDITS, +) -> Result | str: + for category in inclusive_categories: + if category in check_request_in.config.disabled_categories: + continue + + check_request_in.config.disabled_categories.append(category) + + check_result = await check(request, response, check_request_in, None, context) + if isinstance(check_result, Result): + return check_result + + result = ReviewPrompt.handle( + check_result.results, review_type, check_request_in.text, 1900 + ) + if result is None: + result = "WITTYNOCHANGES" + + return result + + +@router.post( + "/debug/prompt", + response_model=Union[Result, PromptOut, None], + response_model_exclude_none=True, + include_in_schema=not get_settings().is_prod, +) +async def debug_prompt( + request: Request, + response: Response, + check_request_in: CheckRequestIn, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +) -> Result | PromptOut: + configs = debug_configs(check_request_in) + return await prompt(response, check_request_in, configs, context) + + +@router.post( + f"/v{REPHRASE_API_VERSION}/prompt", + response_model=Union[Result, PromptOut, None], + response_model_exclude_none=True, +) +async def post_prompt( + request: Request, + response: Response, + check_request_in: CheckRequestIn, + user_email: str, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +) -> Result | PromptOut: + configs = await fetch_configs_for_request(check_request_in, user_email, context) + if configs == {}: + response.status_code = status.HTTP_401_UNAUTHORIZED + return Result.factory("User config missing") + + context.redis.store_metrics(request, configs, REPHRASE_API_VERSION, "prompt") + return await prompt(response, check_request_in, configs, context) + + +async def prompt( + response: Response, + check_request_in: CheckRequestIn, + configs: dict, + context: AppContext, +) -> Result | PromptOut: + if ( + check_request_in.config.plan is None + or not check_request_in.config.plan.startswith("witty_") + ): + response.status_code = status.HTTP_402_PAYMENT_REQUIRED + return Result.factory("Plan missing") + + if not check_request_in.config.llm_alternatives: + response.status_code = status.HTTP_401_UNAUTHORIZED + return Result.factory("User config disallows LLM use") + + check_request_in.text = await context.prompt.handle( + check_request_in.text, None, None, 0.4 + ) + check_request_in.text = context.prompt.parse_json(check_request_in.text) + + for category in inclusive_categories: + if category in check_request_in.config.disabled_categories: + continue + + check_request_in.config.disabled_categories.append(category) + + text, language, limit_reached = fetch_text(check_request_in, context.langs, context) + + if language is None: + response.status_code = status.HTTP_422_UNPROCESSABLE_ENTITY + return Result.factory("Language could not be determined") + + client = Client.parse(check_request_in.client) + check_result = await apply_language_rules( + client, check_request_in.config, configs, language, text, context + ) + + reviewed_response = None + if len(check_result) != 0: + review_prompt = ReviewPrompt.handle( + check_result, ReviewType.INCLUDE_PREVIOUS, check_request_in.text + ) + + if review_prompt is not None: + reviewed_response = await context.prompt.handle(review_prompt) + reviewed_response = context.prompt.parse_json(reviewed_response) + + return PromptOut( + initial_response=check_request_in.text, + reviewed_response=reviewed_response, + check_results=check_result, + limit_reached=limit_reached, + ) diff --git a/app/routes/rephrase.py b/app/routes/rephrase.py new file mode 100644 index 000000000..97a7dea65 --- /dev/null +++ b/app/routes/rephrase.py @@ -0,0 +1,134 @@ +""" +Rephrase Routes +Handles text rephrasing using LLM alternatives. +""" + +from typing import Union + +from fastapi import APIRouter, Depends, Request, Response, status +from fastapi.security import HTTPBearer +from fastapi.security.api_key import APIKeyHeader + +from app.context import AppContext +from app.dependencies import fetch_current_username, get_app_context +from app.settings import get_settings +from app.models import RephraseRequestIn, RephrasesOut, Result, Client +from app.version_validators import ( + client_version, + rephrase_api_version, + REPHRASE_API_VERSION, +) +from app.config_manager import fetch_configs_for_request +from app.auth_service import fetch_user + +router = APIRouter() + + +@router.post( + "/debug/rephrase", + response_model=Union[RephrasesOut, Result], + response_model_exclude_none=True, + dependencies=[ + Depends(HTTPBearer(auto_error=False)), + Depends(APIKeyHeader(name="x-key", auto_error=False)), + ], + include_in_schema=not get_settings().is_prod, +) +async def post_debug_rephrase( + request: Request, + response: Response, + rephrase_request_in: RephraseRequestIn, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +) -> Union[RephrasesOut, Result]: + """Debug endpoint for rephrasing text. + + Args: + request: FastAPI request object + response: FastAPI response object + rephrase_request_in: Rephrase request data + username: Authenticated username + + Returns: + Rephrased alternatives or error result + """ + return await rephrase_sentence(request, response, rephrase_request_in, context) + + +@router.post( + f"/v{REPHRASE_API_VERSION}/rephrase", + response_model=Union[RephrasesOut, Result], + response_model_exclude_none=True, + dependencies=[ + Depends(HTTPBearer(auto_error=False)), + Depends(APIKeyHeader(name="x-key", auto_error=False)), + ], +) +async def post_rephrase_v1_0( + request: Request, + response: Response, + rephrase_request_in: RephraseRequestIn, + context: AppContext = Depends(get_app_context), +) -> Union[RephrasesOut, Result]: + return await rephrase_sentence( + request, response, rephrase_request_in, context, REPHRASE_API_VERSION + ) + + +async def rephrase_sentence( + request: Request, + response: Response, + rephrase_request_in: RephraseRequestIn, + context: AppContext, + version: str | None = None, +) -> Union[RephrasesOut, Result]: + client = Client.parse(rephrase_request_in.client) + client_version(client) + + if version is not None: + if rephrase_request_in.model is not None: + return Result.factory("Model can only be set in debug mode") + + rephrase_api_version(version) + + user_email = await fetch_user( + request, context.settings, context.redis, context.http + ) + if user_email is None: + response.status_code = status.HTTP_401_UNAUTHORIZED + return Result.factory("User not found") + + configs = await fetch_configs_for_request( + rephrase_request_in, user_email, context + ) + + if ( + rephrase_request_in.config.plan is None + or not rephrase_request_in.config.plan.startswith("witty_") + ): + response.status_code = status.HTTP_401_UNAUTHORIZED + return Result.factory("No valid plan on user") + + if not rephrase_request_in.config.llm_alternatives: + response.status_code = status.HTTP_403_FORBIDDEN + return Result.factory("Rephrasing via LLM not enabled on user") + else: + # debug mode + configs = {} + + context.redis.store_metrics(request, configs, version, "rephrase") + + try: + result = RephrasesOut.factory( + rephrase_request_in.sentence, + await context.llm_alternatives.handle(rephrase_request_in), + ) + except Exception as e: + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + message = "An error occurred" + if not version: + message = f"{message}: {e}" + + return Result.factory(message) + + return result diff --git a/app/routes/slack.py b/app/routes/slack.py new file mode 100644 index 000000000..250c75d9c --- /dev/null +++ b/app/routes/slack.py @@ -0,0 +1,95 @@ +""" +Slack Integration Routes +Handles Slack commands and webhook processing. +""" + +from fastapi import APIRouter, Request, HTTPException, status +from slack_bolt.async_app import AsyncAck, AsyncRespond +from slack_sdk.web.async_client import AsyncWebClient as WebClient +from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler + +from app.models import CheckRequestIn, Client +from app.config_manager import ( + fetch_configs_for_request, + fetch_organization_configs_for_request, +) +from app.language_processor import fetch_text, apply_language_rules +from app.bolt import process_command_witty, get_bolt +from app.context import AppContext +from app.settings import get_settings + +router = APIRouter() + +# Slack Bolt app and handler are initialized at startup via init_slack(context) +bolt = None +bolt_handler = None + + +def init_slack(context: AppContext) -> None: + """Initialize Slack Bolt app with AppContext and register listeners.""" + global bolt, bolt_handler + bolt = get_bolt(get_settings(), context) + + @bolt.command("/witty") # type: ignore[attr-defined] + async def handle_command_witty( + body: dict, + ack: AsyncAck, + respond: AsyncRespond, + client: WebClient, + context: dict, + ): # pragma: no cover + await ack() + + app_context: AppContext | None = context.get("app_context") + if app_context is None: + await respond("Service not ready yet. Please try again shortly.") + return + + check_request_in = CheckRequestIn(client="slack:1.0.0", text=body["text"]) + text, language, limit_reached = fetch_text( + check_request_in, app_context.langs, app_context + ) + + if language is None: + await respond(f"Witty could not determine a language for '{text}'.") + return + + configs = {} + + try: + user = await client.users_info(user=body["user_id"]) + configs = await fetch_configs_for_request( + check_request_in, user.data["user"]["profile"]["email"], app_context + ) + except KeyError: + pass + + if configs == {} and app_context.settings.slack_organization_id: + configs = await fetch_organization_configs_for_request( + check_request_in, + app_context.settings.slack_organization_id, + app_context, + ) + + check_request_in.config.__setattr__("alternatives_max_count", None) + parsed_client = Client.parse(check_request_in.client) + results = await apply_language_rules( + parsed_client, check_request_in.config, configs, language, text, app_context + ) + + return await process_command_witty( + text, language, limit_reached, results, respond + ) + + # Create request handler after registering listeners + bolt_handler = AsyncSlackRequestHandler(bolt) + + +@router.post("/slack/commands") +async def post_slack_commands(request: Request): # pragma: no cover + if bolt_handler is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Slack not initialized", + ) + return await bolt_handler.handle(request) diff --git a/app/routes/utility.py b/app/routes/utility.py new file mode 100644 index 000000000..fa2e60805 --- /dev/null +++ b/app/routes/utility.py @@ -0,0 +1,192 @@ +"""Utility routes for basic API operations.""" + +import json +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi.encoders import jsonable_encoder +from fastapi.openapi.docs import get_swagger_ui_html +from fastapi.openapi.utils import get_openapi +from fastapi.exceptions import RequestValidationError +from starlette.responses import RedirectResponse + +from app.models import LangType +from app.dependencies import fetch_current_username, get_app_context +from app.context import AppContext +from app.settings import get_settings, Settings +from app.text_utils import parse_word_type + +router = APIRouter() + + +@router.get("/health") +async def get_health( + check_external: bool = False, context: AppContext = Depends(get_app_context) +) -> dict: + health = {} + + langs = { + LangType.EN: "Hello guys", + LangType.DE: "Hallo Kunde", + LangType.FR: "Je m'appelle Luc", + } + + for lang in context.langs: + try: + context.model.fetch_tokens(lang, langs[lang]) + health["model_" + lang] = True + except Exception: + health["model_" + lang] = False + + if check_external: + try: + languagetool_health = await context.http.fetch_json_get( + context.settings.languagetool_api + "/healthcheck", + {}, + {}, + "LanguageTool", + context.settings.languagetool_verify_ssl, + False, + ) + + health["spelling"] = languagetool_health == "OK" + except Exception: + health["spelling"] = False + + try: + health["config"] = context.redis.db.ping() + except Exception: + health["config"] = False + + content = jsonable_encoder(health) + + for key in health: + if not health[key]: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=content, + ) + + return content + + +@router.get("/lt", include_in_schema=not get_settings().is_prod) +def get_lt( + username: str = Depends(fetch_current_username), + context: AppContext = Depends(get_app_context), +) -> str: + return context.settings.languagetool_api + + +@router.get("/settings", include_in_schema=not get_settings().is_prod) +def get_app_settings( + username: str = Depends(fetch_current_username), + context: AppContext = Depends(get_app_context), +) -> Settings: + return context.settings + + +@router.get("/docs", include_in_schema=False) +def get_swagger_documentation( + username: str = Depends(fetch_current_username), +): # pragma: no cover + return get_swagger_ui_html(openapi_url="/openapi.json", title="docs") + + +@router.get("/openapi.json", include_in_schema=False) +def get_openapi_json(request: Request) -> dict: # pragma: no cover + app = request.app + return get_openapi(title=app.title, version=app.version, routes=app.routes) + + +@router.get("/", include_in_schema=False) +def get_root(context: AppContext = Depends(get_app_context)): + if ( + not context.settings.is_prod and context.settings.testing is False + ): # pragma: no cover + return RedirectResponse(url="/docs", status_code=302) + + return "Witty NLP API: https://witty.works" + + +@router.get( + "/save_openapi_json", + include_in_schema=not get_settings().is_prod, +) +def get_save_openapi_json( + request: Request, username: str = Depends(fetch_current_username) +): # pragma: no cover + openapi_data = request.app.openapi() + for path in openapi_data["paths"].copy(): + if "v2.0" not in path: + del openapi_data["paths"][path] + + with open("openapi.json", "w") as file: + json.dump(openapi_data, file, indent=4, sort_keys=True) + + +@router.get("/lemmatize") +async def get_lemmatize( + text: str, + lang: LangType, + all: bool = False, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +) -> str | tuple[str, ...] | None: + tokens = context.model.fetch_tokens(lang, text) + if all: + return tuple([i.lemma_ for i in tokens]) + + if len(tokens) != 1: + return None + + return tokens[0].lemma_ + + +@router.get("/tokenize") +async def get_tokenize( + text: str, + lang: LangType, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +) -> tuple[str, ...]: + return context.model.tokenize(text, lang) + + +@router.get("/parse-word-types") +async def get_parse_word_types( + text: str, + word_types: str, + lang: LangType, + context: AppContext = Depends(get_app_context), + username: str = Depends(fetch_current_username), +) -> list[dict[str, Any]]: + tokens = context.model.fetch_tokens(lang, text) + word_type_list = word_types.split("|") + + if len(tokens) != len(word_type_list): + raise RequestValidationError( + f"Word type '{word_types}' count does not match text token count '{len(tokens)}' for text '{text}'." + ) + + parsed_word_types = [] + for word_type in word_type_list: + parsed_word_type, lower_case, lemmatize = parse_word_type(word_type) + + if ( + parsed_word_type != "" + and parsed_word_type not in context.supported_word_types + ): + raise RequestValidationError( + f"Word type '{word_type}' within '{word_types}' contains unsupported word type '{parsed_word_type}'" + ) + + parsed_word_types.append( + { + "word_type": parsed_word_type, + "lower_case": lower_case, + "lemmatize": lemmatize, + } + ) + + return parsed_word_types diff --git a/app/rule_check.py b/app/rule_check.py index c4aef93f4..3a4652439 100644 --- a/app/rule_check.py +++ b/app/rule_check.py @@ -11,7 +11,6 @@ RuleLabelEnum, Alternative, GenderedRolesFormatType, - ResultOut, ) from app.helper import is_valid_text, is_addon_enabled, upperfirst from app.categories import ( @@ -32,6 +31,10 @@ from copy import deepcopy from spacy.tokens import Token, Doc, Span from logging import Logger +from app.rule_engine.utils import append_result +from app.rule_engine.matchers.pattern import is_phrase_match +from app.alternatives_engine import utils +from app.rule_engine import utils as rule_utils class RuleCheck: @@ -67,13 +70,6 @@ def __init__( self.adjectives = adjectives self.alternatives = alternatives - def is_target_noun(self, token: Token): - return ( - token.dep_.endswith("subj") - or token.dep_.endswith("obj") - or token.dep_.startswith("obl") - ) - def is_entity_type_mismatch(self, rule: Rule, token: Token): if rule.entity_type == EntityType.DEFAULT: return False @@ -147,7 +143,7 @@ async def fetch_rules( [token.lemma_], [{"word_type": "a", "lower_case": True, "lemmatize": True}], ["hidden_image"], - self.alternatives.get_adjective_alternatives_french( + utils.build_french_adjective_alternatives( male_form, female_form ), ) @@ -189,9 +185,7 @@ async def fetch_rules( [token.lemma_], [{"word_type": "a", "lower_case": True, "lemmatize": True}], ["hidden_image"], - self.alternatives.get_adjective_alternatives_french( - male_form, female_form - ), + utils.build_french_adjective_alternatives(male_form, female_form), ) rule.adapt_alternatives = True @@ -214,8 +208,8 @@ async def check_not_for_people( if rule.label_type != RuleLabelEnum.NOT_FOR_PEOPLE: return False - chunks = self.fetch_sentence_noun_chunks(tokens[token_index].sent) - token_chunk = self.find_token_chunk(chunks, token_index) + chunks = rule_utils.fetch_sentence_noun_chunks(tokens[token_index].sent) + token_chunk = rule_utils.find_token_chunk(chunks, token_index) if token_chunk is None: # No noun detected => assume false positive if len(chunks) == 0: @@ -255,19 +249,19 @@ def is_french_adjective_false_positive( source_noun = None for a in token.ancestors: - if self.is_target_noun(a): + if rule_utils.is_target_noun(a): source_noun = a break for atok in a.children: - if self.is_target_noun(atok): + if rule_utils.is_target_noun(atok): source_noun = atok break if source_noun is None: source_index = word_index = None for word in token.sent: - if not self.is_target_noun(word): + if not rule_utils.is_target_noun(word): continue if word.i < word.head.i: @@ -315,10 +309,8 @@ def is_french_adjective_false_positive( if false_positive_check is None: rule.dynamic.subcategory = "hidden_image" - rule.alternatives = ( - self.alternatives.get_adjective_alternatives_french( - male_form, female_form - ) + rule.alternatives = utils.build_french_adjective_alternatives( + male_form, female_form ) rule.adapt_alternatives = True else: @@ -369,7 +361,7 @@ async def is_french_noun_false_positive( or category_name in self.static_rules["male_specific_dimensions"] ): # false positive check - gender = self.get_token_gender(token) + gender = rule_utils.get_token_gender(token) if gender is None: result = await self.db.fetch_declensions( LangType.FR, WordType.NOUN, token.text, token @@ -827,10 +819,11 @@ async def generate_french_alternatives( 0 if article.lower() == "les" else article_index ) - collective_noun = self.alternatives.add_article( + collective_noun = utils.add_article( LangType.FR, collective_noun, - self.alternatives.get_article_by_index( + utils.get_article_by_index( + self.static_rules, LangType.FR, articles_list, collective_article_index, @@ -870,7 +863,7 @@ async def generate_french_alternatives( and rule.pattern.startswith("article|l") and not is_plural ): - gendered_article, _, _ = self.get_previous_article( + gendered_article, _, _ = self.alternatives.get_previous_article( token_index, tokens, LangType.FR ) gendered_article = gendered_article.lower() @@ -916,9 +909,9 @@ async def generate_french_alternatives( male_article + " " + gender_neutral_noun - + self.static_rules[LangType.FR]["noun_conjunction"][ - "singular" - ] + + utils.get_noun_conjunction( + self.static_rules, LangType.FR, True + ) + female_article + " " + gender_neutral_noun @@ -930,9 +923,9 @@ async def generate_french_alternatives( female_article + " " + gender_neutral_noun - + self.static_rules[LangType.FR]["noun_conjunction"][ - "singular" - ] + + utils.get_noun_conjunction( + self.static_rules, LangType.FR, True + ) + male_article + " " + gender_neutral_noun @@ -1028,7 +1021,10 @@ async def handle( skip_token = token_index + token._.token_index_offset start_token_index = token_index else: - skip_token, start_token_index, text = await self.is_phrase_match( + skip_token, start_token_index, text = await is_phrase_match( + self.model, + self.db, + self.logger, language.lang, token_index, tokens, @@ -1037,9 +1033,7 @@ async def handle( ) if self.is_german_pronoun_check_required(language.lang, token): - subcategory = self.german_pronoun_check( - config, rule, token - ) + subcategory = self.german_pronoun_check(config, rule, token) if not subcategory: continue rule.dynamic.subcategory = subcategory @@ -1131,49 +1125,45 @@ async def handle( label = token._.label if token._.label is not None else rule.label - list_full.append( - ResultOut.factory( - config, - client, - language, - text, - rule.text_id, - full_text, - offsets, - rule.dynamic.subcategory, - start, - None, - rule.alternatives, - None, - rule.explanation, - rule.url, - rule.icon, - label, - rule.source, - ) + append_result( + list_full, + config=config, + client=client, + language=language, + text=text, + text_id=rule.text_id, + full_text=full_text, + offsets=offsets, + subcategory=rule.dynamic.subcategory, + start=start, + alternatives=rule.alternatives, + label=None, + explanation=rule.explanation, + url=rule.url, + icon=rule.icon, + explanation_context=label, + source=rule.source, ) if token._.child_token: - list_full.append( - ResultOut.factory( - config, - client, - language, - token._.child_token.text, - token.lemma_, - full_text, - offsets, - rule.dynamic.subcategory, - token._.child_token.idx, - None, - rule.alternatives, - None, - rule.explanation, - rule.url, - rule.icon, - label, - rule.source, - ) + append_result( + list_full, + config=config, + client=client, + language=language, + text=token._.child_token.text, + text_id=token.lemma_, + full_text=full_text, + offsets=offsets, + subcategory=rule.dynamic.subcategory, + start=token._.child_token.idx, + alternatives=rule.alternatives, + label=None, + explanation=rule.explanation, + url=rule.url, + icon=rule.icon, + explanation_context=label, + source=rule.source, ) return skip_token @@ -1372,31 +1362,9 @@ async def check_person_noun( return skip # TODO cache on the sentence? - def fetch_sentence_noun_chunks(self, sent: Span) -> list[Span]: - chunks = [] - for chunk in sent.noun_chunks: - chunks.append(chunk) - - return chunks - - def find_token_chunk(self, chunks: list[Span], token_index: int): - for chunk in chunks: - if chunk.start <= token_index < chunk.end: - return chunk - if chunk.start > token_index: - break - - return None - - def get_token_gender(self, token: Token): - gender = token.morph.get("Gender") - if len(gender): - return gender[0] - - return None def is_gender_false_positive(self, token: Token) -> bool: - gender = self.get_token_gender(token) + gender = rule_utils.get_token_gender(token) if gender is None: return False @@ -1409,7 +1377,7 @@ def is_gender_false_positive(self, token: Token) -> bool: if lemma != token.lemma_.lower(): continue - gender = self.get_token_gender(token) + gender = rule_utils.get_token_gender(token) if not male_form_found and gender == "Masc": male_form_found = True elif not female_form_found and gender == "Fem": @@ -1439,213 +1407,3 @@ async def is_rule_false_positive( ) return result - - async def check_pattern( - self, - lang: LangType, - tokens: Doc, - pattern: list, - i_pattern_start: int, - offset: int, - ) -> bool | int: - count = 0 - for word_type in pattern: - allow_skip = word_type.endswith("*") - if i_pattern_start < 0: - return False - - if i_pattern_start >= len(tokens): - if allow_skip: - continue - - return False - - if allow_skip: - word_type = word_type.removesuffix("*") - while i_pattern_start >= 0 and await self.model.check_word_type( - lang, tokens[i_pattern_start], word_type, True, True - ): - i_pattern_start -= 1 - count += 1 - if ( - lang == LangType.FR - and i_pattern_start >= 0 - and tokens[i_pattern_start + 1].lemma_ == "la" - and tokens[i_pattern_start].lemma_ in ["à", "de"] - ): - i_pattern_start -= 1 - count += 1 - elif await self.model.check_word_type( - lang, tokens[i_pattern_start], word_type, True - ): - if ( - lang == LangType.FR - and i_pattern_start >= 1 - and tokens[i_pattern_start].lemma_ == "la" - and tokens[i_pattern_start - 1].lemma_ in ["à", "de"] - ): - i_pattern_start += 1 - count += 1 - - i_pattern_start += offset - count += 1 - else: - return False - - return count - - async def is_word_match( - self, - lang: LangType, - token: Token, - word: str, - word_type: dict | None, - suffix: str, - lemma: str | None = None, - ) -> bool: - if word_type is None: - word_type = { - "word_type": "", - "lemmatize": True, - "lower_case": True, - } - - lemma_ = token.lemma_ if lemma is None else lemma - token_word = lemma_ if word_type["lemmatize"] else token.text - - # ignore differences between ’ and ' - token_word = token_word.replace("’", "'") - word = word.replace("’", "'") - - if word_type["lower_case"]: - token_word = token_word.lower() - word = word.lower() - - if token_word != word and ( - not suffix or not token_word.lower().endswith(word.lower()) - ): - if ( - lemma is None - and word_type["lemmatize"] - and lemma_ in self.db.male_to_female_normativ - ): - return await self.is_word_match( - lang, - token, - word, - word_type, - suffix, - self.db.male_to_female_normativ[lemma_], - ) - return False - - return await self.model.check_word_type( - lang, token, word_type["word_type"], True - ) - - async def is_phrase_match( - self, - lang: LangType, - token_index: int, - tokens: Doc, - rule: Rule, - false_positive_matcher: list | None = None, - ) -> tuple[int | None, str | None]: - suffix = rule.type == RuleType.SUFFIX - - word_count = len(rule.words) - word_types_count = len(rule.word_types) - if word_count > 1: - suffix = False - - skip_token_index = token_index - text = "" - for word_index in range(word_count): - if word_index > 0: - text += word_token.whitespace_ - - try: - word_token = tokens[token_index + word_index] - except IndexError: - return None, None, None - - word_type = ( - rule.word_types[word_index] if word_index < word_types_count else None - ) - - if not await self.is_word_match( - lang, - word_token, - rule.words[word_index], - word_type, - suffix, - ): - return None, None, None - - text += word_token.text - - skip_token_index += 1 - - if false_positive_matcher is not None and self.model.is_false_positive_match( - false_positive_matcher, token_index, tokens, rule.lemma - ): - return None, None, None - - start_token_index = token_index - if rule.pattern is not None: - pattern = rule.pattern.split("|") - if pattern[0] == "*" or pattern[-1] == "*": - self.logger.error( - "Rule pattern may not start or end with '*' but is '%s', rule id %i, idx: '%s'", - rule.pattern, - rule.id, - tokens[token_index].idx, - ) - - return None, None, None - - token_count = word_count - prefix_tokens_match_count = 0 - lemma_position = pattern.index("l") - - if lemma_position > 0: - prefix_pattern = pattern[0:lemma_position] - prefix_pattern.reverse() - tokens_match_count = await self.check_pattern( - lang, tokens, prefix_pattern, token_index - 1, -1 - ) - if tokens_match_count is False: - return None, None, None - - prefix_tokens_match_count += tokens_match_count - - suffix_pattern = pattern[lemma_position + 1 :] - if len(suffix_pattern): - tokens_match_count = await self.check_pattern( - lang, tokens, suffix_pattern, token_index + token_count, 1 - ) - if tokens_match_count is False: - return None, None, None - - token_count += tokens_match_count - - if rule.is_pattern_match: - start_token_index -= prefix_tokens_match_count - text = "" - for k in range(prefix_tokens_match_count + token_count): - if k > 0: - text += tokens[start_token_index + k - 1].whitespace_ - - text += tokens[start_token_index + k].text - - skip_token_index = start_token_index + token_count + 1 - - if ( - start_token_index + tokens[start_token_index]._.token_index_offset - > skip_token_index - ): - skip_token_index = ( - start_token_index + tokens[start_token_index]._.token_index_offset - ) - - return skip_token_index, start_token_index, text diff --git a/app/rule_engine/matchers/pattern.py b/app/rule_engine/matchers/pattern.py new file mode 100644 index 000000000..c695564e3 --- /dev/null +++ b/app/rule_engine/matchers/pattern.py @@ -0,0 +1,229 @@ +"""Pattern and phrase matching helpers for RuleCheck. + +These helpers encapsulate async matching logic and depend on the provided +model, db, and logger instead of a specific owning class. +""" + +from typing import Tuple +from spacy.tokens import Doc, Token +from logging import Logger + +from app.models import LangType, Rule, RuleType + + +async def check_pattern( + model, + lang: LangType, + tokens: Doc, + pattern: list, + i_pattern_start: int, + offset: int, +) -> bool | int: + count = 0 + for word_type in pattern: + allow_skip = word_type.endswith("*") + if i_pattern_start < 0: + return False + + if i_pattern_start >= len(tokens): + if allow_skip: + continue + + return False + + if allow_skip: + word_type = word_type.removesuffix("*") + while i_pattern_start >= 0 and await model.check_word_type( + lang, tokens[i_pattern_start], word_type, True, True + ): + i_pattern_start -= 1 + count += 1 + if ( + lang == LangType.FR + and i_pattern_start >= 0 + and tokens[i_pattern_start + 1].lemma_ == "la" + and tokens[i_pattern_start].lemma_ in ["à", "de"] + ): + i_pattern_start -= 1 + count += 1 + elif await model.check_word_type( + lang, tokens[i_pattern_start], word_type, True + ): + if ( + lang == LangType.FR + and i_pattern_start >= 1 + and tokens[i_pattern_start].lemma_ == "la" + and tokens[i_pattern_start - 1].lemma_ in ["à", "de"] + ): + i_pattern_start += 1 + count += 1 + + i_pattern_start += offset + count += 1 + else: + return False + + return count + + +async def is_word_match( + model, + db, + lang: LangType, + token: Token, + word: str, + word_type: dict | None, + suffix: str, + lemma: str | None = None, +) -> bool: + if word_type is None: + word_type = { + "word_type": "", + "lemmatize": True, + "lower_case": True, + } + + lemma_ = token.lemma_ if lemma is None else lemma + token_word = lemma_ if word_type["lemmatize"] else token.text + + # ignore differences between ’ and ' + token_word = token_word.replace("’", "'") + word = word.replace("’", "'") + + if word_type["lower_case"]: + token_word = token_word.lower() + word = word.lower() + + if token_word != word and ( + not suffix or not token_word.lower().endswith(word.lower()) + ): + if ( + lemma is None + and word_type["lemmatize"] + and lemma_ in db.male_to_female_normativ + ): + return await is_word_match( + model, + db, + lang, + token, + word, + word_type, + suffix, + db.male_to_female_normativ[lemma_], + ) + return False + + return await model.check_word_type(lang, token, word_type["word_type"], True) + + +async def is_phrase_match( + model, + db, + logger: Logger, + lang: LangType, + token_index: int, + tokens: Doc, + rule: Rule, + false_positive_matcher: list | None = None, +) -> Tuple[int | None, int | None, str | None]: + suffix = rule.type == RuleType.SUFFIX + + word_count = len(rule.words) + word_types_count = len(rule.word_types) + if word_count > 1: + suffix = False + + skip_token_index = token_index + text = "" + for word_index in range(word_count): + if word_index > 0: + text += word_token.whitespace_ + + try: + word_token = tokens[token_index + word_index] + except IndexError: + return None, None, None + + word_type = ( + rule.word_types[word_index] if word_index < word_types_count else None + ) + + if not await is_word_match( + model, + db, + lang, + word_token, + rule.words[word_index], + word_type, + suffix, + ): + return None, None, None + + text += word_token.text + + skip_token_index += 1 + + if false_positive_matcher is not None and model.is_false_positive_match( + false_positive_matcher, token_index, tokens, rule.lemma + ): + return None, None, None + + start_token_index = token_index + if rule.pattern is not None: + pattern = rule.pattern.split("|") + if pattern[0] == "*" or pattern[-1] == "*": + logger.error( + "Rule pattern may not start or end with '*' but is '%s', rule id %i, idx: '%s'", + rule.pattern, + rule.id, + tokens[token_index].idx, + ) + + return None, None, None + + token_count = word_count + prefix_tokens_match_count = 0 + lemma_position = pattern.index("l") + + if lemma_position > 0: + prefix_pattern = pattern[0:lemma_position] + prefix_pattern.reverse() + tokens_match_count = await check_pattern( + model, lang, tokens, prefix_pattern, token_index - 1, -1 + ) + if tokens_match_count is False: + return None, None, None + + prefix_tokens_match_count += tokens_match_count + + suffix_pattern = pattern[lemma_position + 1 :] + if len(suffix_pattern): + tokens_match_count = await check_pattern( + model, lang, tokens, suffix_pattern, token_index + token_count, 1 + ) + if tokens_match_count is False: + return None, None, None + + token_count += tokens_match_count + + if rule.is_pattern_match: + start_token_index -= prefix_tokens_match_count + text = "" + for k in range(prefix_tokens_match_count + token_count): + if k > 0: + text += tokens[start_token_index + k - 1].whitespace_ + + text += tokens[start_token_index + k].text + + skip_token_index = start_token_index + token_count + 1 + + if ( + start_token_index + tokens[start_token_index]._.token_index_offset + > skip_token_index + ): + skip_token_index = ( + start_token_index + tokens[start_token_index]._.token_index_offset + ) + + return skip_token_index, start_token_index, text diff --git a/app/rule_engine/utils.py b/app/rule_engine/utils.py new file mode 100644 index 000000000..b673a4229 --- /dev/null +++ b/app/rule_engine/utils.py @@ -0,0 +1,84 @@ +"""General rule engine utilities (token helpers, chunking, result helpers). + +These helpers are pure and do not depend on class state. +""" + +from typing import List, Optional +from spacy.tokens import Span, Token +from app.models import ( + ResultOut, + Config, + Client, + Language, + Alternative, + ResultSource, +) + + +def fetch_sentence_noun_chunks(sent: Span) -> List[Span]: + return [chunk for chunk in sent.noun_chunks] + + +def find_token_chunk(chunks: List[Span], token_index: int) -> Optional[Span]: + for chunk in chunks: + if chunk.start <= token_index < chunk.end: + return chunk + if chunk.start > token_index: + break + return None + + +def get_token_gender(token: Token) -> Optional[str]: + gender = token.morph.get("Gender") + if gender: + return gender[0] + return None + + +def is_target_noun(token: Token) -> bool: + dep = token.dep_ + return dep.endswith("subj") or dep.endswith("obj") or dep.startswith("obl") + + +def append_result( + out_list: list, + *, + config: Config, + client: Client, + language: Language, + text: str, + text_id: str, + full_text: str, + offsets: dict, + subcategory: str, + start: int, + end: Optional[int] = None, + alternatives: Optional[list[Alternative]] = None, + label: Optional[str] = None, + explanation: Optional[str] = None, + url: Optional[str] = None, + icon: Optional[str] = None, + explanation_context: Optional[str] = None, + source: Optional[ResultSource] = None, +) -> None: + out_list.append( + ResultOut.factory( + config, + client, + language, + text, + text_id, + full_text, + offsets, + subcategory, + start, + end, + alternatives, + label, + explanation, + url, + icon, + explanation_context, + source, + ) + ) diff --git a/app/rule_processors.py b/app/rule_processors.py new file mode 100644 index 000000000..714860cf4 --- /dev/null +++ b/app/rule_processors.py @@ -0,0 +1,445 @@ +""" +Rule Processing Functions +Core functions for processing Witty language rules. +""" + +import re +from inspect import currentframe + +from spacy.tokens import Doc + +from app.context import AppContext +from app.models import ( + Alternative, + Client, + Config, + GermanGenderEndingType, + Language, + LangType, + ResultOut, + Rule, +) +from app.categories import is_sub_category_enabled +from app.helper import is_valid_text +from app.text_utils import german_lemmatization + + +def check_continue( + list_full: list, + token_index: int, + new_token_index: int, + tokens: Doc, + func_name: str, + context: AppContext, +) -> bool: + if new_token_index == token_index: + return False + + if new_token_index < token_index: + cf = currentframe() + + text_id = list_full[-1].text_id if len(list_full) else "" + + context.logger.error( + "Incorrect new_token_index on line %i using '%s': expected %i < %i for '%s' versus '%s' for text_id '%s'", + cf.f_back.f_lineno, + func_name, + token_index, + new_token_index, + tokens[token_index].text, + tokens[new_token_index].text, + text_id, + ) + + return False + + return True + + +def is_bullet_point(line: str) -> bool: + line = line.strip() + if not line: + return False + + # Check for bullet characters + bullet_chars = {"-", "*", "•", "‣", "⁃", "⁌", "⁍", "◘", "◦", "⦾", "⦿"} + if line[0] in bullet_chars: + return True + + # Check for numbered list (e.g., "1.", "2)", "3:") + return bool(re.match(r"^\d+[).:]", line)) + + +async def german_gender_endings( + config: Config, + client: Client, + tokens: Doc, + offsets: dict, + language: Language, + text: str, + token_index: int, + list_full: list, + context: AppContext, +) -> int: + # shallow check to see if any of the delimiters is even contained + if not re.search("[/):_*I]", text): + return token_index + + subcategory = "d_and_i" + if is_sub_category_enabled(config.disabled_categories, subcategory): + word_types = ( + (-1, 1, config.german_gender_ending[0]) + if config.german_gender_ending.startswith("/") + else (None, None, config.german_gender_ending[0]) + ) + + endings = [ + Rule( + config.german_gender_ending + "", + LangType.DE, + config._gendereddenom_ending[config.german_gender_ending], + None, + config._gendereddenom_ending_word_type[config.german_gender_ending], + subcategory, + ), + ] + + if config.german_gender_ending in config._gendereddenom_ending_article: + endings.append( + Rule( + config.german_gender_ending + " article", + LangType.DE, + config._gendereddenom_ending_article[config.german_gender_ending], + None, + word_types, + subcategory, + ) + ) + + new_token_index = await context.regex_check.handle( + config, + client, + language, + text, + token_index, + tokens, + offsets, + list_full, + endings, + ) + + if check_continue( + list_full, token_index, new_token_index, tokens, "regex_match", context + ): + return new_token_index + + subcategory = "gendered_denominations_ending_advanced" + if is_sub_category_enabled( + config.disabled_categories, subcategory + ) and Config.gendered_roles_format_inclusive(config.gendered_roles_format): + endings = [] + for key, regexp in config._gendereddenom_ending.items(): + if config.german_gender_ending == key: + continue + + ending = Rule( + key + "", + LangType.DE, + regexp, + None, + config._gendereddenom_ending_word_type[key], + subcategory, + [Alternative(config.german_gender_ending)], + ) + + endings.append(ending) + + if ( + # GermanGenderEndingType.SLASH_DASH is redundant to GermanGenderEndingType.SLASH + key != GermanGenderEndingType.SLASH_DASH + # only check if relevant regexp is defined + and key in config._gendereddenom_ending_article + ): + word_types = ( + (-1, 2, key[0]) if key.startswith("/") else (None, None, key[0]) + ) + + ending = Rule( + key + "article", + LangType.DE, + config._gendereddenom_ending_article[key], + None, + word_types, + subcategory, + ) + + endings.append(ending) + + new_token_index = await context.regex_check.handle( + config, + client, + language, + text, + token_index, + tokens, + offsets, + list_full, + endings, + ) + + if check_continue( + list_full, token_index, new_token_index, tokens, "regex_match", context + ): + return new_token_index + + return token_index + + +async def witty_rules( + config: Config, + term_replacements: list[Rule], + client: Client, + tokens: Doc, + offsets: dict, + language: Language, + text: str, + context: AppContext, +) -> list: + false_positive_matcher = context.model.fetch_false_positive_matchers( + language.lang, tokens + ) + + list_full = [] + + new_token_index = 0 + token_count = len(tokens) + while new_token_index < token_count: + token_index = new_token_index + + token = tokens[token_index] + if token._.connected_token is not None: + new_token_index += 1 + continue + + if language.lang == LangType.DE: + token.lemma_ = await german_lemmatization(tokens, token_index, context) + + if len(term_replacements): + new_token_index = await context.rule_check.handle( + config, + client, + language, + text, + token_index, + tokens, + offsets, + list_full, + term_replacements, + ) + + if check_continue( + list_full, token_index, new_token_index, tokens, "rule_check", context + ): + continue + + if is_sub_category_enabled( + config.disabled_categories, "gender_specific_abbreviation" + ): + new_token_index = await context.regex_check.handle( + config, + client, + language, + text, + token_index, + tokens, + offsets, + list_full, + context.static_rules["m_f_regexes"], + ) + + if check_continue( + list_full, token_index, new_token_index, tokens, "regex_match", context + ): + continue + + new_token_index = context.emoji_check.handle( + config, + client, + language, + text, + token_index, + tokens, + offsets, + list_full, + ) + + if check_continue( + list_full, + token_index, + new_token_index, + tokens, + "detect_non_inclusive_emoji", + context, + ): + continue + + if token.text.startswith("#"): + new_token_index = await context.regex_check.handle( + config, + client, + language, + text, + token_index, + tokens, + offsets, + list_full, + context.static_rules[language.lang]["hashtags"], + ) + + if check_continue( + list_full, token_index, new_token_index, tokens, "regex_match", context + ): + continue + + valid_text = is_valid_text(language.lang, token.text) + if valid_text: + new_token_index = await context.rule_check.handle( + config, + client, + language, + text, + token_index, + tokens, + offsets, + list_full, + None, + false_positive_matcher, + ) + + if check_continue( + list_full, token_index, new_token_index, tokens, "rule_check", context + ): + continue + + if language.lang == LangType.DE: + new_token_index = await context.rule_check.handle( + config, + client, + language, + text, + token_index, + tokens, + offsets, + list_full, + None, + false_positive_matcher, + True, + ) + + if check_continue( + list_full, + token_index, + new_token_index, + tokens, + "rule_check", + context, + ): + continue + + if language.lang == LangType.DE: + new_token_index = await german_gender_endings( + config, + client, + tokens, + offsets, + language, + text, + token_index, + list_full, + context, + ) + + if check_continue( + list_full, token_index, new_token_index, tokens, "rule_check", context + ): + continue + + if len(token.text) > 18 and is_sub_category_enabled( + config.disabled_categories, "plain_language" + ): + subwords = ( + token.text.replace("/", "-") + .replace("@", "-") + .replace(":", "-") + .replace(".", "-") + .replace("_", "-") + .split("-") + ) + highlight = len(subwords) == 1 + for subword in subwords: + if len(subword) > 12: + highlight = True + break + + if highlight: + list_full.append( + ResultOut.factory( + config, + client, + language, + token.text, + token.text, + text, + offsets, + "plain_language", + token.idx, + explanation=language.translate("TOO_LONG_WORD"), + ) + ) + + new_token_index += 1 + + if is_sub_category_enabled(config.disabled_categories, "plain_language"): + sentence_word_limit = 30 + + for sent in tokens.sents: + if len(sent) <= sentence_word_limit: + continue + + sentences = {} + sentence_parts = [] + start = sent[0].idx + + lines = sent.text.split("\n") + for line in lines: + if is_bullet_point(line): + sentence = "\n".join(sentence_parts) + if sentence and sentence.count(" ") > sentence_word_limit: + sentences[start] = sentence + + start += len(sentence) + 1 + sentence_parts = [line] + else: + sentence_parts.append(line) + + sentence = "\n".join(sentence_parts) + if sentence and sentence.count(" ") >= sentence_word_limit: + sentences[start] = sentence + + for start in sentences: + list_full.append( + ResultOut.factory( + config, + client, + language, + sentences[start], + sentences[start], + text, + offsets, + "plain_language", + start, + explanation=language.translate("TOO_LONG_SENTENCE"), + ) + ) + + return list_full diff --git a/app/settings.py b/app/settings.py index 58acf692e..43d8e3b06 100644 --- a/app/settings.py +++ b/app/settings.py @@ -1,4 +1,5 @@ from typing import Optional +from functools import lru_cache import json import base64 from pydantic_settings import BaseSettings, SettingsConfigDict @@ -39,7 +40,7 @@ class Settings(BaseSettings): office_sso_client_id: Optional[str] = "" office_sso_expected_scope: Optional[str] = "" - sso_configs: dict = {} + sso_configs: dict[str, dict[str, Optional[str]]] = {} redis_host: Optional[str] = "" redis_port: Optional[str] = "" @@ -51,11 +52,12 @@ class Settings(BaseSettings): testing_email: Optional[str] = "" testing_rules: Optional[str] = "" testing_organization_rules: Optional[str] = "" + slack_enabled: bool = False slack_signing_secret: Optional[str] = "" slack_bot_token: Optional[str] = "" slack_organization_id: Optional[str] = "" alternatives_max_count: int = 5 - context_checker: dict = {} + context_checker: dict[str, dict[str, str]] = {} context_checker_url: Optional[str] = "" context_checker_api_key: Optional[str] = "" context_checker_url_de: Optional[str] = "" @@ -69,15 +71,15 @@ class Settings(BaseSettings): ] minimum_version_web_ext: Optional[str] = "" minimum_version_word_plugin: Optional[str] = "" - minimum_versions: dict = {} + minimum_versions: dict[str, str] = {} model_config = SettingsConfigDict(env_file=".env", extra="ignore") import_from_dump: bool = True log_missing_declension: bool = True + log_metrics: Optional[bool] = False aws_region_name: Optional[str] = "" aws_key: Optional[str] = "" aws_secret_key: Optional[str] = "" aws_model_id: Optional[str] = "mistral.mixtral-8x7b-instruct-v0:1" - log_metrics: Optional[bool] = False @staticmethod def factory(): @@ -154,3 +156,16 @@ def factory(): settings.redis_verify_ssl = False return settings + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + return Settings.factory() + + +def reset_settings_cache() -> None: + """Clear the cached settings instance (primarily for tests).""" + try: + get_settings.cache_clear() # type: ignore[attr-defined] + except Exception: + pass diff --git a/app/startup.py b/app/startup.py new file mode 100644 index 000000000..a9f57b948 --- /dev/null +++ b/app/startup.py @@ -0,0 +1,141 @@ +"""Application startup and lifecycle management.""" + +import os +import json +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from app.context import AppContext +from app.db import Db +from app.http import Http +from app.nouns import Nouns +from app.verbs import Verbs +from app.adjectives import Adjectives +from app.alternatives import Alternatives +from app.languagetool import LanguageTool +from app.prompt import Prompt +from app.llm_alternatives import LlmAlternatives +from app.rule_check import RuleCheck +from app.regex_check import RegexCheck +from app.emoji_check import EmojiCheck +from app.config_manager import parse_term_replacements + + +@asynccontextmanager +async def lifespan(app: FastAPI, context: AppContext): + # Startup + # Expose context on FastAPI app.state for DI-friendly access + app.state.context = context + if os.environ.get("BLACKFIRE_ENABLE_CONTINUOUS_PROFILING"): + try: + from blackfire_conprof.profiler import Profiler + + app_name = os.environ.get("PLATFORM_APPLICATION_NAME") + profiler = Profiler(application_name=app_name) + profiler.start() + + print(f"Profiler started for {app_name}") + except Exception: + pass + + # Initialize HTTP client + context.http = Http(context.settings, context.logger) + + # Configure SQLite logging + sqlite_logger = logging.getLogger("aiosqlite") + sqlite_logger.setLevel(logging.WARNING) + + # Initialize database + context.db = await Db.factory(context.settings, context.languages) + context.model.db = context.db + + # Load testing rules if configured + if context.settings.testing_rules: + rules = json.loads(context.settings.testing_rules) + + rules["term_replacements"] = parse_term_replacements( + rules["term_replacements"], context + ) + email = rules["email"] + context.redis.db.set(context.redis.get_user_id(email), json.dumps(rules)) + + if context.settings.testing_organization_rules: + organization_rules = json.loads(context.settings.testing_organization_rules) + + organization_rules["term_replacements"] = parse_term_replacements( + organization_rules["term_replacements"], context + ) + key = organization_rules["id"] + context.redis.db.set(key, json.dumps(organization_rules)) + + if context.settings.redis_log_emails: + log_emails = json.loads(context.settings.redis_log_emails) + for email in log_emails: + context.redis.db.lpush("debug_emails", email) + + # Initialize language processors + context.nouns = Nouns( + context.settings, + context.logger, + context.static_rules, + context.model, + context.db, + ) + context.verbs = Verbs( + context.settings, + context.logger, + context.static_rules, + context.model, + context.db, + ) + context.adjectives = Adjectives(context.settings, context.logger, context.db) + context.alternatives = Alternatives( + context.settings, + context.logger, + context.static_rules, + context.db, + context.model, + context.nouns, + context.verbs, + context.adjectives, + ) + context.languagetool = LanguageTool( + context.settings, + context.logger, + context.static_rules, + context.db, + context.categories, + context.http, + ) + context.prompt = Prompt(context.settings) + context.llm_alternatives = LlmAlternatives( + context.settings, context.alternatives, context.prompt + ) + context.rule_check = RuleCheck( + context.settings, + context.logger, + context.static_rules, + context.model, + context.db, + context.nouns, + context.verbs, + context.adjectives, + context.alternatives, + ) + context.regex_check = RegexCheck( + context.settings, context.logger, context.static_rules, context.nouns + ) + context.emoji_check = EmojiCheck( + context.settings, context.logger, context.static_rules + ) + + yield + + # Shutdown + await context.http.close() + await context.db.close() + # Clean up reference on app.state + if hasattr(app.state, "context"): + delattr(app.state, "context") diff --git a/app/text_utils.py b/app/text_utils.py new file mode 100644 index 000000000..ed537af91 --- /dev/null +++ b/app/text_utils.py @@ -0,0 +1,150 @@ +"""Text and word type parsing utilities.""" + +from spacy.tokens import Doc + +from app.context import AppContext +from app.models import LangType, WordType +from app.helper import remove_gender_ending + + +def parse_word_type(word_type: str, lower_case: bool = True) -> tuple[str, bool, bool]: + lemmatize = True + + if not word_type: + return "", lower_case, lemmatize + + match word_type[0]: + case "~": + # exact match + lower_case = True + lemmatize = False + word_type = word_type[1:] + case "=": + # exact match + lower_case = False + lemmatize = False + word_type = word_type[1:] + case "-": + # force lower case off + lower_case = False + lemmatize = True + word_type = word_type[1:] + + return word_type, lower_case, lemmatize + + +async def german_lemmatization( + tokens: Doc, token_index: int, context: AppContext +) -> str: + token = tokens[token_index] + word_type = await context.model.fetch_word_type(LangType.DE, token) + + match word_type: + case WordType.NOUN: + if ( + token.text != token.lemma_ + or not token.text[0].isupper() + or len(token.text) <= 3 + ): + return token.lemma_ + + word = remove_gender_ending(token.text) + + result = await context.nouns.german_noun_lookup(word, token) + if result is not None: + target = "male_form" if result["male_form"] else "base_form" + return result[target] + + case WordType.VERB: + verb_form = token.morph.get("VerbForm") + verb_form = verb_form[0] if len(verb_form) else "" + + if verb_form not in context.verb_form_map: + return token.lemma_ + + column_name = False + if isinstance(context.verb_form_map[verb_form], dict): + tense = token.morph.get("Tense") + tense = tense[0] if len(tense) else "" + person = token.morph.get("Person") + person = person[0] if len(person) else "" + + if ( + tense in context.verb_form_map[verb_form] + and person in context.verb_form_map[verb_form][tense] + ): + parameters = [token.text + "%"] + operator = "LIKE" + column_name = context.verb_form_map[verb_form][tense][person] + else: + if token_index > 0 and tokens[token_index - 1].text == "zu": + parameters = ["zu " + token.text] + prev = True + else: + parameters = [token.text] + prev = False + operator = "=" + column_name = context.verb_form_map[verb_form] + + if column_name: + table_name = context.declensions_config[LangType.DE][WordType.VERB][ + "name" + ] + query = f"SELECT base_form, {column_name} FROM {table_name} WHERE {column_name} {operator} ? LIMIT 1" + + rows = await context.db.fetch_rows(query, parameters) + if len(rows): + if operator == "LIKE": + token_index_offset = 1 + for sentence_token in token.sent: + if sentence_token.i > token.i: + token_index_offset += 1 + if ( + tokens[token_index].text + + " " + + sentence_token.text.lower() + == rows[0][1] + ): + if ( + sentence_token.text.lower() == "schwarz" + and token_index_offset > 2 + ): + sentence_token._.connected_token = token + token._.child_token = sentence_token + token._.label = sentence_token._.label = ( + tokens[token_index].text + + " .. " + + sentence_token.text.lower() + ) + else: + token._.token_index_offset = token_index_offset + + await context.db.fetch_declensions( + LangType.DE, WordType.VERB, rows[0][0], token + ) + + token._.form = column_name + if token_index_offset == 2: + token._.text = ( + tokens[token_index].text + + tokens[token_index].whitespace_ + + sentence_token.text + ) + return rows[0][0] + return token.lemma_ + + if prev: + token._.start = tokens[token_index - 1].idx + token._.text = ( + tokens[token_index - 1].text + + tokens[token_index - 1].whitespace_ + + tokens[token_index].text + ) + token._.form = column_name + + await context.db.fetch_declensions( + LangType.DE, WordType.VERB, rows[0][0] + ) + return rows[0][0] + + return token.lemma_ diff --git a/app/version_validators.py b/app/version_validators.py new file mode 100644 index 000000000..36acfedbf --- /dev/null +++ b/app/version_validators.py @@ -0,0 +1,52 @@ +"""Validation functions for API versions and client versions. + +Also exposes typed version constants. +""" + +from fastapi import HTTPException, status +from cmp_version import VersionString +from typing import Final, Literal + +from app.models import Client + + +# Typed API version constants +CHECK_API_VERSION: Final[Literal["2.4"]] = "2.4" +REPHRASE_API_VERSION: Final[Literal["1.0"]] = "1.0" + + +def rephrase_api_version(version: str) -> None: + if version != REPHRASE_API_VERSION: # pragma: no cover + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"API version '{version}' not supported, please use version '{REPHRASE_API_VERSION}'." + ), + ) + + +def check_api_version(version: str) -> None: + if version != CHECK_API_VERSION: # pragma: no cover + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"API version '{version}' not supported, please use version '{CHECK_API_VERSION}'." + ), + ) + + +def client_version( + client: Client, minimum_versions: dict[str, str] | None = None +) -> None: + if not minimum_versions: + return + + if client.name in minimum_versions and VersionString( + client.version or "0.0.0" + ) < VersionString( + minimum_versions[client.name] + ): # pragma: no cover + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Client version '{client.version}' not supported, please use at least '{minimum_versions[client.name]}'.", + ) diff --git a/tests/test_api.py b/tests/test_api.py index 4f93d96fe..b0e4f6453 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,6 +6,8 @@ from app.main import ( app, context, +) +from app.config_manager import ( fetch_configs_for_request, parse_term_replacements, ) @@ -538,7 +540,7 @@ def set_redis(): } user_object["term_replacements"] = parse_term_replacements( - user_object["term_replacements"] + user_object["term_replacements"], context ) context.redis.db.set( context.redis.get_user_id(user_object["email"]), json.dumps(user_object) @@ -740,7 +742,7 @@ def set_redis(): } user_object["term_replacements"] = parse_term_replacements( - user_object["term_replacements"] + user_object["term_replacements"], context ) context.redis.db.set( context.redis.get_user_id(user_object["email"]), json.dumps(user_object) @@ -820,7 +822,7 @@ def set_redis(): } organization_object["term_replacements"] = parse_term_replacements( - organization_object["term_replacements"] + organization_object["term_replacements"], context ) context.redis.db.set(organization_object["id"], json.dumps(organization_object)) @@ -1143,7 +1145,7 @@ def test_fetch_configs_for_request(event_loop, set_redis): } test_request = CheckRequestIn(**request_data) event_loop.run_until_complete( - fetch_configs_for_request(test_request, "test@gmail.com") + fetch_configs_for_request(test_request, "test@gmail.com", context) ) assert hasattr(test_request.config, "store_context") assert test_request.config.store_context is True @@ -1170,7 +1172,7 @@ def test_fetch_user_rules_suggestion(event_loop, set_redis): } test_request = CheckRequestIn(**request_data) event_loop.run_until_complete( - fetch_configs_for_request(test_request, "non_existant@gmail.com") + fetch_configs_for_request(test_request, "non_existant@gmail.com", context) ) assert test_request.config.store_context is True assert test_request.config.llm_alternatives is False @@ -1190,7 +1192,7 @@ def test_set_organization_rules(event_loop, set_redis): } test_request = CheckRequestIn(**request_data) event_loop.run_until_complete( - fetch_configs_for_request(test_request, "test@gmail.com") + fetch_configs_for_request(test_request, "test@gmail.com", context) ) assert test_request.config.store_context is True assert test_request.config.llm_alternatives is True @@ -1209,7 +1211,7 @@ def test_set_default_rules(event_loop): test_request = CheckRequestIn(**request_data) event_loop.run_until_complete( - fetch_configs_for_request(test_request, "non_existant@gmail.com") + fetch_configs_for_request(test_request, "non_existant@gmail.com", context) ) assert test_request.config.store_context is True assert test_request.config.llm_alternatives is False diff --git a/tests/test_slack.py b/tests/test_slack.py new file mode 100644 index 000000000..ee3cb1d15 --- /dev/null +++ b/tests/test_slack.py @@ -0,0 +1,91 @@ +import time +import hmac +import hashlib +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +def slack_signature(secret: str, body: str, ts: str | None = None) -> tuple[str, str]: + ts = ts or str(int(time.time())) + base = f"v0:{ts}:{body}".encode() + sig = "v0=" + hmac.new(secret.encode(), base, hashlib.sha256).hexdigest() + return sig, ts + + +def test_register_routes_without_slack(monkeypatch): + # Import locally to allow monkeypatching settings getter + import app.routes as routes_pkg + + class FakeSettings: + slack_enabled = False + + monkeypatch.setattr(routes_pkg, "get_settings", lambda: FakeSettings()) + + app = FastAPI() + routes_pkg.register_routes(app) + + client = TestClient(app) + # Slack commands endpoint should not be present + r = client.post("/slack/commands") + assert r.status_code in (404, 405) + + +def test_register_routes_with_slack_but_not_initialized(monkeypatch): + # Import locally to allow monkey-patching settings getter + import app.routes as routes_pkg + + class FakeSettings: + slack_enabled = True + + monkeypatch.setattr(routes_pkg, "get_settings", lambda: FakeSettings()) + + app = FastAPI() + routes_pkg.register_routes(app) + + client = TestClient(app) + # Router is included, but Bolt handler is not initialized -> 503 + r = client.post("/slack/commands") + assert r.status_code == 503 + + +@pytest.mark.asyncio +async def test_process_command_witty_blocks_smoke(monkeypatch): + from app.bolt import process_command_witty + from app.models import Language, ResultOut, ResultExplanation + + # Minimal objects to simulate one result + language = Language(locale="en-US", translations={}) + result = ResultOut( + text="term", + text_id="term", + category="general", + subcategory="hidden_image", + start=0, + end=4, + alternatives=[], + label="Label", + explanation=ResultExplanation(text="Explain", url=None, icon=None), + proficiency_level="basic", + ) + + class DummyRespond: + def __init__(self): + self.calls = [] + + async def __call__(self, *, blocks): # Slack SDK passes named arg + self.calls.append(blocks) + + respond = DummyRespond() + + await process_command_witty( + text="Hello", + language=language, + limit_reached=False, + results=[result], + respond=respond, + ) + + assert len(respond.calls) == 1 + # Basic sanity: at least two sections (header + details) + assert len(respond.calls[0]) >= 2 From 5f940309f3939150a9775dc7dfce306fefc70c7c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 10 Nov 2025 15:00:49 +0000 Subject: [PATCH 2/3] Initial plan From 4f7ee55954a874aaca96533fe7181d64af1f5c78 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 10 Nov 2025 15:05:11 +0000 Subject: [PATCH 3/3] Fix mixed implicit and explicit returns in handle_command_witty Co-authored-by: lsmith77 <300279+lsmith77@users.noreply.github.com> --- app/routes/slack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/routes/slack.py b/app/routes/slack.py index 250c75d9c..6fa57fe62 100644 --- a/app/routes/slack.py +++ b/app/routes/slack.py @@ -43,7 +43,7 @@ async def handle_command_witty( app_context: AppContext | None = context.get("app_context") if app_context is None: await respond("Service not ready yet. Please try again shortly.") - return + return None check_request_in = CheckRequestIn(client="slack:1.0.0", text=body["text"]) text, language, limit_reached = fetch_text( @@ -52,7 +52,7 @@ async def handle_command_witty( if language is None: await respond(f"Witty could not determine a language for '{text}'.") - return + return None configs = {}