Skip to content

Commit

Permalink
Merge pull request #10 from a-belov/patch-3
Browse files Browse the repository at this point in the history
Update predictor.py (compatibility with Python 3.11)
  • Loading branch information
IlyaGusev authored May 22, 2024
2 parents 87d1cd6 + d50abed commit 74f5cf2
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions rnnmorph/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,29 @@
from rnnmorph.settings import MODELS_PATHS


def pymorphy2_hotfix():
# src: https://github.com/pymorphy2/pymorphy2/issues/160#issuecomment-1486657176
from inspect import getfullargspec
from pymorphy2.units.base import BaseAnalyzerUnit

def _get_param_names_fixed(klass):
if klass.__init__ is object.__init__:
return []
args = getfullargspec(klass.__init__).args
return sorted(args[1:])

setattr(BaseAnalyzerUnit, '_get_param_names', _get_param_names_fixed)


class Predictor:
"""
Интерфейс POS-теггера.
"""

def predict(self, words: List[str], include_all_forms: bool) -> List[WordFormOut]:
"""
Предсказать теги для одного предложения.
:param words: массив слов (знаки препинания - отдельные токены).
:param include_all_forms: флаг, включающий все варианты разбора.
:return: массив форм с леммами, тегами и оригинальными словами.
Expand All @@ -35,7 +50,7 @@ def predict_sentences(self, sentences: List[List[str]], batch_size: int,
include_all_forms: bool) -> List[List[WordFormOut]]:
"""
Предсказать теги для массива предложений. В сетку как batch загружается.
:param sentences: массив предложений.
:param batch_size: размер батча.
:param include_all_forms: флаг, включающий все варианты разбора.
Expand All @@ -48,15 +63,16 @@ class RNNMorphPredictor(Predictor):
"""
POS-теггер на освное RNN.
"""

def __init__(self,
language="ru",
eval_model_config_path: str=None,
eval_model_weights_path: str=None,
gram_dict_input: str=None,
gram_dict_output: str=None,
word_vocabulary: str=None,
char_set_path: str=None,
build_config: str=None):
eval_model_config_path: str = None,
eval_model_weights_path: str = None,
gram_dict_input: str = None,
gram_dict_output: str = None,
word_vocabulary: str = None,
char_set_path: str = None,
build_config: str = None):
if eval_model_config_path is None:
eval_model_config_path = MODELS_PATHS[language]["eval_model_config"]
if eval_model_weights_path is None:
Expand All @@ -74,7 +90,11 @@ def __init__(self,

self.language = language
self.converter = converters.converter('opencorpora-int', 'ud14') if language == "ru" else None
self.morph = MorphAnalyzer() if language == "ru" else None
if language == "ru":
pymorphy2_hotfix()
self.morph = MorphAnalyzer()
else:
None
if self.language == "en":
nltk.download("wordnet")
nltk.download('averaged_perceptron_tagger')
Expand All @@ -87,12 +107,12 @@ def __init__(self,
self.model.prepare(gram_dict_input, gram_dict_output, word_vocabulary, char_set_path)
self.model.load_eval(self.build_config, eval_model_config_path, eval_model_weights_path)

def predict(self, words: List[str], include_all_forms: bool=False) -> List[WordFormOut]:
def predict(self, words: List[str], include_all_forms: bool = False) -> List[WordFormOut]:
words_probabilities = self.model.predict_probabilities([words], 1, self.build_config)[0]
return self.__get_sentence_forms(words, words_probabilities, include_all_forms)

def predict_sentences(self, sentences: List[List[str]], batch_size: int=64,
include_all_forms: bool=False) -> List[List[WordFormOut]]:
def predict_sentences(self, sentences: List[List[str]], batch_size: int = 64,
include_all_forms: bool = False) -> List[List[WordFormOut]]:
sentences_probabilities = self.model.predict_probabilities(sentences, batch_size, self.build_config)
answers = []
for words, words_probabilities in zip(sentences, sentences_probabilities):
Expand All @@ -103,7 +123,7 @@ def __get_sentence_forms(self, words: List[str], words_probabilities: List[List[
include_all_forms: bool) -> List[WordFormOut]:
"""
Получить теги и формы.
:param words: слова.
:param words_probabilities: вероятности тегов слов.
:param include_all_forms: флаг, включающий все варианты разбора.
Expand Down Expand Up @@ -153,10 +173,10 @@ def __compose_out_form(self, word: str, probabilities: List[float],
return result_form

def __get_lemma(self, word: str, pos_tag: str, gram: str, word_forms=None,
enable_normalization: bool=True):
enable_normalization: bool = True):
"""
Получить лемму.
:param word: слово.
:param pos_tag: часть речи.
:param gram: граммаическое значение.
Expand Down Expand Up @@ -201,7 +221,7 @@ def __get_lemma(self, word: str, pos_tag: str, gram: str, word_forms=None,
def __normalize_for_gikrya(form):
"""
Поучение леммы по правилам, максимально близким к тем, которые в корпусе ГИКРЯ.
:param form: форма из pymorphy2.
:return: леммма.
"""
Expand Down

0 comments on commit 74f5cf2

Please sign in to comment.