-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Extend BERT-based classification with customized layers #4553
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this all looks great! Would it be possible to add a short test?
@@ -90,20 +99,6 @@ def add_cmdline_args( | |||
""" | |||
super().add_cmdline_args(parser, partial_opt=partial_opt) | |||
parser = parser.add_argument_group("BERT Classifier Arguments") | |||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this option just never used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I actually also found it in BertWrapper's add_common_args function, also never used, so I'll remove it from there
|
||
if ind < len(dimensions): | ||
raise Exception( | ||
"Output layer's dimension does not match number of classes. Found {dimensions[ind][1]}, expected {output_dimension}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: think you're missing f""
string here
"Output layer's dimension does not match number of classes. Found {dimensions[ind][1]}, expected {output_dimension}" | ||
) | ||
raise Exception( | ||
"Output layer's dimension does not match number of classes. Found {prev_dimension}, expected {output_dimension}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
parlai/agents/bert_ranker/helpers.py
Outdated
aggregation="first", | ||
bert_model: BertModel, | ||
output_dim: int = -1, | ||
classifier_layer: torch.nn.Module = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you please move this arg to be the last one? so that prior calls to this __init__
don't fail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed issues and added tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for adding tests!
edit: approving assuming long_gpu_tests pass
I have changed the torch version in CircleCI config to make it work since CR was approved, so re-requesting approval for this change |
* Extend BERT-based classification with customized layers * fix bugs and add tests * increase lr to improve training stability * upgrading torch version * adjusting loss value
Patch description
Added functionality to specify custom decoder layers for BERT-based classification. Code is a modification of existing in external ParlAI functions.
Testing steps
Other information