Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
plugin system to include user-defined components (#416)
Browse files Browse the repository at this point in the history
Summary:
With this diff, PyText users can write tasks and components in their own directory and include them using --include <directory> without having to modify PyText sources to register their tasks, etc., which is very difficult when working with pip installed pytext-l=nlp. This works for all pytext CLI commands.

Examples in upcoming tutorials.

Pull Request resolved: #416

Reviewed By: snisarg

Differential Revision: D14534038

fbshipit-source-id: 898d15d284654685d2a1e92510fb5987c0293bc5
  • Loading branch information
Titousensei authored and facebook-github-bot committed Mar 20, 2019
1 parent d72f977 commit 4b7b48b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 6 deletions.
38 changes: 38 additions & 0 deletions pytext/builtin_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import glob
import importlib
import inspect
import os

from pytext.config.component import register_tasks
from pytext.task.disjoint_multitask import DisjointMultitask
from pytext.task.new_task import NewTask
from pytext.task.task import Task
from pytext.task.tasks import (
ContextualIntentSlotTask,
DocClassificationTask,
Expand All @@ -15,6 +22,37 @@
SeqNNTask,
WordTaggingTask,
)
from pytext.utils.documentation import eprint


USER_TASKS_DIR = "user_tasks"


def add_include(path):
"""
Import tasks (and associated components) from the folder name.
"""
eprint("Including:", path)
modules = glob.glob(os.path.join(path, "*.py"))
all = [
os.path.basename(f)[:-3].replace("/", ".")
for f in modules
if os.path.isfile(f) and not f.endswith("__init__.py")
]
for mod_name in all:
mod_path = path + "." + mod_name
eprint("... importing module:", mod_path)
my_module = importlib.import_module(mod_path)

for m in inspect.getmembers(my_module, inspect.isclass):
if m[1].__module__ != mod_path:
pass
elif Task in m[1].__bases__ or NewTask in m[1].__bases__:
eprint("... task:", m[1].__name__)
register_tasks(m[1])
else:
eprint("... importing:", m[1])
importlib.import_module(mod_path, m[1])


def register_builtin_tasks():
Expand Down
11 changes: 9 additions & 2 deletions pytext/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import click
import torch
from pytext import create_predictor
from pytext.builtin_task import add_include
from pytext.config import LATEST_VERSION, PyTextConfig
from pytext.config.component import register_tasks
from pytext.config.serialize import (
Expand All @@ -26,6 +27,7 @@
ROOT_CONFIG,
eprint,
find_config_class,
get_subclasses,
pretty_print_config_class,
replace_components,
)
Expand Down Expand Up @@ -135,7 +137,7 @@ def gen_config_impl(task_name, options):
elif len(replace_class_set) > 1:
raise Exception(f"Multiple component named {opt}: {replace_class_set}")
replace_class = next(iter(replace_class_set))
found = replace_components(root, opt, set(replace_class.__bases__))
found = replace_components(root, opt, get_subclasses(replace_class))
if found:
eprint("INFO - Applying option:", "->".join(reversed(found)), "=", opt)
obj = root
Expand All @@ -151,13 +153,14 @@ def gen_config_impl(task_name, options):


@click.group()
@click.option("--include", multiple=True)
@click.option("--config-file", default="")
@click.option("--config-json", default="")
@click.option(
"--config-module", default="", help="python module that contains the config object"
)
@click.pass_context
def main(context, config_file, config_json, config_module):
def main(context, config_file, config_json, config_module, include):
"""Configs can be passed by file or directly from json.
If neither --config-file or --config-json is passed,
attempts to read the file from stdin.
Expand All @@ -166,6 +169,9 @@ def main(context, config_file, config_json, config_module):
pytext train < demos/docnn.json
"""
for path in include or []:
add_include(path)

context.obj = Attrs()

def load_config():
Expand Down Expand Up @@ -282,6 +288,7 @@ def test(context, model_snapshot, test_path, use_cuda, use_tensorboard):
@click.pass_context
def train(context):
"""Train a model and save the best snapshot."""

config = context.obj.load_config()
print("\n===Starting training...")
metric_channels = []
Expand Down
24 changes: 20 additions & 4 deletions pytext/utils/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sys import modules, stderr
from typing import Union

from pytext.config.component import get_component_name
from pytext.config.component import Component, get_component_name
from pytext.config.pytext_config import ConfigBase
from pytext.models.module import Module

Expand Down Expand Up @@ -130,6 +130,19 @@ def pretty_print_config_class(obj):
print(f" {k} = null")


def get_subclasses(klass, stop_klass=Component):
ret = set()

def add_subclasses(k):
for b in getattr(k, "__bases__"):
if b != stop_klass:
ret.add(b)
add_subclasses(b)

add_subclasses(klass)
return ret


def find_config_class(class_name):
"""
Return the set of PyText classes matching that name.
Expand All @@ -145,7 +158,10 @@ def find_config_class(class_name):
for _, mod in list(modules.items()):
try:
for name, obj in getmembers(mod, isclass):
if name == class_name:
if name == class_name and any(
base.__module__.startswith("pytext.")
for base in get_subclasses(obj, object)
):
if not module_part or obj.__module__ == module_part:
ret.add(obj)
except ModuleNotFoundError:
Expand Down Expand Up @@ -173,7 +189,7 @@ def replace_components(root, component, base_class):
return found

# Not found in options, try to match base classes
# Except ConfigBase, which gives false matches
bases = list(filter(lambda x: x != ConfigBase, v_comp_obj.__bases__))
bases = get_subclasses(v_comp_obj)
bases.add(v_comp_obj)
if base_class & set(bases):
return [k]

0 comments on commit 4b7b48b

Please sign in to comment.