Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/test_topic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ def test_n_topics():
assert TopicModel(model, n_topics=20).n_topics == 20


def test_duck_typing():
class TrainedDummyModel():
def __init__(self):
self.n_topics = 5
self.components_ = np.array([[0,0,0,1], [1,0,0,0]])

def transform(self, text):
return text
dummy = TrainedDummyModel()
tmodel = TopicModel(dummy)

assert tmodel.n_topics == dummy.n_topics
assert tmodel.model.transform == dummy.transform
np.testing.assert_array_equal(tmodel.model.components_, dummy.components_)


def test_init_model():
expecteds = (NMF, LatentDirichletAllocation, TruncatedSVD)
models = ["nmf", "lda", "lsa"]
Expand Down
11 changes: 10 additions & 1 deletion textacy/tm/topic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
class TopicModel(object):
"""
Train and apply a topic model to vectorized texts using scikit-learn's
implementations of LSA, LDA, and NMF models. Inspect and visualize results.
implementations of LSA, LDA, and NMF models. Also any other topic model implementations that have
`component_`, `n_topics` and `transform` attributes. Inspect and visualize results.
Save and load trained models to and from disk.

Prepare a vectorized corpus (i.e. document-term matrix) and corresponding
Expand Down Expand Up @@ -103,9 +104,13 @@ class TopicModel(object):
- http://scikit-learn.org/stable/modules/generated/sklearn.decomposition.TruncatedSVD.html
"""

_required_trained_model_attr = {"transform", "components_", "n_topics"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scanning through the code, it looks like models also need a .fit() method and n_components attribute. Is that doable?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with the n_components, but would be good to leave .fit method a bit optional, essentially it allows to initialize with a trained model, so we don't need to call the training method. But up to you I guess.


def __init__(self, model, n_topics=10, **kwargs):
if isinstance(model, (NMF, LatentDirichletAllocation, TruncatedSVD)):
self.model = model
elif all(hasattr(model, required_attr) for required_attr in self._required_trained_model_attr):
self.model = model
else:
self.init_model(model, n_topics=n_topics, **kwargs)

Expand Down Expand Up @@ -365,6 +370,7 @@ def termite_plot(
rank_terms_by="topic_weight",
sort_terms_by="seriation",
save=False,
rc_params=None,
):
"""
Make a "termite" plot for assessing topic models using a tabular layout
Expand Down Expand Up @@ -392,6 +398,8 @@ def termite_plot(
the default ("seriation") groups similar terms together, which
facilitates cross-topic assessment
save (str): give the full /path/to/fname on disk to save figure
rc_params (dict, optional): allow passing parameters to rc_context in matplotlib.plyplot,
details in https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.rc_context.html

Returns:
``matplotlib.axes.Axes.axis``: Axis on which termite plot is plotted.
Expand Down Expand Up @@ -522,4 +530,5 @@ def termite_plot(
term_labels,
highlight_cols=highlight_cols,
save=save,
rc_params=rc_params,
)
7 changes: 7 additions & 0 deletions textacy/viz/termite.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def draw_termite_plot(
highlight_cols=None,
highlight_colors=None,
save=False,
rc_params=None,
):
"""
Make a "termite" plot, typically used for assessing topic models with a tabular
Expand All @@ -87,6 +88,8 @@ def draw_termite_plot(
of (light/dark) matplotlib-friendly colors used to highlight a single
column; if not specified (default), a good set of 6 pairs are used
save (str, optional): give the full /path/to/fname on disk to save figure
rc_params (dict, optional): allow passing parameters to rc_context in matplotlib.plyplot,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: plyplot => pyplot

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah.. missed this one... sorry

details in https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.rc_context.html

Returns:
:obj:`matplotlib.axes.Axes.axis`: Axis on which termite plot is plotted.
Expand Down Expand Up @@ -138,6 +141,10 @@ def draw_termite_plot(
raise ValueError(msg)
highlight_colors = {hc: COLOR_PAIRS[i] for i, hc in enumerate(highlight_cols)}

_rc_params = RC_PARAMS.copy()
if rc_params:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block overwrites parts of the global RC_PARAMS with whatever is passed in the function call, but it would be safer to modify a local version of the params. Could you do something like

_rc_params = RC_PARAMS.copy()
if rc_params:
    _rc_params.update(rc_params)

with plt.rc_context(_rc_params):
    ...

instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdewilde yeah sure, it's better indeed :)

_rc_params.update(rc_params)

with plt.rc_context(RC_PARAMS):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there! You forgot to change this line: RC_PARAMS => _rc_params

fig, ax = plt.subplots(figsize=(pow(n_cols, 0.8), pow(n_rows, 0.66)))

Expand Down