-
Notifications
You must be signed in to change notification settings - Fork 18
Train system and general refactoring #230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Train system and general refactoring #230
Conversation
src/cnlpt/new_data/preprocess.py
Outdated
| return [ | ||
| tokenized_input.word_ids(i) for i in range(len(tokenized_input.input_ids)) | ||
| ] | ||
| elif character_level: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote all the character level code here and elsewhere for some experiments using Google's CANINE model on one of Guergana's projects. If we want to keep the code there's some cleaning up I can do, although I haven't used CANINE in a while personally
etgld
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of great clarity and efficiency improvements! Really liked the refactoring some of the functionalities from train_system into callbacks, those examples are really helpful.
I'm not sure if/when I'll get the chance to test any of this out but from what I can tell all of the functionality I typically use should still work
1794a35
into
Machine-Learning-for-Medical-Language:main
This is an attempt at refactoring the messier parts of the codebase. Quite a few changes, summary below. All the pre-refactoring code is still available in the
cnlpt.legacypackage.Refactored train system
The refactored train system lives in the
cnlpt.train_systempackage. With the new setup, you can initialize the train system by creating aCnlpTrainSysteminstance.To run the new train system, use
cnlpt train [ARGS]Initialization and training
CnlpTrainSystemis created from model arguments, data arguments, and training arguments. Classmethods are also available to initialize aCnlpTrainSystemfromargv, a config dictionary, or a json file.The
__init__method ofCnlpTrainSystemconfigures logging (more info below) and validates the provided args, then sets up the tokenizer, dataset, and model for training. Training won't actually start until thetrain()method is called.Metrics and model saving
The
model_selection_scoreandmodel_selection_labeltraining arguments have been removed in favor of Trainer's built-in system to save the best model. Use the training argument--metric_for_best_modelto choose your selection metric. It will default to average accuracy across all tasks, but other options are available:lossavg_acc(default)avg_macro_f1avg_micro_f1TASKNAME.accTASKNAME.macro_f1TASKNAME.micro_f1TASKNAME.LABELNAME.f1METRIC_1, METRIC_2, ... METRIC_NReworked predictions and analysis system
This PR introduces a
CnlpPredictionsdataclass (in thedatapackage) that stores information related to predictions made by the model on test data (with or without labels). These predictions can be generated with thepredict()method ofCnlpTrainSystem, and the dataclass has methods for JSON serialization. Using the--do_predictflag when training will automatically run predictions on the test set when the training is complete, and save them to apredictions.jsonfile in your output directory.There is also a new
cnlpt.data.analysismodule with a function that can convert aCnlpPredictionsinstance to a polars dataframe for analysis.Logging and live display
Rather than relying on stdout/stderr to document the training process, all relevant information is now logged in
train_system.login the configured output directory.By moving everything to the logfile, we can reclaim console real estate for a much more interpretable live training progress display using
rich.Refactored data processing
Most of the data processing code has also been refactored (i.e.,
cnlp_processorsandcnlp_data). The new code, which is used by the new train system, lives in thecnlpt.datapackage.The main goal of the data refactoring was to simplify a lot of code by packaging all info related to each task into a new dataclass,
TaskInfo. Basically all of our data processing before required passing around a bunch of dicts mapping task names to different properties (task type, number of labels, label set, task index). Repackaging all that data on a per-task basis simplifies quite a lot.Other stuff
srcuvandsetuptools-scmhave been resolved in neweruvversions)TODO
When reworking the train system, I noticed that the old code only successfully sets the class weights for the CNN model; for the Hierarchical and CNLP models the class weights are taken from
dataset.class_weightswhich (as far as I can tell) is alwaysNone. This is most likely a bug in the original train system, but since I didn't write that code I'll wait for review before fixing it in case I'm missing something.One chunk of code that's still missing in this refactor is the error and disagreement analysis stuff in
cnlp_predict.py. I don't have a good sense of how much of that code is still needed now that we can export a dataframe with much of the same information from the newCnlpPredictionsdataclass via themake_preds_dffunction incnlpt.data.analysis.Despite the fact that training seems to run the same with the arguments I've tried, it's possible I accidentally broke something for someone else's use case. I'm opening this PR early as a draft so that people can test it on their own tasks/data to make sure everything is still working fine. As a reminder, run the new train system with
cnlpt train [ARGS].