Skip to content

Commit

Permalink
Fix merge conflicts with latest develop.
Browse files Browse the repository at this point in the history
  • Loading branch information
dafajon committed Apr 15, 2021
2 parents 90aadc3 + 2e3e504 commit 32a2dc3
Show file tree
Hide file tree
Showing 39 changed files with 1,233 additions and 147 deletions.
2 changes: 1 addition & 1 deletion prod.requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
loguru>=0.5.1
click>=7.1.2

smart-open>=2.1.0
smart-open==2.0.0

uvicorn>=0.11.8
fastapi>=0.61.0
Expand Down
48 changes: 23 additions & 25 deletions sadedegel/bblock/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def raw_tf(self, drop_stopwords=False, lowercase=False, drop_suffix=False, drop_
v = np.zeros(self.vocabulary.size_cs)

if lowercase:
tokens = [tr_lower(t) for t in self.tokens]
tokens = [t.lower_ for t in self.tokens]
else:
tokens = self.tokens
tokens = [t.word for t in self.tokens]

counter = Counter(tokens)

Expand Down Expand Up @@ -304,7 +304,6 @@ def __init__(self, id_: int, text: str, doc, config: dict = {}):
self.id = id_
self.text = text

self._tokens = None
self.document = doc
self.config = doc.builder.config
self._bert = None
Expand Down Expand Up @@ -332,7 +331,7 @@ def __init__(self, id_: int, text: str, doc, config: dict = {}):
f"Unknown term frequency method {self.tf_method}. Choose on of {','.join(TF_METHOD_VALUES)}")

@property
def avgdl(self) -> int:
def avgdl(self) -> float:
"""Average number of tokens per sentence"""
return self.config['default'].getfloat('avg_sentence_length')

Expand Down Expand Up @@ -361,17 +360,17 @@ def input_ids(self):
return self.tokenizer.convert_tokens_to_ids(self.tokens_with_special_symbols)

@cached_property
def tokens(self):
return self.tokenizer(self.text)
def tokens(self) -> List[Token]:
return [t for t in self.tokenizer(self.text)]

@property
def tokens_with_special_symbols(self):
return ['[CLS]'] + self.tokens + ['[SEP]']
return [Token('[CLS]')] + self.tokens + [Token('[SEP]')]

def rouge1(self, metric):
def rouge1(self, metric) -> float:
return rouge1_score(
flatten([[tr_lower(token) for token in sent.tokens] for sent in self.document if sent.id != self.id]),
[tr_lower(t) for t in self.tokens], metric)
flatten([[t.lower_ for t in sent] for sent in self.document if sent.id != self.id]),
[t.lower_ for t in self], metric)

@property
def bm25(self) -> np.float32:
Expand Down Expand Up @@ -418,8 +417,7 @@ def tf(self):
def idf(self):
v = np.zeros(len(self.vocabulary))

for token in self.tokens:
t = self.vocabulary[token]
for t in self.tokens:
if not t.is_oov:
v[t.id] = t.idf

Expand All @@ -438,11 +436,10 @@ def __eq__(self, s: str):
return self.text == s # no need for type checking, will return false for non-strings

def __getitem__(self, token_ix):
return Token(self.tokens[token_ix])
return self.tokens[token_ix]

def __iter__(self):
for t in self.tokens:
yield Token(t)
yield from self.tokens


class Document(TFImpl, IDFImpl, BM25Impl):
Expand All @@ -459,19 +456,18 @@ def __init__(self, raw, builder):
self.config = self.builder.config

@property
def avgdl(self) -> int:
def avgdl(self) -> float:
"""Average number of tokens per document"""
return self.config['default'].getfloat('avg_document_length')

@property
def tokens(self):
if self._tokens is None:
self._tokens = []
for s in self:
for t in s.tokens:
self._tokens.append(t)
@cached_property
def tokens(self) -> List[str]:
tokens = []
for s in self:
for t in s.tokens:
tokens.append(t)

return self._tokens
return tokens

@property
def vocabulary(self):
Expand Down Expand Up @@ -606,7 +602,9 @@ def __init__(self, **kwargs):

tokenizer_str = normalize_tokenizer_name(self.config['default']['tokenizer'])

self.tokenizer = WordTokenizer.factory(tokenizer_str)
self.tokenizer = WordTokenizer.factory(tokenizer_str, emoji=self.config['tokenizer'].getboolean('emoji'),
hashtag=self.config['tokenizer'].getboolean('hashtag'),
mention=self.config['tokenizer'].getboolean('mention'))

Token.set_vocabulary(self.tokenizer.vocabulary)

Expand Down
File renamed without changes.
28 changes: 20 additions & 8 deletions sadedegel/bblock/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from math import log

import numpy as np
from cached_property import cached_property

from .util import tr_lower, load_stopwords, deprecate, ConfigNotSet, VocabularyIsNotSet, WordVectorNotFound
from .vocabulary import Vocabulary
Expand All @@ -25,13 +26,7 @@ def get_idf(self, method=IDF_SMOOTH, drop_stopwords=False, lowercase=False, drop
else:
v = np.zeros(self.vocabulary.size_cs)

if lowercase:
tokens = [tr_lower(t) for t in self.tokens]
else:
tokens = self.tokens

for token in tokens:
t = Token(token)
for t in self.tokens:
if t.is_oov or (drop_stopwords and t.is_stopword) or (drop_suffix and t.is_suffix) or (
drop_punct and t.is_punct):
continue
Expand Down Expand Up @@ -106,7 +101,9 @@ def _create_token(cls, word: str):
token.is_punct = all(unicodedata.category(c).startswith("P") for c in token.word)
token.is_digit = token.word.isdigit()
token.is_suffix = token.word.startswith('##')
token.shape = word_shape(token.word)
token.is_emoji = False
token.is_hashtag = False
token.is_mention = False

return token

Expand All @@ -117,6 +114,17 @@ def __new__(cls, word: str):

return cls.cache[word]

def __len__(self):
return len(self.word)

def __eq__(self, other):
if type(other) == str:
return self.word == other
elif type(other) == Token:
return self.word == other.word
else:
raise TypeError(f"Unknown comparison type with Token {type(other)}")

@classmethod
def set_vocabulary(cls, vocab: Vocabulary):
Token.vocabulary = vocab
Expand Down Expand Up @@ -236,6 +244,10 @@ def vector(self) -> np.ndarray:
else:
raise WordVectorNotFound(self.word)

@cached_property
def shape(self) -> str:
return word_shape(self.word)

def __str__(self):
return self.word

Expand Down
12 changes: 8 additions & 4 deletions sadedegel/bblock/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@


def tr_lower(s: str) -> str:
return s.replace("I", "ı").replace("İ", "i").lower()
if "I" in s or "İ" in s:
return s.replace("I", "ı").replace("İ", "i").lower()
else:
return s.lower()


def tr_upper(s: str) -> str:
Expand Down Expand Up @@ -161,14 +164,15 @@ def load_stopwords(base_path=None):
return stopwords


def deprecate(message: str, eol_version: tuple):
def deprecate(message: str, eol_version: tuple, post_message: str = None):
current = tuple([int(v) for v in __version__.split('.')])

if current >= eol_version:
console.print(f"[red]{message}[/red]")
console.print(f"[red]{message}[/red]. {post_message}")
sys.exit(1)
else:
console.print(f"[magenta]{message}[/magenta], will be dropped by {'.'.join(map(str, eol_version))}")
console.print(
f"{message}, will be [magenta]dropped[/magenta] by {'.'.join(map(str, eol_version))}. {post_message}")


class ConfigNotSet(Exception):
Expand Down
Loading

0 comments on commit 32a2dc3

Please sign in to comment.