Skip to content

Adding the missing integrator inputs to the Active Learners #320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions classifiers/active_learner/bayesian_optimization/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from util.configs import build_classifier_learner_config
from util.enums import State
from util.enums import State, RefineryDataType, BricksVariableType, SelectionType

from . import bayesian_optimization


Expand All @@ -20,6 +21,40 @@ def get_config():
"gdpr_compliant",
],
integrator_inputs={
"input": "coming soon"
"name": "MyBayesian",
"refineryDataType": RefineryDataType.TEXT.value,
"variables": {
"EMBEDDING": {
"selectionType": SelectionType.STRING.value,
"defaultValue": "text-classification-distilbert-base-uncased",
"description": "pick this from the options above",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
},
"MIN_CONFIDENCE": {
"selectionType": SelectionType.FLOAT.value,
"defaultValue": 0.8,
"addInfo": [
BricksVariableType.GENERIC_FLOAT.value
]
},
"ITERATIONS": {
"selectionType": SelectionType.INTEGER.value,
"defaultValue": 100,
"description": "this can be modified by the user",
"addInfo": [
BricksVariableType.GENERIC_INT.value
]
},
"LABELS": {
"selectionType": SelectionType.STRING.value,
"description": "optional, you can specify a list to filter the predictions (e.g. [\"label-a\", \"label-b\"])",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
}
}
}
)
)

Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@ from sklearn.tree import DecisionTreeClassifier
from typing import List
# you can find further models here: https://scikit-learn.org/stable/supervised_learning.html#supervised-learning

YOUR_EMBEDDING: str = "text-classification-distilbert-base-uncased"
YOUR_MIN_CONFIDENCE: float = 0.8
YOUR_LABELS: List[str] = None # optional, you can specify a list to filter the predictions
EMBEDDING: str = "text-classification-distilbert-base-uncased"
MIN_CONFIDENCE: float = 0.8
LABELS: List[str] = None # you can specify a list to filter the predictions (e.g. ["label-a", "label-b"])

class MyDT(LearningClassifier):

def __init__(self):
self.model = DecisionTreeClassifier()

@params_fit(
embedding_name = YOUR_EMBEDDING,
embedding_name = EMBEDDING,
train_test_split = 0.5 # we have this fixed at the moment, but you'll soon be able to specify this individually!
)
def fit(self, embeddings, labels):
self.model.fit(embeddings, labels)

@params_inference(
min_confidence = YOUR_MIN_CONFIDENCE,
label_names = YOUR_LABELS
min_confidence = MIN_CONFIDENCE,
label_names = LABELS
)
def predict_proba(self, embeddings):
return self.model.predict_proba(embeddings)
Expand Down
32 changes: 30 additions & 2 deletions classifiers/active_learner/decision_tree/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from util.configs import build_classifier_learner_config
from util.enums import State
from util.enums import State, RefineryDataType, BricksVariableType, SelectionType
from . import decision_tree


Expand All @@ -20,6 +20,34 @@ def get_config():
"gdpr_compliant",
],
integrator_inputs={
"input": "coming soon"
"name": "MyDT",
"refineryDataType": RefineryDataType.TEXT.value,
"variables": {
"EMBEDDING": {
"selectionType": SelectionType.STRING.value,
"defaultValue": "text-classification-distilbert-base-uncased",
"description": "pick this from the options above",
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
},
"MIN_CONFIDENCE": {
"selectionType": SelectionType.FLOAT.value,
"defaultValue": 0.8,
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_FLOAT.value
]
},
"LABELS": {
"selectionType": SelectionType.STRING.value,
"description": "optional, you can specify a list to filter the predictions (e.g. [\"label-a\", \"label-b\"])",
"optional": "true",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
}
}
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@ class MyGrid(LearningClassifier):

def predict_proba(self, embeddings):
return self.model.predict_proba(embeddings)

```
34 changes: 31 additions & 3 deletions classifiers/active_learner/grid_search/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from util.configs import build_classifier_learner_config
from util.enums import State
from util.enums import State, RefineryDataType, SelectionType, BricksVariableType
from . import grid_search


Expand All @@ -20,6 +20,34 @@ def get_config():
"gdpr_compliant",
],
integrator_inputs={
"input": "coming soon"
"name": "MyGrid",
"refineryDataType": RefineryDataType.TEXT.value,
"variables": {
"EMBEDDING": {
"selectionType": SelectionType.STRING.value,
"defaultValue": "text-classification-distilbert-base-uncased",
"description": "pick this from the options above",
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
},
"MIN_CONFIDENCE": {
"selectionType": SelectionType.FLOAT.value,
"defaultValue": 0.8,
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_FLOAT.value
]
},
"LABELS": {
"selectionType": SelectionType.STRING.value,
"description": "optional, you can specify a list to filter the predictions (e.g. [\"label-a\", \"label-b\"])",
"optional": "true",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
}
}
}
)
)
32 changes: 30 additions & 2 deletions classifiers/active_learner/logistic_regression/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from util.configs import build_classifier_learner_config
from util.enums import State
from util.enums import State, RefineryDataType, SelectionType, BricksVariableType
from . import logistic_regression


Expand All @@ -20,6 +20,34 @@ def get_config():
"gdpr_compliant",
],
integrator_inputs={
"input": "coming soon"
"name": "MyLR",
"refineryDataType": RefineryDataType.TEXT.value,
"variables": {
"EMBEDDING": {
"selectionType": SelectionType.STRING.value,
"defaultValue": "text-classification-distilbert-base-uncased",
"description": "pick this from the options above",
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
},
"MIN_CONFIDENCE": {
"selectionType": SelectionType.FLOAT.value,
"defaultValue": 0.8,
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_FLOAT.value
]
},
"LABELS": {
"selectionType": SelectionType.STRING.value,
"description": "optional, you can specify a list to filter the predictions (e.g. [\"label-a\", \"label-b\"])",
"optional": "true",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
}
}
}
)
34 changes: 31 additions & 3 deletions classifiers/active_learner/random_forest/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from util.configs import build_classifier_learner_config
from util.enums import State
from util.enums import State, RefineryDataType, SelectionType, BricksVariableType
from . import random_forest


Expand All @@ -20,6 +20,34 @@ def get_config():
"gdpr_compliant",
],
integrator_inputs={
"input": "coming soon"
"name": "MyRF",
"refineryDataType": RefineryDataType.TEXT.value,
"variables": {
"EMBEDDING": {
"selectionType": SelectionType.STRING.value,
"defaultValue": "text-classification-distilbert-base-uncased",
"description": "pick this from the options above",
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
},
"MIN_CONFIDENCE": {
"selectionType": SelectionType.FLOAT.value,
"defaultValue": 0.8,
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_FLOAT.value
]
},
"LABELS": {
"selectionType": SelectionType.STRING.value,
"description": "optional, you can specify a list to filter the predictions (e.g. [\"label-a\", \"label-b\"])",
"optional": "true",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
}
}
}
)
)
40 changes: 38 additions & 2 deletions classifiers/active_learner/random_search/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from util.configs import build_classifier_learner_config
from util.enums import State
from util.enums import State, RefineryDataType, SelectionType, BricksVariableType
from . import random_search


Expand All @@ -20,6 +20,42 @@ def get_config():
"gdpr_compliant",
],
integrator_inputs={
"input": "coming soon"
"name": "MyRandom",
"refineryDataType": RefineryDataType.TEXT.value,
"variables": {
"EMBEDDING": {
"selectionType": SelectionType.STRING.value,
"defaultValue": "text-classification-distilbert-base-uncased",
"description": "pick this from the options above",
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
},
"MIN_CONFIDENCE": {
"selectionType": SelectionType.FLOAT.value,
"defaultValue": 0.8,
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_FLOAT.value
]
},
"ITERATIONS": {
"selectionType": SelectionType.INTEGER.value,
"defaultValue": 100,
"description": "this can be modified by the user",
"addInfo": [
BricksVariableType.GENERIC_INT.value
]
},
"LABELS": {
"selectionType": SelectionType.STRING.value,
"description": "optional, you can specify a list to filter the predictions (e.g. [\"label-a\", \"label-b\"])",
"optional": "true",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
}
}
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ EMBEDDING: str = "text-extraction-distilbert-base-uncased" # pick this from the
MIN_CONFIDENCE: float = 0.8
LABELS: List[str] = None # optional, you can specify a list to filter the predictions (e.g. ["label-a", "label-b"])

class MyActiveLearner(LearningExtractor):
class MyCRF(LearningExtractor):

def __init__(self):
self.model = CRFTagger(
Expand All @@ -31,5 +31,4 @@ class MyActiveLearner(LearningExtractor):
)
def predict_proba(self, embeddings):
return self.model.predict_proba(embeddings)

```
32 changes: 30 additions & 2 deletions extractors/active_learner/crf_tagger/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from util.configs import build_extractor_learner_config
from util.enums import State
from util.enums import State, RefineryDataType, BricksVariableType, SelectionType
from . import crf_tagger


Expand All @@ -20,6 +20,34 @@ def get_config():
"gdpr_compliant",
],
integrator_inputs={
"input": "coming soon"
"name": "MyCRF",
"refineryDataType": RefineryDataType.TEXT.value,
"variables": {
"EMBEDDING": {
"selectionType": SelectionType.STRING.value,
"defaultValue": "text-classification-distilbert-base-uncased",
"description": "pick this from the options above",
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
},
"MIN_CONFIDENCE": {
"selectionType": SelectionType.FLOAT.value,
"defaultValue": 0.8,
"optional": "false",
"addInfo": [
BricksVariableType.GENERIC_FLOAT.value
]
},
"LABELS": {
"selectionType": SelectionType.STRING.value,
"description": "optional, you can specify a list to filter the predictions (e.g. [\"label-a\", \"label-b\"])",
"optional": "true",
"addInfo": [
BricksVariableType.GENERIC_STRING.value
]
}
}
}
)