-
Notifications
You must be signed in to change notification settings - Fork 293
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Correcting imports, making mean isotonic support decreasing.
- Loading branch information
1 parent
c7f2010
commit 1a666b8
Showing
26 changed files
with
301 additions
and
285 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,113 +1,111 @@ | ||
import numpy as np | ||
import sys | ||
|
||
from scattertext import TermCategoryFrequencies, CorpusDF | ||
from scattertext.Corpus import Corpus | ||
from scattertext.ParsedCorpus import ParsedCorpus | ||
|
||
|
||
class CorpusShouldBeParsedCorpusException(Exception): | ||
pass | ||
pass | ||
|
||
|
||
class DocsAndLabelsFromCorpus: | ||
def __init__(self, corpus, alternative_text_field=None): | ||
''' | ||
Parameters | ||
---------- | ||
corpus, Corpus: Corpus to extract documents and labels from | ||
alternative_text_field, str or None: if str, corpus must be parsed corpus | ||
''' | ||
#assert (isinstance(corpus, (Corpus, ParsedCorpus, CorpusDF, TermCategoryFrequencies)) | ||
# or (issubclass(type(corpus), (Corpus, ParsedCorpus, CorpusDF, TermCategoryFrequencies)))) | ||
self._texts_to_display = None | ||
if alternative_text_field is not None: | ||
if not isinstance(corpus, ParsedCorpus): | ||
raise CorpusShouldBeParsedCorpusException( | ||
'Corpus type needs to be ParsedCorpus to use the alternative text field.') | ||
self._texts_to_display = corpus.get_field(alternative_text_field) | ||
self._use_non_text_features = False | ||
self._use_terms_for_extra_features = False | ||
self._corpus = corpus | ||
|
||
def use_non_text_features(self): | ||
self._use_non_text_features = True | ||
return self | ||
|
||
def use_terms_for_extra_features(self): | ||
self._use_terms_for_extra_features = True | ||
return self | ||
|
||
def get_labels_and_texts(self): | ||
# type: () -> dict | ||
texts = self._get_texts_to_display() | ||
to_ret = {'categories': self._corpus.get_categories(), | ||
'labels': self._corpus.get_doc_indices(), | ||
'texts': self._get_list_from_texts(texts)} | ||
if self._use_non_text_features: | ||
to_ret['extra'] = self._corpus.list_extra_features(not self._use_terms_for_extra_features) | ||
return to_ret | ||
|
||
def _get_list_from_texts(self, texts): | ||
if sys.version_info[0] == 2: | ||
return texts.astype(unicode).tolist() | ||
else: | ||
return texts.astype(str).tolist() | ||
|
||
def _get_texts_to_display(self): | ||
if self._there_are_no_alternative_texts_to_display(): | ||
return self._corpus.get_texts() | ||
else: | ||
return self._texts_to_display | ||
|
||
def _there_are_no_alternative_texts_to_display(self): | ||
return self._texts_to_display is None | ||
|
||
def get_labels_and_texts_and_meta(self, metadata): | ||
# type: (np.array) -> dict | ||
data = self.get_labels_and_texts() | ||
assert len(metadata) == len(data['labels']) | ||
data['meta'] = list(metadata) | ||
return data | ||
def __init__(self, corpus, alternative_text_field=None): | ||
''' | ||
Parameters | ||
---------- | ||
corpus, Corpus: Corpus to extract documents and labels from | ||
alternative_text_field, str or None: if str, corpus must be parsed corpus | ||
''' | ||
# assert (isinstance(corpus, (Corpus, ParsedCorpus, CorpusDF, TermCategoryFrequencies)) | ||
# or (issubclass(type(corpus), (Corpus, ParsedCorpus, CorpusDF, TermCategoryFrequencies)))) | ||
self._texts_to_display = None | ||
if alternative_text_field is not None: | ||
if not isinstance(corpus, ParsedCorpus): | ||
raise CorpusShouldBeParsedCorpusException( | ||
'Corpus type needs to be ParsedCorpus to use the alternative text field.') | ||
self._texts_to_display = corpus.get_field(alternative_text_field) | ||
self._use_non_text_features = False | ||
self._use_terms_for_extra_features = False | ||
self._corpus = corpus | ||
|
||
def use_non_text_features(self): | ||
self._use_non_text_features = True | ||
return self | ||
|
||
def use_terms_for_extra_features(self): | ||
self._use_terms_for_extra_features = True | ||
return self | ||
|
||
def get_labels_and_texts(self): | ||
# type: () -> dict | ||
texts = self._get_texts_to_display() | ||
to_ret = {'categories': self._corpus.get_categories(), | ||
'labels': self._corpus.get_doc_indices(), | ||
'texts': self._get_list_from_texts(texts)} | ||
if self._use_non_text_features: | ||
to_ret['extra'] = self._corpus.list_extra_features(not self._use_terms_for_extra_features) | ||
return to_ret | ||
|
||
def _get_list_from_texts(self, texts): | ||
if sys.version_info[0] == 2: | ||
return texts.astype(unicode).tolist() | ||
else: | ||
return texts.astype(str).tolist() | ||
|
||
def _get_texts_to_display(self): | ||
if self._there_are_no_alternative_texts_to_display(): | ||
return self._corpus.get_texts() | ||
else: | ||
return self._texts_to_display | ||
|
||
def _there_are_no_alternative_texts_to_display(self): | ||
return self._texts_to_display is None | ||
|
||
def get_labels_and_texts_and_meta(self, metadata): | ||
# type: (np.array) -> dict | ||
data = self.get_labels_and_texts() | ||
assert len(metadata) == len(data['labels']) | ||
data['meta'] = list(metadata) | ||
return data | ||
|
||
|
||
class DocsAndLabelsFromCorpusSample(DocsAndLabelsFromCorpus): | ||
def __init__(self, corpus, max_per_category, alternative_text_field=None, seed=None): | ||
DocsAndLabelsFromCorpus.__init__(self, corpus, alternative_text_field) | ||
self.max_per_category = max_per_category | ||
if seed is not None: | ||
np.random.seed(seed) | ||
|
||
def get_labels_and_texts(self, metadata=None): | ||
''' | ||
Parameters | ||
---------- | ||
metadata : (array like or None) | ||
Returns | ||
------- | ||
{'labels':[], 'texts': []} or {'labels':[], 'texts': [], 'meta': []} | ||
''' | ||
to_ret = {'categories': self._corpus.get_categories(), 'labels': [], 'texts': []} | ||
labels = self._corpus._y.astype(int) | ||
texts = self._get_texts_to_display() | ||
if self._use_non_text_features: | ||
to_ret['extra'] = [] | ||
extrafeats = self._corpus.list_extra_features() | ||
if metadata is not None: | ||
to_ret['meta'] = [] | ||
for label_i in range(len(self._corpus._category_idx_store)): | ||
label_indices = np.arange(0, len(labels))[labels == label_i] | ||
if self.max_per_category < len(label_indices): | ||
label_indices = np.random.choice(label_indices, self.max_per_category, replace=False) | ||
to_ret['labels'] += list([int(e) for e in labels[label_indices]]) | ||
to_ret['texts'] += list(texts[label_indices]) | ||
if metadata is not None: | ||
to_ret['meta'] += [metadata[i] for i in label_indices] | ||
if self._use_non_text_features: | ||
to_ret['extra'] += [extrafeats[i] for i in label_indices] | ||
|
||
return to_ret | ||
|
||
def get_labels_and_texts_and_meta(self, metadata): | ||
return self.get_labels_and_texts(metadata) | ||
def __init__(self, corpus, max_per_category, alternative_text_field=None, seed=None): | ||
DocsAndLabelsFromCorpus.__init__(self, corpus, alternative_text_field) | ||
self.max_per_category = max_per_category | ||
if seed is not None: | ||
np.random.seed(seed) | ||
|
||
def get_labels_and_texts(self, metadata=None): | ||
''' | ||
Parameters | ||
---------- | ||
metadata : (array like or None) | ||
Returns | ||
------- | ||
{'labels':[], 'texts': []} or {'labels':[], 'texts': [], 'meta': []} | ||
''' | ||
to_ret = {'categories': self._corpus.get_categories(), 'labels': [], 'texts': []} | ||
labels = self._corpus._y.astype(int) | ||
texts = self._get_texts_to_display() | ||
if self._use_non_text_features: | ||
to_ret['extra'] = [] | ||
extrafeats = self._corpus.list_extra_features() | ||
if metadata is not None: | ||
to_ret['meta'] = [] | ||
for label_i in range(len(self._corpus._category_idx_store)): | ||
label_indices = np.arange(0, len(labels))[labels == label_i] | ||
if self.max_per_category < len(label_indices): | ||
label_indices = np.random.choice(label_indices, self.max_per_category, replace=False) | ||
to_ret['labels'] += list([int(e) for e in labels[label_indices]]) | ||
to_ret['texts'] += list(texts[label_indices]) | ||
if metadata is not None: | ||
to_ret['meta'] += [metadata[i] for i in label_indices] | ||
if self._use_non_text_features: | ||
to_ret['extra'] += [extrafeats[i] for i in label_indices] | ||
|
||
return to_ret | ||
|
||
def get_labels_and_texts_and_meta(self, metadata): | ||
return self.get_labels_and_texts(metadata) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,48 @@ | ||
from typing import Optional, Callable | ||
from abc import ABCMeta | ||
from typing import Optional, Callable, Dict | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from scattertext import RankDifference, TermDocMatrix, CorpusBasedTermScorer | ||
from scattertext.TermDocMatrix import TermDocMatrix | ||
from scattertext.termscoring.RankDifference import RankDifference | ||
from scattertext.termscoring.CorpusBasedTermScorer import CorpusBasedTermScorer | ||
from scattertext.categorygrouping.characteristic_embedder_base import CategoryEmbedderABC | ||
|
||
from scattertext.util import inherits_from | ||
|
||
class RankEmbedder(CategoryEmbedderABC): | ||
def __init__(self, | ||
scorer_function: Optional[Callable[[np.array, np.array], np.array]] = None, | ||
term_scorer: Optional[CorpusBasedTermScorer] = None, | ||
rank_threshold: int = 10, | ||
term_scorer_kwargs: Optional[Dict] = None, | ||
*args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.scorer_function = RankDifference().get_scores if scorer_function is None else scorer_function | ||
self.term_scorer = term_scorer | ||
self.rank_threshold = rank_threshold | ||
self.term_scorer_kwargs = {} if term_scorer_kwargs is None else term_scorer_kwargs | ||
|
||
def embed_categories(self, corpus: TermDocMatrix, non_text: bool = False) -> np.array: | ||
tdf = corpus.get_freq_df(use_metadata=non_text, label_append='') | ||
term_freqs = tdf.sum(axis=1) | ||
score_df = pd.DataFrame({ | ||
category: pd.Series( | ||
self.__get_scores_for_category(str(category), tdf, term_freqs, non_text), | ||
self.__get_scores_for_category(str(category), tdf, term_freqs, non_text, corpus), | ||
index=corpus.get_terms(use_metadata=non_text) | ||
).sort_values(ascending=False).head(self.rank_threshold) | ||
for category in corpus.get_categories() | ||
}) | ||
return score_df.fillna(0).T.values | ||
|
||
def __get_scores_for_category(self, category, tdf, term_freqs, non_text): | ||
def __get_scores_for_category(self, category, tdf, term_freqs, non_text, corpus): | ||
if self.term_scorer is not None: | ||
scorer = self.term_scorer.set_categories(category_name=category) | ||
if inherits_from(self.term_scorer, 'CorpusBasedTermScorer') and type(self.term_scorer) == ABCMeta: | ||
scorer = self.term_scorer(corpus, **self.term_scorer_kwargs) | ||
else: | ||
scorer = self.term_scorer | ||
if non_text: | ||
scorer = scorer.use_metadata() | ||
scorer = scorer.set_categories(category_name=category) | ||
return scorer.get_scores() | ||
return self.scorer_function(tdf[str(category)], term_freqs - tdf[str(category)]) |
5 changes: 3 additions & 2 deletions
5
scattertext/categorytable/multi_category_association_scorer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.