From 3396c4d71012b9042b5185d38f245e7c4d19dceb Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Thu, 6 Jul 2023 15:02:11 +0300 Subject: [PATCH 1/9] v1.7.1 Release PR (#331) * CU-8677ge6j8 Version identification and updating (#313) * Expose example model card version in metadata test * Add version detection along with tests * Move to a more comprehensive version string parser (regex) * Add more comprehensive versioning tests * Move MedCAT unzip to a separate method * Separate getting semantic version from string * Add new CDB with version information and use that with versioning tests * Add methods to get version info from CDB dump and model pack zip/folder * Exposing CDB file name and adding custom dev patch version support * Fix config.linking.filters.cuis - from empty dict to empty set * Add logging to versioning * Fix f-strings instead of (intended) r-strings * Add creating model pack archive to versioning CDB fix * Fix logger initialising * Making versioning a runnable module that allows fixing the config * Add docstrings to CLI methods * CU-8677ge6j8 Make explicit check regards to empty dict when fixing config * CU-8677ge6j8 Add tests regarding versioning changes * CU-8677ge6j8 Add missing return type hint * CU-8677ge6j8 Simplify action handling for CLI input * CU-8677ge6j8 Simplifying archive making method * Pin down transformers for the de-identification model (#314) * NO-TICKET pin down transformers for the de-id model * Added function to remove CUI from cdb (#316) * Added function to remove CUI from cdb * Unit test for remove_cui * CU-862jjprjw Fix github actions failures (#317) * Added function to remove CUI from cdb --------- Co-authored-by: antsh3k * CU-862jr8wkk Pin pydantic dependency to avoid conflicts with v2.0 (#318) * Bump django from 3.2.18 to 3.2.19 in /webapp/webapp Bumps [django](https://github.com/django/django) from 3.2.18 to 3.2.19. - [Commits](https://github.com/django/django/compare/3.2.18...3.2.19) --- updated-dependencies: - dependency-name: django dependency-type: direct:production ... Signed-off-by: dependabot[bot] * CU-863gntc58 Umlspt2ch (#322) * CU-863gntc58 Add parent to child relationship getter to UMLS preprocessing * CU-863gntc58 Only use ISA relationships * Make sure parents do not have themselves as children * CU-863gntc58 Only keep preferred names * CU-863gntc58 Fix typing issues * CU-863gntc58 Fix child-parent relationships being saved instea * Better system for avoiding parent-child being the same * Fix for Issue 325 (#326) * Issue-325 Add check for old/new spacy; fix code for nested entities * Issue-325 Fix a typing issue * Issue-325 Improve nested entity extraction in _doc_to_out; add type hint for individual entities * Issue-325 Remove unneccessary whitespace * Issue-325 Move spacy version detection from cat to utils.helpers * CU-86783u6d9 Add wrapper to simplify De-ID model usage (#324) * CU-2wgnqg5 Add javadoc to a method * CU-2wgnqg5 Fix issues with typing * CU-2wgnqg5 Add (potential) progress bar to regression testing * CU-2wgnqg5 Add runnable regression checker with command line arguments * CU-2wgnqg5 Add better help message for a CLI argument * CU-2wgnqg5 Fix import to use proper namespace * CU-2wgnqg5 Add parent-child functionality for filters * CU-2wgnqg5 Add cui and children option to the config example * Revert "CU-2wgnqg5 Fix import to use proper namespace" This reverts commit 882be443fd45a33ea708000014f74b3a6554c3ce. * CU-2wgnqg5 Add default / empty children to translation layer * CU-2wgnqg5 Remove use of deprecated warning method * CU-2wgnqg5 Add new default test case that checks for 'heart rate' and its children 4 deep * CU-2wgnqg5 Remove unneccessary TODO comment * CU-2wgnqg5 Add possibility of using result reporting for regression checks * CU-2wgnqg5 Fix issue with delegations not shown for reports * CU-2wgnqg5 Add possibility of using reports for CLI regression testing * CU-2wgnqg5 Fix minor typing issues * CU-2wgnqg5 Fix typo in default regression config * CU-2wgnqg5 Make sure imports work both when running directly as well as when using as part of the project * CU-2wgnqg5 Add a new test case with the ANY strategy * CU-2wgnqg5 Fixing imports so that absolute imports are used * CU-2wgnqg5 Add new package to setup.py * CU-2wgnqg5 Fix typing issues * CU-2wgnqg5 Fix report output formating * CU-2vzhd93 Remove logging tutorials (move to MedCATtutorials) * CU-2wgnqg5 Move to a simpler filter design * CU-2wgnqg5 Add (optional) per-phrase results to results/reporting * CU-2wgnqg5 Add per-phrase information toggle to CLI * CU-2wgnqg5 Fix method signature changes between inherited classes * CU-2q50k3c: add contact email address. * added latest release news / accepted paper * Update README.md * CU-2zj4czk Move to a class based linking filter approach * CU-2zj4czk Move to identifier based linking filter access * CU-2zj4czk Use MCT filters when training supervised * New UMLS Full Model * CU-2zj4czk Make sure excluded CUIs are always specified (even if by an empty set) * CU-2zj4czk Add possibility of creating a copy of linking filters * CU-2zj4czk Use copies of linking.filters in train_supervised and _print_stats * CU-2zj4czk Add linking.filters merging functionality * CU-2zj4czk Add parameter to retain MCT filters within train_supervised * CU-2zj4czk Rename filters variable within print_stats method for better consistency and readability * CU-2zj4czk Consolidate some duplicate code between train_supervised and _print_stats * CU-2zj4czk Fix multi-project detection * CU-2zj4czk Fix linking filter merging * CU-2zj4czk Add tests for retaining filters from MCT along with a test-trainer export * CU-2zj4czk Remove debug print outputs from some tests * CU-2wgnqg5 Separate some of the regression code into different modules * Add URL of paper for Dutch model (#275) * CU-2wgnqg5 Add serialisation code along with tests * CU-2wgnqg5 Fix regression checker and case serialisation and add tests * CU-2wgnqg5 Add conversion code from MCT export to regression YAML along with tests * CU-2wgnqg5 Fix minor import and typing issues * CU-2wgnqg5 Add runnable to convert from MedCATtrainer to regression YAML * CU-2wgnqg5 Add for number of cases read from MCT export * CU-2wgnqg5 Add context selectors for conversion from MCT * CU-2wgnqg5 Add use of context selector to converter * CU-2wgnqg5 Add use of context selector to runnable * CU-2wgnqg5 Fix issue with typing * CU-2wgnqg5 Add regression case based progress bar in case the total of sub-cases is unknown * CU-2wgnqg5 Make sure (and test) that only 1 replacement '%s' is in each phrase for regression tests * CU-2wgnqg5 Add test cases for '%' replacement in context and some minor optimisation * CU-2wgnqg5 Add option to not show empty cases in report * CU-2wgnqg5 Fix verbose output mode/logging * CU-2wgnqg5 Fix name clashes in test cases * CU-2wgnqg5 Make conversion filter for both CUI and NAME * CU-2wgnqg5 Use different approach for generating targets for regression cases * CU-2wgnqg5 Add warning when no parent-child information is present (but continue to run) * Fix issue with typing * Add TODO comment regarding more comprehensive reporting * Fix whitespace issue * CU-2wgnqg5 Translation layer now able to confirm if a set of CUIs has a parent or child of a specified one * CU-2wgnqg5 Add reasons for failure of a regression case * CU-2wgnqg5 Make hiding failures a possibility from the CLI * CU-2wgnqg5 Use better report output for failures with summary * CU-2wgnqg5 Fix typing issues * CU-2wgnqg5 Add description to failed cases where applicable * CU-2wgnqg5 Fix successes not being reported on * CU-2wgnqg5 Rename some fail reasons for better readability * CU-2wgnqg5 Add test cases for specifeid CUI and name if/when none are found from the CDB * CU-2wgnqg5 Add extra information (names) in case of failure becasue name not in CDB * CU-2wgnqg5 Make converter consolidate different test cases with identical filters (CUI and name) into one with multiple phrases * CU-2wgnqg5 Remove use of TargetInfo and using a tuple instead * CU-2wgnqg5 Fix remnant targetinfo * CU-2wgnqg5 Fix remnant targetinfo stuff * CU-2wgnqg5 Fix remnant targetinfo in docstrings * CU-2wgnqg5 Fix missing argumnet in docstrings * CU-2wgnqg5 Allow only reports in regression checker * CU-2wgnqg5 Add medcat.utils.regression level parent logger * CU-2wgnqg5 Use medcat.utils.regression parent logger for verbose output in regression checker * CU-2wgnqg5 Move from logger.warn to logger.warning * CU-2wgnqg5 Fix issue with wrong targets being generated * CU-2wgnqg5 Fix checking tests * CU-2wgnqg5 Add dunder init to test (utils) packages to make the tests within discoverable * CU-2wgnqg5 Fix serialisation tests (add missing argument) * CU-2wgnqg5 Fix regression results tests (change method owner) * CU-2wgnqg5 Fix regression results tests (make names ordered) * CU-2wgnqg5 Remove unnecessary print output in test * CU-2wgnqg5 Update conversion code to not use target info * CU-2wgnqg5 Attempt to fix automated build on github actions (bin sklearn version) * CU-2wgnqg5 Move from sklearn to scikit-learn dependency * CU-2wgnqg5 Separate some code in converting, add docs * CU-2wgnqg5 Make yaml dumping save for yaml representation of regression checker * CU-2wgnqg5 Add initial editing code with some simple tests * CU-2wgnqg5 Add possibility for combinations to ignore identicals * CU-2wgnqg5 Add docs to the editing/combining methods * CU-2wgnqg5 Add runnable python file for combining different regression YAMLs * CU-2wgnqg5 Minor codebase improvements * CU-2wgnqg5 Make FailReasons serializable * CU-2wgnqg5 Add json output to regression checking * Make stats reporting not have np.nan values on empty train count (#277) * CU-327vb66 make stats reporting not have np.nan values on empty train count * CU-327vb66 start using scikit-learn instead of deprecated sklearn * Bump django from 3.2.15 to 3.2.16 in /webapp/webapp Bumps [django](https://github.com/django/django) from 3.2.15 to 3.2.16. - [Release notes](https://github.com/django/django/releases) - [Commits](https://github.com/django/django/compare/3.2.15...3.2.16) --- updated-dependencies: - dependency-name: django dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Update ReadMe.md to show Licence change Updated News Section * CU-2wgnqg5 Add docstring to fail descriptor getter method * CU-2wgnqg5 Removed handled TODO * CU-33g09h4 Make strides towards PEP 257. Make all docstrings use triple double quotes; remove preceding whitespace from docstrings; remove raw-string docstrings where applicable; remove empty docstrings * CU-2zj4czk Add documentation regarding config.linking.filters * CU-2zj4czk Add test for leakage of extra_cui_filters * CU-33g09h4 Remove leftover whitespace from start of docstring * include joblib dep * CU-2zj4czk Add parameter to retain extra_cui_filters (instead of MCT filters). Make sure tests pass. * CU-33g09h4 Some docstring unification for config(s) * CU-33g09h4 Some docstring unification for pipe, meta_cat and vocab * CU-33g09h4 Some docstring unification for cdb * CU-33g09h4 Some docstring unification for cdb maker * CU-33g09h4 Some docstring unification for cdb and maker (Return: to Returns:) * CU-33g09h4 Some docstring unification for cat * CU-33g09h4 Fix typo in docstring * CU-33g09h4 Some docstring unification for utils * CU-33g09h4 Some docstring unification for tokenizers * CU-33g09h4 Some docstring unification for preprocessors * CU-33g09h4 Some docstring unification for NER parts * CU-33g09h4 Some docstring unification for NEO parts * CU-33g09h4 Some docstring unification for linking parts * CU-33g09h4 Some docstring unification for cogstack connection part * CU-33g09h4 Remove some leftover backticks from docstring types * CU-33g09h4 Remove some leftover 'Return:' -> 'Returns:' changes * CU-33g09h4 Fix typo in a return type name * CU-384mewq match post release branches in the production workflow (#283) * CU-346mpxm Add new JSON based (faster) serialization for CDB along with tests * CU-346mpxm Add new package to setup.py; add logger and docstrings to serializer; remove dead code and comments * CU-346mpxm Remove leftover codel; Fix type safety regarding optinal json path * CU-346mpxm Add logging on writing to serializer * CU-346mpxm Add logging on reading to serializer * CU-346mpxm Make deserializing consistent with previous CDB deserialising * CU-346mpxm Add JSON serialisation to CDB * CU-346mpxm Remove issue with circular imports * CU-346mpxm Make sure json files end with .json * CU-346mpxm Add json type format to modelpack creation * CU-346mpxm Add tests for json format modelpack creation * CU-346mpxm Add logging output to model pack creation and loading * CU-346mpxm Add model pack converter / runnable * Update README.md * CU-862hyd5wx Unify rosalind/vocab downloading in tests, identify and fail meaningfully in case of 503 * CU-862hyd5wx Remove unused imports in tests due to last commit * CU-862hyd5wx Add possibility of generating and using a simply vocab when Rosalind is down * CU-862hyd5wx Fix small typo in tests * Loosen dependency restrictions (#289) Signed-off-by: zethson Signed-off-by: zethson * bug found in snomed2OPCS func * markdown improvements * Mapping icd10 and opcs complete * get all children func added * pep8 fixes * Update README.md * Add confusion matrix to meta model evaluation * CU-862j0jcdu / CU-862j0jd2n Cdb json (#295) * CU-862j0jcdu Rename format parameter in model creation to specify it only applys to the CDB * CU-862j0jd2n Add addl_info to be JSON serialised when required * CU-862j0jd2n Add addl_info to docstring of CDB serializer * CU-38g55wn / CU-39cmv82 Support for python3.11 (and 3.10) (#285) * CU-38g55wn Move dependencies to (hopefully) support python 3.11 on Ubuntu * CU-38g55wn Attempt to fix dependencies for github dependency (gensim) * CU-38g55wn Attempt to fix dependencies for github dependency (gensim) x2 * CU-38g55wn Attempt to fix dependencies for github dependency (gensim) x3 * CU-38g55wn Attempt to fix dependencies for github dependency (gensim) x4 * CU-38g55wn Attempt to fix dependencies for github dependency (gensim) x5 - fix missing comma * CU-38g55wn Remove errorenous package from setup.py * CU-38g55wn Bump spacy version so as to (hopefully) fix pydantic issues * CU-38g55wn Bump spacy en_core_web_md version so as to (hopefully) fix requirements issues * CU-38g55wn Fix test typo that was fixed on newere en_core_web_md * CU-38g55wn Fix small issue in NER test * CU-38g55wn Fix small issue with NER test (int conversion) * CU-38g55wn Mark some places as ignore where newer mypy complains * CU-38g55wn Bump mypy dev requirement version * CU-38g55wn Add python 3.11 and 3.10 to workflow * CU-38g55wn Trying to install gensim over https rather tha ssh * CU-38g55wn Make python versions strings in GH worfklow so 3.10 doesn't get 'rounded' to 3.10 when read * CU-38g55wn Remove python 3.7 from workflow since it's not compatible with required versions of numpy and scipy * CU-38g55wn Universally fixing NER test regarding the 'movar~viruse' -> 'movar~virus' thing * CU-38g55wn Bump gensim version to 4.3.0 - the first to support 3.11 * CU-862hyd5wx Unify rosalind/vocab downloading in tests, identify and fail meaningfully in case of 503 * CU-862hyd5wx Remove unused imports in tests due to last commit * CU-862hyd5wx Add possibility of generating and using a simply vocab when Rosalind is down * CU-862hyd5wx Remove python 3.7 and add 3.10/3.11 to classifiers * CU-862hyd5wx Reorder python versions in GitHub workflow * CU-862hyd5wx Attempt to fix GHA by importing unittest.mock explicitly * CU-39cmvru Faster hashing (#286) * CU-39cmvru Add marking of CDB dirty if/when concepts change. Avoid calculating its hash separately if it hasn't been dirtied. Add tests to verify behaviour. * CU-39cmvru Add possibility to force recalculation of hash for CDB (inlcuding when getting hash for CAT) * CU-39cmvru Add possibility to force recalculation of hash for CDB through modelcat creation (new parameter, propageting through _versioning) * CU-39cmvru Remove previous hash from influencing hashing of CDB to produce consistent hash on every recalculation Add tests to make sure that is the case on the CDB level as well as the CAT/modelpack level. * CU-39cmvru Add logging around the (re)calclulation of the CDB hash * CU-39cmvru Fix typo in log message * CU-39cmvru Add test to make sure the CDB hash is saved to disk and loaded from disk * CU-39cmvru Add possibility to calculate hash upon saving of CDB if/when the hash is unknown (i.e when saving outside a model pack) * CU-39cmvru Add CDB dirty flag to all other methods that modify the CDB * Change confusion matrix to DF and add labels * Fix model config * CU-86777ey74 No elastic dependency (#298) * Removed elastic dependency * CU-86777ey74 Remove module that depends on elastic (cogstack/cogstack_conn) * CU-86777ey74 Remove medcat.cogstack package from setup.py packages * Docstring updated to google-style docstring * CU-2e77a2k Remove unused utility modules * CU-2e77a2k Remove deprecated utils * Bump django from 3.2.16 to 3.2.17 in /webapp/webapp Bumps [django](https://github.com/django/django) from 3.2.16 to 3.2.17. - [Release notes](https://github.com/django/django/releases) - [Commits](https://github.com/django/django/compare/3.2.16...3.2.17) --- updated-dependencies: - dependency-name: django dependency-type: direct:production ... Signed-off-by: dependabot[bot] * CU-33g0f3w Read the docs build failures (#306) * CU-33g0f3w Pin aiohttp dependency version for docs * CU-33g0f3w Pin aiohttp dependency version for docs (#303) * CU-33g0f3w Pin aiohttp dependency version for docs in setup.py * Read the docs build failures (#304) * CU-33g0f3w Pin aiohttp dependency version for docs * CU-33g0f3w Pin aiohttp dependency version for docs in setup.py * CU-33g0f3w Pin blis dependency version for docs in setup.py * Add options for loading meta models and additional NERs (#300) * CU-8677aud63 add options for loading meta models and addl NERs * CU-8677aud63 reduce memory usage during test * Style fix * NO-TICKET reduce the false positives on pushing to test pypi (#307) * CU-862j5by9q Regression touchup - metadata and ability to split suites into categories (#301) * CU-862j5by9q Add metadata to regression suite, loaded from model card if/when specified. A model can be specified upon creation to get the model card from. * CU-862j5by9q Remove f-string from string with no placeholders * CU-862j5by9q Make regression case hashable * CU-862j5by9q Add category separation to regression test suite along with automated tests and test example * CU-862j5by9q Add missing docstringgs to category separation * CU-862j5by9q Add saving to category separator and a convenience method for separation based on regression test YAML file and categories YAML file * CU-862j5by9q Add missing docstrings to new methods * CU-862j5by9q Fix typo in class name * CU-862j5by9q Fix saving issue for separation results * CU-862j5by9q Add runnable category separator * CU-862j5by9q Separate some file location constants in separation tests * CU-862j5by9q Add test for separation that checks that no information gets lost (in the specific situation) * CU-862j5by9q Add an anything-goes category description * CU-862j5by9q Fix anything-goes option * CU-862j5by9q Add tests for anything-goes category description * CU-862j5by9q Add possibility of using an overflow category when separating regression suite * CU-862j5by9q Add use of the overflow category to the runnable * CU-862j5by9q Fix linting and typing issues * CU-862j5by9q Add test for each individual separated suite * CU-862j5by9q Fix minor abstract class issues * CU-862j5by9q Rename categoryseparation module as category_separation * CU-862j5by9q Add docstrings to category_separator * CU-8677craqe make transformer_ner continue processing other entities after the first non-matching * Bump django from 3.2.17 to 3.2.18 in /webapp/webapp Bumps [django](https://github.com/django/django) from 3.2.17 to 3.2.18. - [Release notes](https://github.com/django/django/releases) - [Commits](https://github.com/django/django/compare/3.2.17...3.2.18) --- updated-dependencies: - dependency-name: django dependency-type: direct:production ... Signed-off-by: dependabot[bot] * CU-862j7b9jc Mypy full release - 1.0.0 (#308) * CU-862j7b9jc Add abstract base class to regression converting strategy where necessary * CU-862j7b9jc Bump mypy to version 1.0.0 * CU-862j7b9jc Mypy abc hotfix (#311) * CU-862j7b9jc Fix issue with duplicate imports * CU-862j7b9jc Fix issue with no whitespace after keyword (E275) * CU-862j7b9jc Remove unnecessary brackets from if statement * CU-8677ge6j8 Version identification and updating (#313) * Expose example model card version in metadata test * Add version detection along with tests * Move to a more comprehensive version string parser (regex) * Add more comprehensive versioning tests * Move MedCAT unzip to a separate method * Separate getting semantic version from string * Add new CDB with version information and use that with versioning tests * Add methods to get version info from CDB dump and model pack zip/folder * Exposing CDB file name and adding custom dev patch version support * Fix config.linking.filters.cuis - from empty dict to empty set * Add logging to versioning * Fix f-strings instead of (intended) r-strings * Add creating model pack archive to versioning CDB fix * Fix logger initialising * Making versioning a runnable module that allows fixing the config * Add docstrings to CLI methods * CU-8677ge6j8 Make explicit check regards to empty dict when fixing config * CU-8677ge6j8 Add tests regarding versioning changes * CU-8677ge6j8 Add missing return type hint * CU-8677ge6j8 Simplify action handling for CLI input * CU-8677ge6j8 Simplifying archive making method * Pin down transformers for the de-identification model (#314) * NO-TICKET pin down transformers for the de-id model * Added function to remove CUI from cdb (#316) * Added function to remove CUI from cdb * Unit test for remove_cui * CU-862jjprjw Fix github actions failures (#317) * Added function to remove CUI from cdb --------- Co-authored-by: antsh3k * CU-862jr8wkk Pin pydantic dependency to avoid conflicts with v2.0 (#318) * Bump django from 3.2.18 to 3.2.19 in /webapp/webapp Bumps [django](https://github.com/django/django) from 3.2.18 to 3.2.19. - [Commits](https://github.com/django/django/compare/3.2.18...3.2.19) --- updated-dependencies: - dependency-name: django dependency-type: direct:production ... Signed-off-by: dependabot[bot] * CU-863gntc58 Umlspt2ch (#322) * CU-863gntc58 Add parent to child relationship getter to UMLS preprocessing * CU-863gntc58 Only use ISA relationships * Make sure parents do not have themselves as children * CU-863gntc58 Only keep preferred names * CU-863gntc58 Fix typing issues * CU-863gntc58 Fix child-parent relationships being saved instea * Better system for avoiding parent-child being the same * CU-86783u6d9 Add wrapper to simplify De-ID model usage * CU-86783u6d9 Add wrapper to simplify De-ID model usage * CU-86783u6d9 Fix typoe (nod vs not) * CU-86783u6d9 Fix typo in docstring * CU-86783u6d9 Change loading method name to match CAT * CU-86783u6d9 Separate NER model from DeID model * Better separation of NER models from DeID models * CU-86783u6d9 Move deid method from helpers module to deid model and deprecated the use of the wrappers in the helpers module * Fix imports in deid model * Fix deid training method return value * CU-86783u6d9 Fix dunder call defaults for redaction * CU-86783u6d9 Add a few simple tests for the DeID model * CU-86783u6d9 Add redaction test for the DeID model * CU-86783u6d9 Add remove senitive data * CU-86783u6d9 Fix deid model validation * CU-86783u6d9 Add ChatGPT generated DeId trian data * CU-86783u6d9 Add Warning regarding deid training data * CU-86783u6d9 Fix model issue with multiple NER models * CU-86783u6d9 Fix merge conflict in docstring * CU-86783u6d9 Try and fix keyword argument duplication * CU-86783u6d9 Ignore mypy where needed * CU-86783u6d9 Fix issue with NER model being returned when loading a DeID model * CU-86783u6d9 Remove unused import * CU-86783u6d9 Update training data with some more examples * CU-86783u6d9 Add type hints and doc string to deid method * CU-86783u6d9 Add comment regarding deid_text method being outside the model class * CU-86783u6d9 Add missing return type * CU-86783u6d9 Expose get_entities in NER model * CU-86783u6d9 Expose dunder call in NER model * CU-86783u6d9 Remove dunder call in override in deid model * CU-86783u6d9 Fix deid model tests * CU-86783u6d9 Fix a few typos in docstrings * CU-86783u6d9 Fix a method name in docstrings --------- Signed-off-by: dependabot[bot] Signed-off-by: zethson Co-authored-by: tomolopolis Co-authored-by: Zeljko Co-authored-by: Sander Tan Co-authored-by: Xi Bai <82581439+baixiac@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Anthony Shek <55877857+antsh3k@users.noreply.github.com> Co-authored-by: Lukas Heumos Co-authored-by: antsh3k Co-authored-by: James Brandreth Co-authored-by: Xi Bai * CU-862k1tt90 Fix circular imports by moving raw deid method back to helpers module (#328) * CU-862k1tt90 Fix circular imports by moving raw deid method back to helpers module * CU-862k1tt90 Fix missing import regarding deid * CU-862k1tt90 Remove unnecessary newline * Cu 863h30jyb separate train from data load (#329) * CU-863h30jyb Deprecated train_supervised method in favour of train_supervised_from_json method * CU-863h30jyb Shuffle around docstrings for supoervised training methods * CU-863h30jyb Create new train_supervised_raw method for raw data based training * CU-863h30jyb In MetaCat deprecate train method and replace with train_from_json method * CU-863h30jyb In MetaCat add train_raw method and move most of the training logic into that one * CU-863h30jyb Fix type hint * CU-86785yhfk Add method to populate cui2snames with data from cui2names (#327) * CU-86785yhfk Add method to populate cui2snames with data from cui2names * CU-86785yhfk Add test for cui2sname population method * Bump django from 3.2.19 to 3.2.20 in /webapp/webapp Bumps [django](https://github.com/django/django) from 3.2.19 to 3.2.20. - [Commits](https://github.com/django/django/compare/3.2.19...3.2.20) --- updated-dependencies: - dependency-name: django dependency-type: direct:production ... Signed-off-by: dependabot[bot] * CU-346mpwz Improving memory usage of MedCAT models (#323) * CU-863gntc58 Add parent to child relationship getter to UMLS preprocessing * CU-863gntc58 Only use ISA relationships * Make sure parents do not have themselves as children * CU-863gntc58 Only keep preferred names * CU-346mpwz Add memory optimiser for CDB * CU-346mpwz Add name2 to memory optimiser for CDB * CU-346mpwz Add keys/items/values views to memory optimiser fake dicts * CU-346mpwz Fix keys/items/values views in memory optimiser fake dicts * CU-346mpwz Add option to optimise or not cui and/or name based dicts in memory optimiser * CU-346mpwz Make default memory optimiser omit name2... optimising; add comment regarding this in docstring * CU-346mpwz Remove unused/legacy code from memory optimiser * CU-346mpwz Add tests for memory optimiser * CU-346mpwz Add tests memory optimised CDB * CU-346mpwz Make dict names available within memory optimiser * CU-346mpwz Add separate tests for memory optimised CDB * CU-346mpwz Remove unused imports in memory optimiser * CU-346mpwz Move some encoding and decoing stuff within serialisation to their own module * CU-346mpwz Add tests for encoding/decoding stuff * CU-346mpwz Add encoding/decoding for delegating dict as well as postprocessing for delegation linking with json serialisation * CU-346mpwz Fix decision upon JSON deserialisation of CDB when loading model pack * CU-346mpwz Adapt serialisation tests to the potential one2many mappings * CU-346mpwz Add tests for memory optimisation, including JSON serialisation ones * CU-346mpwz Remove debug print statements * CU-346mpwz Remove debug methods from tests * CU-346mpwz Fix method signatures in encoding/decoding methods * CU-346mpwz Fix typing issue in serialiser when passing encoder * CU-346mpwz Relax typing restrictions for umls preprocessing / parent2child mapping * CU-346mpwz Remove some debug variables * CU-346mpwz Fix remnant merge conflict * CU-346mpwz Add item removal and popping to delegating dict * CU-346mpwz Add item removal and popping tests to delegating dict * CU-346mpwz Add item adding/setting tests to delegating dict * CU-346mpwz Fix typing issue (List vs list) * CU-346mpwz Add possibility of memory-optimising for snames as well * CU-346mpwz Add comment regarding memory-optimising for filtering by CUI to CDB * CU-346mpwz Add sname based memory optimisation tests * CU-346mpwz Add json serialisation capabilities to snames delegation * CU-346mpwz Make sname optimisation default for memory optimisation * CU-346mpwz Fix typo in serialisation tests * CU-346mpwz Add variable to keep track of current memory optimisation info to CDB * CU-346mpwz Add default cui2snames to sname optimisations; make sure sname optimisation dirties the CDB * CU-346mpwz Add method to undo CDB memory optimisation * CU-346mpwz Add tests for undoing CDB memory optimisation * CU-346mpwz Clear memory optimised parts if/when undoing optimisations * CU-346mpwz Remove accidentally added file/module * CU-346mpwz Add more straight forward optimisation part names; Fix memory optimisation part clearing * CU-346mpwz Add further tests for memory optimisation (dirty state, checking optimised parts) --------- Signed-off-by: dependabot[bot] Signed-off-by: zethson Co-authored-by: Xi Bai <82581439+baixiac@users.noreply.github.com> Co-authored-by: Anthony Shek <55877857+antsh3k@users.noreply.github.com> Co-authored-by: antsh3k Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: tomolopolis Co-authored-by: Zeljko Co-authored-by: Sander Tan Co-authored-by: Lukas Heumos Co-authored-by: James Brandreth Co-authored-by: Xi Bai --- .github/workflows/main.yml | 2 +- .github/workflows/production.yml | 2 +- examples/cdb_new.dat | Bin 0 -> 3366 bytes medcat/cat.py | 193 ++++++++++-- medcat/cdb.py | 69 ++++- medcat/meta_cat.py | 50 ++- medcat/utils/helpers.py | 20 ++ medcat/utils/memory_optimiser.py | 366 ++++++++++++++++++++++ medcat/utils/ner/deid.py | 114 +++++++ medcat/utils/ner/helpers.py | 36 ++- medcat/utils/ner/model.py | 109 +++++++ medcat/utils/preprocess_umls.py | 126 +++++++- medcat/utils/saving/coding.py | 146 +++++++++ medcat/utils/saving/serializer.py | 51 ++- medcat/utils/versioning.py | 314 +++++++++++++++++++ setup.py | 4 +- tests/resources/deid_train_data.json | 1 + tests/test_cdb.py | 17 + tests/utils/ner/__init__.py | 0 tests/utils/ner/test_deid.py | 118 +++++++ tests/utils/regression/test_metadata.py | 1 + tests/utils/saving/test_coding.py | 77 +++++ tests/utils/saving/test_serialization.py | 28 +- tests/utils/test_memory_optimiser.py | 375 +++++++++++++++++++++++ tests/utils/test_versioning.py | 163 ++++++++++ webapp/webapp/requirements.txt | 2 +- 26 files changed, 2299 insertions(+), 85 deletions(-) create mode 100644 examples/cdb_new.dat create mode 100644 medcat/utils/memory_optimiser.py create mode 100644 medcat/utils/ner/deid.py create mode 100644 medcat/utils/ner/model.py create mode 100644 medcat/utils/saving/coding.py create mode 100644 medcat/utils/versioning.py create mode 100644 tests/resources/deid_train_data.json create mode 100644 tests/utils/ner/__init__.py create mode 100644 tests/utils/ner/test_deid.py create mode 100644 tests/utils/saving/test_coding.py create mode 100644 tests/utils/test_memory_optimiser.py create mode 100644 tests/utils/test_versioning.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3c05835d0..c769dfc2e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -42,7 +42,7 @@ jobs: github.ref == 'refs/heads/master' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') != true - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 concurrency: publish-to-test-pypi needs: [build] diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml index ad229ce9c..5088c1000 100644 --- a/.github/workflows/production.yml +++ b/.github/workflows/production.yml @@ -8,7 +8,7 @@ on: jobs: build-n-publish-to-pypi: - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 concurrency: build-n-publish-to-pypi if: github.repository == 'CogStack/MedCAT' diff --git a/examples/cdb_new.dat b/examples/cdb_new.dat new file mode 100644 index 0000000000000000000000000000000000000000..27957d62bcdbb0013c393684f2036a9f4d7c8f4b GIT binary patch literal 3366 zcmbVP-ESR76~EWl_KjUTj$;#AB}i9-NRE&{Qz0d)FpepP+}Mq5C!wv{(cay;cZR#W zv+V5pOHG7?sA4fvEp>Q+C`3)vDCz@!gQouh|3HBU5aN|T0)A(9*ET%xz+TR5Q>iue53CuF zy@<0No@%LLZQsMbp<8d=dvnv8q19V&u54It&Zj(5JDg!7k2Arw)V4ijrXv}%kY^b- zvy}fdmnjcydmuNHUL32;%OH}eJWKM-X)qZmpYgqn#{mzRjJwJnnlWgcfe1<6%-;ML z9y4CiQrQy~aWYi9JoP=zS-ET<=BC->F;Bg)usBpb$l2kcZBLs~nsrqgd6_leygcI; z$lNRX`sJB&(7HSmajm@`sMegJEs2*2YsTj0BOdr(HV2wDuj$0|_gNsd*9oD$ucAc7 za7W>@29&U77#`9zXzp+yE{4P=*#;@_G?-W-WAWRelop&kR%l*6Pr_(#;_=kI7I7 zxZtF-JV|)UT=>ZK+8*X+JV|+yDjxxmaRJ)1x-S!^GnMSBGysvO>Lbu3I4$c*7XL_dUw39ln3Qa8+0^qb$i zF;O$;0)>s)sm<=pyEndXP2=v34fnq7-fMW@yt(P#TlOdymniYZ=y;b`kZaow1PX^a z!^%Ll2h8*-hd7s#62yEVOZSmZl#-m0?F?UF0{&`}ODj(eWvM4)L~TeVnT%Wl3}*WY z2M)ozpT+pe2o=akg0EsuWZAOF0xI-8p*f^*a<&95TcFw<536)P=+6<}Or#t- zgs4t+bI}WTy?qTRhJGG;8L|V24!sn~iH)xnDgn|SLV&u!x&$zkGHXTy8HRH>Bt&wZ zV}8nM`|)$x z@6WIEV0C4?tYV;3W0*_!QRUtU5%s|Pu5HJncjZcEH?3b5yMk?$!AL)2Y*eU?4O+jC zkC|ZXq_S(qK+_-Ej=6Z^%U~bCm;Nc2!=)z#_}$dKZ-w*2>^`I&DVugf0N9&8jg=o_#YFnnmHP;9zDN|MURLyA#?V*giN z_n249PJ@RX3S+h)eDqNYX(xJ^cl#F5*ZSrO!NI2RKP3W zvhCJ~hz;tfv(*(^>=3^+qwl`&%^obwUVHD`6%%@rYw~m$S_!qbs_&yQEY2;=EzDAi zK!FBs4q(F-f}FJEG^6Rzj!g5S;hkAspy%o}sOE-?7;-c|;&>3bi^Tut>HimxczV43 zfR@CkL>bi_P7EISdI!YvtRuV6c1wgJXa*(4BYHhH_1jkb*3?&QE}jtO6Y;rt zDt?!XXEgh?nmzxI*^B#Dd?r4J3x+ibr+~~{?Y+QxPRqx zsgUASf++DwJQ2^uvm^181eg9i*;I{iR4J%$kHlxR^TknaMoD6+*eQ`;ir;>J`PuFQcJ43Y0>p8@Zxd|E$lW5eHK4x%xap@}XyuEbwb^2Lc zxCR?wFWHfs^R%nTHCz{qEOdL#&|QZv@zgZksPzumFA#fGiJ@gz2DmeE0&yAsqQX}i zj*-Ke)ULZ0csLYK4#l%W@%&J{KocGDf*>qzpE$ROlJtYc1(NAvrV?MjPP4OgI8s>Q z3P?9DpyM&ct+BrKLwlqPTNEy-GWJ4e$|*k_8&R3FZpB}#al3HIjQCqcSx3kVv)kY+ zf^J5)9Qs>?euT<}B;eXlC0af@5-DUp-rA>IIjU$Y5lZ75%0vdOT=SqM;}&L0D}cF$ zVn{Qrp?rG#s0Xb~wRDyT`=B^0QOR2ok2GWp#UHV|#1PYBNR!9?) str: + """Attempt unpack the zip to a folder and get the model pack path. + + If the folder already exists, no unpacking is done. + + Args: + zip_path (str): The ZIP path + + Returns: + str: The model pack path + """ + base_dir = os.path.dirname(zip_path) + filename = os.path.basename(zip_path) + + foldername = filename.replace(".zip", '') + + model_pack_path = os.path.join(base_dir, foldername) + if os.path.exists(model_pack_path): + logger.info("Found an existing unziped model pack at: {}, the provided zip will not be touched.".format(model_pack_path)) + else: + logger.info("Unziping the model pack and loading models.") + shutil.unpack_archive(zip_path, extract_dir=model_pack_path) + return model_pack_path + @classmethod def load_model_pack(cls, zip_path: str, @@ -324,20 +352,12 @@ def load_model_pack(cls, from medcat.vocab import Vocab from medcat.meta_cat import MetaCAT - base_dir = os.path.dirname(zip_path) - filename = os.path.basename(zip_path) - foldername = filename.replace(".zip", '') - - model_pack_path = os.path.join(base_dir, foldername) - if os.path.exists(model_pack_path): - logger.info("Found an existing unziped model pack at: {}, the provided zip will not be touched.".format(model_pack_path)) - else: - logger.info("Unziping the model pack and loading models.") - shutil.unpack_archive(zip_path, extract_dir=model_pack_path) + model_pack_path = cls.attempt_unpack(zip_path) # Load the CDB cdb_path = os.path.join(model_pack_path, "cdb.dat") - has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= len(SPECIALITY_NAMES) + nr_of_jsons_expected = len(SPECIALITY_NAMES) - len(ONE2MANY) + has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= nr_of_jsons_expected json_path = model_pack_path if has_jsons else None logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format') cdb = CDB.load(cdb_path, json_path) @@ -823,6 +843,8 @@ def add_and_train_concept(self, for _cui in cuis: self.linker.context_model.train(cui=_cui, entity=spacy_entity, doc=spacy_doc, negative=True) # type: ignore + @deprecated(message="Use train_supervised_from_json to train based on data " + "loaded from a json file") def train_supervised(self, data_path: str, reset_cui_count: bool = False, @@ -842,9 +864,93 @@ def train_supervised(self, checkpoint: Optional[Checkpoint] = None, retain_filters: bool = False, is_resumed: bool = False) -> Tuple: - """TODO: Refactor, left from old - Run supervised training on a dataset from MedCATtrainer. Please take care that this is more a simulated - online training then supervised. + """Train supervised by reading data from a json file. + + Refer to `train_supervvised_from_json` and/or `train_supervised_raw` + for further details. + """ + return self.train_supervised_from_json(data_path, reset_cui_count, nepochs, + print_stats, use_filters, terminate_last, + use_overlaps, use_cui_doc_limit, test_size, + devalue_others, use_groups, never_terminate, + train_from_false_positives, extra_cui_filter, + retain_extra_cui_filter, checkpoint, + retain_filters, is_resumed) + + def train_supervised_from_json(self, + data_path: str, + reset_cui_count: bool = False, + nepochs: int = 1, + print_stats: int = 0, + use_filters: bool = False, + terminate_last: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + test_size: int = 0, + devalue_others: bool = False, + use_groups: bool = False, + never_terminate: bool = False, + train_from_false_positives: bool = False, + extra_cui_filter: Optional[Set] = None, + retain_extra_cui_filter: bool = False, + checkpoint: Optional[Checkpoint] = None, + retain_filters: bool = False, + is_resumed: bool = False) -> Tuple: + """ + Run supervised training on a dataset from MedCATtrainer in JSON format. + + Refer to `train_supervised_raw` for more details. + """ + with open(data_path) as f: + data = json.load(f) + return self.train_supervised_raw(data, reset_cui_count, nepochs, + print_stats, use_filters, terminate_last, + use_overlaps, use_cui_doc_limit, test_size, + devalue_others, use_groups, never_terminate, + train_from_false_positives, extra_cui_filter, + retain_extra_cui_filter, checkpoint, + retain_filters, is_resumed) + + def train_supervised_raw(self, + data: Dict[str, List[Dict[str, dict]]], + reset_cui_count: bool = False, + nepochs: int = 1, + print_stats: int = 0, + use_filters: bool = False, + terminate_last: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + test_size: int = 0, + devalue_others: bool = False, + use_groups: bool = False, + never_terminate: bool = False, + train_from_false_positives: bool = False, + extra_cui_filter: Optional[Set] = None, + retain_extra_cui_filter: bool = False, + checkpoint: Optional[Checkpoint] = None, + retain_filters: bool = False, + is_resumed: bool = False) -> Tuple: + """Train supervised based on the raw data provided. + + The raw data is expected in the following format: + {'projects': + [ # list of projects + { # project 1 + 'name': '', + # list of documents + 'documents': [{'name': '', # document 1 + 'text': '', + # list of annotations + 'annotations': [{'start': -1, # annotation 1 + 'end': 1, + 'cui': 'cui', + 'value': ''}, ...], + }, ...] + }, ... + ] + } + + Please take care that this is more a simulated online training then supervised. When filtering, the filters within the CAT model are used first, then the ones from MedCATtrainer (MCT) export filters, @@ -853,8 +959,8 @@ def train_supervised(self, extra_cui_filter ⊆ MCT filter ⊆ Model/config filter. Args: - data_path (str): - The path to the json file that we get from MedCATtrainer on export. + data (Dict[str, List[Dict[str, dict]]]): + The raw data, e.g from MedCATtrainer on export. reset_cui_count (boolean): Used for training with weight_decay (annealing). Each concept has a count that is there from the beginning of the CDB, that count is used for annealing. Resetting the count will @@ -923,8 +1029,7 @@ def train_supervised(self, local_filters = self.config.linking.filters.copy_of() fp = fn = tp = p = r = f1 = examples = {} - with open(data_path) as f: - data = json.load(f) + cui_counts = {} if retain_filters: @@ -1489,6 +1594,43 @@ def _mp_cons(self, in_q: Queue, out_list: List, min_free_memory: int, lock: Lock logger.warning(str(e)) sleep(2) + def _add_nested_ent(self, doc: Doc, _ents: List[Span], _ent: Union[Dict, Span]) -> None: + # if the entities are serialised (PipeRunner.serialize_entities) + # then the entities are dicts + # otherwise they're Span objects + meta_anns = None + if isinstance(_ent, dict): + start = _ent['start'] + end =_ent['end'] + label = _ent['label'] + cui = _ent['cui'] + detected_name = _ent['detected_name'] + context_similarity = _ent['context_similarity'] + id = _ent['id'] + if 'meta_anns' in _ent: + meta_anns = _ent['meta_anns'] + else: + start = _ent.start + end = _ent.end + label = _ent.label + cui = _ent._.cui + detected_name = _ent._.detected_name + context_similarity = _ent._.context_similarity + if _ent._.has('meta_anns'): + meta_anns = _ent._.meta_anns + if HAS_NEW_SPACY: + id = _ent.id + else: + id = _ent.ent_id + entity = Span(doc, start, end, label=label) + entity._.cui = cui + entity._.detected_name = detected_name + entity._.context_similarity = context_similarity + entity._.id = id + if meta_anns is not None: + entity._.meta_anns = meta_anns + _ents.append(entity) + def _doc_to_out(self, doc: Doc, only_cui: bool, @@ -1499,16 +1641,9 @@ def _doc_to_out(self, if doc is not None: out_ent: Dict = {} if self.config.general.show_nested_entities: - _ents = [] + _ents: List[Span] = [] for _ent in doc._.ents: - entity = Span(doc, _ent['start'], _ent['end'], label=_ent['label']) - entity._.cui = _ent['cui'] - entity._.detected_name = _ent['detected_name'] - entity._.context_similarity = _ent['context_similarity'] - entity._.id = _ent['id'] - if 'meta_anns' in _ent: - entity._.meta_anns = _ent['meta_anns'] - _ents.append(entity) + self._add_nested_ent(doc, _ents, _ent) else: _ents = doc.ents # type: ignore diff --git a/medcat/cdb.py b/medcat/cdb.py index 60fa1aff6..44d4fd9dd 100644 --- a/medcat/cdb.py +++ b/medcat/cdb.py @@ -95,10 +95,11 @@ def __init__(self, config: Union[Config, None] = None) -> None: self._optim_params = None self.is_dirty = False self._hash: Optional[str] = None + self._memory_optimised_parts: Set[str] = set() def get_name(self, cui: str) -> str: """Returns preferred name if it exists, otherwise it will return - the logest name assigend to the concept. + the longest name assigned to the concept. Args: cui @@ -118,7 +119,7 @@ def update_cui2average_confidence(self, cui: str, new_sim: float) -> None: self.is_dirty = True def remove_names(self, cui: str, names: Dict) -> None: - """Remove names from an existing concept - efect is this name will never again be used to link to this concept. + """Remove names from an existing concept - effect is this name will never again be used to link to this concept. This will only remove the name from the linker (namely name2cuis and name2cuis2status), the name will still be present everywhere else. Why? Because it is bothersome to remove it from everywhere, but could also be useful to keep the removed names in e.g. cui2names. @@ -153,6 +154,43 @@ def remove_names(self, cui: str, names: Dict) -> None: self.name2cuis2status[name][_cui] = 'PD' self.is_dirty = True + def remove_cui(self, cui: str) -> None: + """This function takes a `CUI` as an argument and removes it from all the internal objects that reference it. + Args: + cui + """ + if cui in self.cui2names: + del self.cui2names[cui] + if cui in self.cui2snames: + del self.cui2snames[cui] + if cui in self.cui2context_vectors: + del self.cui2context_vectors[cui] + if cui in self.cui2count_train: + del self.cui2count_train[cui] + if cui in self.cui2tags: + del self.cui2tags[cui] + if cui in self.cui2type_ids: + del self.cui2type_ids[cui] + if cui in self.cui2preferred_name: + del self.cui2preferred_name[cui] + if cui in self.cui2average_confidence: + del self.cui2average_confidence[cui] + for name, cuis in self.name2cuis.items(): + if cui in cuis: + cuis.remove(cui) + for name, cuis2status in self.name2cuis2status.items(): + if cui in cuis2status: + del cuis2status[cui] + if isinstance(self.snames, set): + # if this is a memory optimised CDB, this won't be a set + # but it also won't need to be changed since it + # relies directly on cui2snames + self.snames = set() + for cuis in self.cui2snames.values(): + self.snames |= cuis + self.name2count_train = {name: len(cuis) for name, cuis in self.name2cuis.items()} + self.is_dirty = True + def add_names(self, cui: str, names: Dict, name_status: str = 'A', full_build: bool = False) -> None: """Adds a name to an existing concept. @@ -500,6 +538,27 @@ def reset_training(self) -> None: self.reset_concept_similarity() self.is_dirty = True + def populate_cui2snames(self, force: bool = True) -> None: + """Populate the cui2snames dict if it's empty. + + If the dict is not empty and the population is not force, + nothing will happen. + + For now, this method simply populates all the names form + cui2names into cui2snames. + + Args: + force (bool, optional): Whether to force the (re-)population. Defaults to True. + """ + if not force and self.cui2snames: + return + self.cui2snames.clear() # in case forced re-population + # run through cui2names + # and create new sets so that they can be independently modified + for cui, names in self.cui2names.items(): + self.cui2snames[cui] = set(names) # new set + self.is_dirty = True + def filter_by_cui(self, cuis_to_keep: Union[List[str], Set[str]]) -> None: """Subset the core CDB fields (dictionaries/maps). Note that this will potenitally keep a bit more CUIs then in cuis_to_keep. It will first find all names that link to the cuis_to_keep and then @@ -507,6 +566,10 @@ def filter_by_cui(self, cuis_to_keep: Union[List[str], Set[str]]) -> None: This also will not remove any data from cdb.addl_info - as this field can contain data of unknown structure. + As a side note, if the CDB has been memory-optimised, filtering will undo this memory optimisation. + This is because the dicts being involved will be rewritten. + However, the memory optimisation can be performed again afterwards. + Args: cuis_to_keep (List[str]): CUIs that will be kept, the rest will be removed (not completely, look above). @@ -570,6 +633,8 @@ def filter_by_cui(self, cuis_to_keep: Union[List[str], Set[str]]) -> None: self.cui2type_ids = new_cui2type_ids self.cui2preferred_name = new_cui2preferred_name self.is_dirty = True + # reset memory optimisation state + self._memory_optimised_parts.clear() def make_stats(self): stats = {} diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index 374b55978..d92e6ea61 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -15,6 +15,7 @@ from medcat.pipeline.pipe_runner import PipeRunner from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase from medcat.utils.meta_cat.data_utils import Doc as FakeDoc +from medcat.utils.decorators import deprecated # It should be safe to do this always, as all other multiprocessing # will be finished before data comes to meta_cat @@ -98,6 +99,7 @@ def get_hash(self): hasher.update(self.config.get_hash()) return hasher.hexdigest() + @deprecated(message="Use `train_from_json` or `train_raw` instead") def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict: """Train or continue training a model give a json_path containing a MedCATtrainer export. It will continue training if an existing model is loaded or start new training if the model is blank/new. @@ -109,8 +111,19 @@ def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None In case we have aut_save_model (meaning during the training the best model will be saved) we need to set a save path. Defaults to `None`. """ - g_config = self.config.general - t_config = self.config.train + return self.train_from_json(json_path, save_dir_path) + + def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict: + """Train or continue training a model give a json_path containing a MedCATtrainer export. It will + continue training if an existing model is loaded or start new training if the model is blank/new. + + Args: + json_path (Union[str, list]): + Path/Paths to a MedCATtrainer export containing the meta_annotations we want to train for. + save_dir_path (Optional[str]): + In case we have aut_save_model (meaning during the training the best model will be saved) + we need to set a save path. Defaults to `None`. + """ # Load the medcattrainer export if isinstance(json_path, str): @@ -131,6 +144,39 @@ def merge_data_loaded(base, other): for path in json_path: with open(path, 'r') as f: data_loaded = merge_data_loaded(data_loaded, json.load(f)) + return self.train_raw(data_loaded, save_dir_path) + + def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> Dict: + """Train or continue training a model given raw data. It will + continue training if an existing model is loaded or start new training if the model is blank/new. + + The raw data is expected in the following format: + {'projects': + [ # list of projects + { # project 1 + 'name': '', + # list of documents + 'documents': [{'name': '', # document 1 + 'text': '', + # list of annotations + 'annotations': [{'start': -1, # annotation 1 + 'end': 1, + 'cui': 'cui', + 'value': ''}, ...], + }, ...] + }, ... + ] + } + + Args: + data_loaded (Dict): + The raw data we want to train for. + save_dir_path (Optional[str]): + In case we have aut_save_model (meaning during the training the best model will be saved) + we need to set a save path. Defaults to `None`. + """ + g_config = self.config.general + t_config = self.config.train # Create directories if they don't exist if t_config['auto_save_model']: diff --git a/medcat/utils/helpers.py b/medcat/utils/helpers.py index 04f92730b..f783a9b06 100644 --- a/medcat/utils/helpers.py +++ b/medcat/utils/helpers.py @@ -3,6 +3,8 @@ from medcat.preprocessing.cleaners import clean_name from medcat.utils.other import TPL_ENT, TPL_ENTS +from spacy import __version__ as spacy_version + import logging logger = logging.getLogger(__name__) @@ -517,3 +519,21 @@ def run_cv(cdb_path, data_path, vocab_path, cv=100, nepochs=16, test_size=0.1, l fns[key] = [fn.get(key, 0)] return fps, fns, tps, ps, rs, f1s, cui_counts, examples + + +def has_new_spacy() -> bool: + """Figures out whether or not a newer version of spacy is installed. + + This plays a role in how some parts of the Span needs to be interacted with. + + As of writing, the new version starts at v3.3.1. + + Returns: + bool: Whether new version was detected. + """ + major, minor, patch_plus = spacy_version.split('.') + major, minor = int(major), int(minor) + patch = int(patch_plus) + return (major > 3 or + (major == 3 and minor > 3) or + (major == 3 and minor == 3 and patch >= 1)) diff --git a/medcat/utils/memory_optimiser.py b/medcat/utils/memory_optimiser.py new file mode 100644 index 000000000..e8328734d --- /dev/null +++ b/medcat/utils/memory_optimiser.py @@ -0,0 +1,366 @@ +from typing import Any, Dict, KeysView, Iterator, List, Tuple, Union, Optional, Set + +from medcat.cdb import CDB +from medcat.utils.saving.coding import EncodeableObject, PartEncoder, PartDecoder, UnsuitableObject, register_encoder_decoder + + +CUI_DICT_NAMES_TO_COMBINE = [ + "cui2names", "cui2snames", "cui2context_vectors", + "cui2count_train", "cui2tags", "cui2type_ids", + "cui2preferred_name", "cui2average_confidence", +] +ONE2MANY = 'cui2many' + +NAME_DICT_NAMES_TO_COMBINE = [ + "cui2names", "name2cuis2status", "cui2preferred_name", +] +NAME2MANY = 'name2many' + +DELEGATING_DICT_IDENTIFIER = '==DELEGATING_DICT==' + +DELEGATING_SET_IDENTIFIER = '==DELEGATING_SET==' + +# these will be used in CDB._memory_optimised_parts +CUIS_PART = 'CUIS' +NAMES_PART = 'NAMES' +SNAMES_PART = 'snames' + + +class _KeysView: + def __init__(self, keys: KeysView, parent: 'DelegatingDict'): + self._keys = keys + self._parent = parent + + def __iter__(self) -> Iterator[Any]: + for key in self._keys: + if key in self._parent: + yield key + + def __len__(self) -> int: + return len([_ for _ in self]) + + +class _ItemsView: + def __init__(self, parent: 'DelegatingDict') -> None: + self._parent = parent + + def __iter__(self) -> Iterator[Any]: + for key in self._parent: + yield key, self._parent[key] + + def __len__(self) -> int: + return len(self._parent) + + +class _ValuesView: + def __init__(self, parent: 'DelegatingDict') -> None: + self._parent = parent + + def __iter__(self) -> Iterator[Any]: + for key in self._parent: + yield self._parent[key] + + def __len__(self) -> int: + return len(self._parent) + + +class DelegatingDict: + + def __init__(self, delegate: Dict[str, List[Any]], nr: int, + nr_of_overall_items: int = 8) -> None: + self.delegate = delegate + self.nr = nr + self.nr_of_overall_items = nr_of_overall_items + + def _generate_empty_entry(self) -> List[Any]: + return [None for _ in range(self.nr_of_overall_items)] + + def __getitem__(self, key: str) -> Any: + val = self.delegate[key][self.nr] + if val is None: + raise KeyError + return val + + def get(self, key: str, default: Any) -> Any: + try: + return self[key] + except KeyError: + return default + + def __setitem__(self, key: str, value: Any) -> None: + if key not in self.delegate: + self.delegate[key] = self._generate_empty_entry() + self.delegate[key][self.nr] = value + + def __contains__(self, key: str) -> bool: + return key in self.delegate and self.delegate[key][self.nr] is not None + + def keys(self) -> _KeysView: + return _KeysView(self.delegate.keys(), self) + + def items(self) -> _ItemsView: + return _ItemsView(self) + + def values(self) -> _ValuesView: + return _ValuesView(self) + + def __iter__(self) -> Iterator[str]: + yield from self.keys() + + def __len__(self) -> int: + return len(self.keys()) + + def to_dict(self) -> dict: + return {'delegate': None, + 'nr': self.nr, + 'nr_of_overall_items': self.nr_of_overall_items} + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, DelegatingDict): + return False + return self.delegate == __value.delegate and self.nr == __value.nr + + def __hash__(self) -> int: + return hash((self.delegate, self.nr)) + + def __delitem__(self, key: str) -> None: + self[key] = None + + def pop(self, key: str, default: Optional[Any] = None) -> Any: + if key in self: + item = self[key] + else: + item = default + del self[key] + return item + + +class DelegatingValueSet: + + def __init__(self, delegate: Dict[str, Set[str]]) -> None: + self.delegate = delegate + + def update(self, other: Any) -> None: + # do nothing since the value will be updated in delegate + pass + + def __contains__(self, value: str) -> bool: + for cui_value in self.delegate.values(): + if value in cui_value: + return True + return False + + def to_dict(self) -> dict: + return {'delegate': None} + + +class DelegatingDictEncoder(PartEncoder): + + def try_encode(self, obj): + if isinstance(obj, DelegatingDict): + return {DELEGATING_DICT_IDENTIFIER: obj.to_dict()} + raise UnsuitableObject() + + +class DelegatingDictDecoder(PartDecoder): + + def try_decode(self, dct: dict) -> Union[dict, EncodeableObject]: + if DELEGATING_DICT_IDENTIFIER in dct: + info = dct[DELEGATING_DICT_IDENTIFIER] + delegate = info['delegate'] + nr = info['nr'] + overall = info['nr_of_overall_items'] + return DelegatingDict(delegate, nr, overall) + return dct + + +class DelegatingValueSetEncoder(PartEncoder): + + def try_encode(self, obj): + if isinstance(obj, DelegatingValueSet): + return {DELEGATING_SET_IDENTIFIER: obj.to_dict()} + raise UnsuitableObject() + + +class DelegatingValueSetDecoder(PartDecoder): + + def try_decode(self, dct: dict) -> Union[dict, EncodeableObject]: + if DELEGATING_SET_IDENTIFIER in dct: + info = dct[DELEGATING_SET_IDENTIFIER] + delegate = info['delegate'] + return DelegatingValueSet(delegate) + return dct + + +def attempt_fix_after_load(cdb: CDB): + _attempt_fix_after_load(cdb, ONE2MANY, CUI_DICT_NAMES_TO_COMBINE) + _attempt_fix_after_load(cdb, NAME2MANY, NAME_DICT_NAMES_TO_COMBINE) + + +def attempt_fix_snames_after_load(cdb: CDB, snames_attr_name: str = 'snames'): + snames = getattr(cdb, snames_attr_name) + if isinstance(snames, DelegatingValueSet) and snames.delegate is None: + snames = DelegatingValueSet(cdb.cui2snames) + setattr(cdb, snames_attr_name, snames) + + +# register encoder and decoders +register_encoder_decoder(encoder=DelegatingDictEncoder, + decoder=DelegatingDictDecoder, + loading_postprocessor=attempt_fix_after_load) +register_encoder_decoder(encoder=DelegatingValueSetEncoder, + decoder=DelegatingValueSetDecoder, + loading_postprocessor=attempt_fix_snames_after_load) + + +def _optimise(cdb: CDB, to_many_name: str, dict_names_to_combine: List[str]) -> None: + dicts = [getattr(cdb, dict_name) + for dict_name in dict_names_to_combine] + one2many, delegators = map_to_many(dicts) + for delegator, name in zip(delegators, dict_names_to_combine): + setattr(cdb, name, delegator) + setattr(cdb, to_many_name, one2many) + cdb.is_dirty = True + + +def _optimise_snames(cdb: CDB, cui2snames: str = 'cui2snames', + snames_attr: str = 'snames') -> None: + """Optimise the snames part of a CDB. + + Args: + cdb (CDB): The CDB to optimise snames on. + one2many_name (str): The cui2snames dict name to delegate to. Defaults to 'cui2snames'. + snames_attr (str, optional): The `snames` attribute name. Defaults to 'snames'. + """ + delegate = getattr(cdb, cui2snames) + dvs = DelegatingValueSet(delegate) + setattr(cdb, snames_attr, dvs) + cdb.is_dirty = True + + +def perform_optimisation(cdb: CDB, optimise_cuis: bool = True, + optimise_names: bool = False, + optimise_snames: bool = True) -> None: + """Attempts to optimise the memory footprint of the CDB. + + This can perform optimisation for cui2<...> and name2<...> dicts. + However, by default, only cui2many optimisation will be done. + This is because at the time of writing, there were not enough name2<...> + dicts to be able to benefit from the optimisation. + + Does so by unifying the following dicts: + + cui2names (Dict[str, Set[str]]): + From cui to all names assigned to it. Mainly used for subsetting (maybe even only). + cui2snames (Dict[str, Set[str]]): + From cui to all sub-names assigned to it. Only used for subsetting. + cui2context_vectors (Dict[str, Dict[str, np.array]]): + From cui to a dictionary of different kinds of context vectors. Normally you would have here + a short and a long context vector - they are calculated separately. + cui2count_train (Dict[str, int]): + From CUI to the number of training examples seen. + cui2tags (Dict[str, List[str]]): + From CUI to a list of tags. This can be used to tag concepts for grouping of whatever. + cui2type_ids (Dict[str, Set[str]]): + From CUI to type id (e.g. TUI in UMLS). + cui2preferred_name (Dict[str, str]): + From CUI to the preferred name for this concept. + cui2average_confidence (Dict[str, str]): + Used for dynamic thresholding. Holds the average confidence for this CUI given the training examples. + + name2cuis (Dict[str, List[str]]): + Map fro concept name to CUIs - one name can map to multiple CUIs. + name2cuis2status (Dict[str, Dict[str, str]]): + What is the status for a given name and cui pair - each name can be: + P - Preferred, A - Automatic (e.g. let medcat decide), N - Not common. + name2count_train (Dict[str, str]): + Counts how often did a name appear during training. + + It can also delegate the `snames` set to use the various sets in `cui2snames` instead. + + They will all be included in 1 dict with CUI keys and a list of values for each pre-existing dict. + + Args: + cdb (CDB): The CDB to modify. + optimise_cuis (bool, optional): Whether to optimise cui2<...> dicts. Defaults to True. + optimise_names (bool, optional): Whether to optimise name2<...> dicts. Defaults to False. + optimise_snames (bool, optional): Whether to optimise `snames` set. Defaults to True. + """ + # cui2<...> -> cui2many + if optimise_cuis: + _optimise(cdb, ONE2MANY, CUI_DICT_NAMES_TO_COMBINE) + cdb._memory_optimised_parts.add(CUIS_PART) + # name2<...> -> name2many + if optimise_names: + _optimise(cdb, NAME2MANY, NAME_DICT_NAMES_TO_COMBINE) + cdb._memory_optimised_parts.add(NAMES_PART) + if optimise_snames: + # check snames based on cui2sanmes + _optimise_snames(cdb) + cdb._memory_optimised_parts.add(SNAMES_PART) + + +def _attempt_fix_after_load(cdb: CDB, one2many_name: str, dict_names: List[str]): + if not hasattr(cdb, one2many_name): + return + one2many = getattr(cdb, one2many_name) + for dict_name in dict_names: + d = getattr(cdb, dict_name) + if not isinstance(d, DelegatingDict): + raise ValueError(f'Unknown type for {dict_name}: {type(d)}') + d.delegate = one2many + + +def _unoptimise(cdb: CDB, to_many_name: str, dict_names_to_combine: List[str]): + # remove one2many attribute + # the references still exist on each delegator + delattr(cdb, to_many_name) + + delegating_dicts: List[Dict[str, Any]] = [getattr(cdb, dict_name) + for dict_name in dict_names_to_combine] + for del_dict, dict_name in zip(delegating_dicts, dict_names_to_combine): + raw_dict = dict(del_dict.items()) + setattr(cdb, dict_name, raw_dict) + cdb.is_dirty = True + + +def _unoptimise_snames(cdb: CDB, cui2snames: str = 'cui2snames', + snames_attr: str = 'snames') -> None: + # rebuild snames + delegate: Dict[str, Set[str]] = getattr(cdb, cui2snames) + snames = set() + for values in delegate.values(): + snames.update(values) + setattr(cdb, snames_attr, snames) + cdb.is_dirty = True + + +def unoptimise_cdb(cdb: CDB): + """This undoes all the (potential) memory optimisations done in `perform_optimisation`. + + This method relies on `CDB._memory_optimised_parts` to be up to date. + + Args: + cdb (CDB): The CDB to work on. + """ + if CUIS_PART in cdb._memory_optimised_parts: + _unoptimise(cdb, ONE2MANY, CUI_DICT_NAMES_TO_COMBINE) + if NAMES_PART in cdb._memory_optimised_parts: + _unoptimise(cdb, NAME2MANY, NAME_DICT_NAMES_TO_COMBINE) + if SNAMES_PART in cdb._memory_optimised_parts: + _unoptimise_snames(cdb) + cdb._memory_optimised_parts.clear() + + +def map_to_many(dicts: List[Dict[str, Any]]) -> Tuple[Dict[str, List[Any]], List[DelegatingDict]]: + one2many: Dict[str, List[Any]] = {} + delegators: List[DelegatingDict] = [] + for nr, d in enumerate(dicts): + delegator = DelegatingDict( + one2many, nr, nr_of_overall_items=len(dicts)) + for key, value in d.items(): + if key not in one2many: + one2many[key] = delegator._generate_empty_entry() + one2many[key][nr] = value + delegators.append(delegator) + return one2many, delegators diff --git a/medcat/utils/ner/deid.py b/medcat/utils/ner/deid.py new file mode 100644 index 000000000..7c5d0231c --- /dev/null +++ b/medcat/utils/ner/deid.py @@ -0,0 +1,114 @@ +"""De-identification model. + +This describes a wrapper on the regular CAT model. +The idea is to simplify the use of a DeId-specific model. + +It tackles two use cases +1) Creation of a deid model +2) Loading and use of a deid model + +I.e for use case 1: + +Instead of: +cat = CAT(cdb=ner.cdb, addl_ner=ner) + +You can use: +deid = DeIdModel.create(ner) + + +And for use case 2: + +Instead of: +cat = CAT.load_model_pack(model_pack_path) +anon_text = deid_text(cat, text) + +You can use: +deid = DeIdModel.load_model_pack(model_pack_path) +anon_text = deid.deid_text(text) + +Or if/when structured output is desired: +deid = DeIdModel.load_model_pack(model_pack_path) +anon_doc = deid(text) # the spacy document + +The wrapper also exposes some CAT parts directly: +- config +- cdb +""" +from typing import Union, Tuple, Any + +from medcat.cat import CAT +from medcat.utils.ner.model import NerModel + +from medcat.utils.ner.helpers import _deid_text as deid_text + + +class DeIdModel(NerModel): + """The DeID model. + + This wraps a CAT instance and simplifies its use as a + de-identification model. + + It provies methods for creating one from a TransformersNER + as well as loading from a model pack (along with some validation). + + It also exposes some useful parts of the CAT it wraps such as + the config and the concept database. + """ + + def __init__(self, cat: CAT) -> None: + self.cat = cat + + def train(self, json_path: Union[str, list, None], + *args, **kwargs) -> Tuple[Any, Any, Any]: + return super().train(json_path, *args, train_nr=0, **kwargs) # type: ignore + + def deid_text(self, text: str, redact: bool = False) -> str: + """Deidentify text and potentially redact information. + + Args: + text (str): The text to deidentify. + redact (bool): Whether to redact the information. + + Returns: + str: The deidentified text. + """ + return deid_text(self.cat, text, redact=redact) + + @classmethod + def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel': + """Load DeId model from model pack. + + The method first loads the CAT instance. + + It then makes sure that the model pack corresponds to a + valid DeId model. + + Args: + model_pack_path (str): The model pack path. + + Raises: + ValueError: If the model pack does not correspond to a DeId model. + + Returns: + DeIdModel: The resulting DeI model. + """ + ner_model = NerModel.load_model_pack(model_pack_path) + cat = ner_model.cat + if not cls._is_deid_model(cat): + raise ValueError( + f"The model saved at {model_pack_path} is not a deid model " + f"({cls._get_reason_not_deid(cat)})") + model = cls(ner_model.cat) + return model + + @classmethod + def _is_deid_model(cls, cat: CAT) -> bool: + return not bool(cls._get_reason_not_deid(cat)) + + @classmethod + def _get_reason_not_deid(cls, cat: CAT) -> str: + if cat.vocab is not None: + return "Has vocab" + if len(cat._addl_ner) != 1: + return f"Incorrect number of addl_ner: {len(cat._addl_ner)}" + return "" diff --git a/medcat/utils/ner/helpers.py b/medcat/utils/ner/helpers.py index 65b8660e9..7dcada3dd 100644 --- a/medcat/utils/ner/helpers.py +++ b/medcat/utils/ner/helpers.py @@ -1,16 +1,48 @@ from medcat.utils.data_utils import count_annotations from medcat.cdb import CDB +from medcat.utils.decorators import deprecated -def deid_text(cat, text, redact=False): + +# For now, we will keep this method separate from the above class +# This is so that we wouldn't need to create a thorwaway object +# when calling the method from .helpers where it used to be. +# After the deprecated method in .helpers is removed, we can +# move this to a proper class method. +def _deid_text(cat, text: str, redact: bool = False) -> str: + """De-identify text. + + De-identified text. + If redaction is enabled, identifiable entities will be + replaced with starts (e.g `*****`). + Otherwise, the replacement will be the CUI or in other words, + the type of information that was hidden (e.g [PATIENT]). + + + Args: + cat (CAT): The CAT object to use for deid. + text (str): The input document. + redact (bool, optional): Whether to redact. Defaults to False. + + Returns: + str: The de-identified document. + """ new_text = str(text) entities = cat.get_entities(text)['entities'] for ent in sorted(entities.values(), key=lambda ent: ent['start'], reverse=True): - r = "*"*(ent['end']-ent['start']) if redact else cat.cdb.get_name(ent['cui']) + r = "*"*(ent['end']-ent['start'] + ) if redact else cat.cdb.get_name(ent['cui']) new_text = new_text[:ent['start']] + f'[{r}]' + new_text[ent['end']:] return new_text +@deprecated("API now allows creating a DeId model (medcat.utils.ner.deid.DeIdModel). " + "It aims to simplify the usage of DeId models. " + "The use of this model is encouraged over the use of this method.") +def deid_text(*args, **kwargs) -> str: + return _deid_text(*args, **kwargs) + + def make_or_update_cdb(json_path, cdb=None, min_count=0): """Creates a new CDB or updates an existing one with new concepts if the cdb argument is provided. All concepts that are less frequent diff --git a/medcat/utils/ner/model.py b/medcat/utils/ner/model.py new file mode 100644 index 000000000..553fb4c65 --- /dev/null +++ b/medcat/utils/ner/model.py @@ -0,0 +1,109 @@ +from typing import Any, List, Tuple, Union, Optional + +from spacy.tokens import Doc + +from medcat.ner.transformers_ner import TransformersNER +from medcat.cat import CAT +from medcat.cdb import CDB +from medcat.config import Config + + +class NerModel: + + """The NER model. + + This wraps a CAT instance and simplifies its use as a + NER model. + + It provies methods for creating one from a TransformersNER + as well as loading from a model pack (along with some validation). + + It also exposes some useful parts of the CAT it wraps such as + the config and the concept database. + """ + + def __init__(self, cat: CAT) -> None: + self.cat = cat + + def train(self, json_path: Union[str, list, None], train_nr: int = 0, + *args, **kwargs) -> Tuple[Any, Any, Any]: + """Train the underlying transformers NER model. + + All the extra arguments are passed to the TransformersNER train method. + + Args: + json_path (Union[str, list, None]): The JSON file path to read the training data from. + train_nr (int, optional): The number of the NER object in cat._addl_train to train. Defaults to 0. + + Returns: + Tuple[Any, Any, Any]: df, examples, dataset + """ + return self.cat._addl_ner[train_nr].train(json_path, *args, **kwargs) + + def __call__(self, text: Optional[str], *args, **kwargs) -> Optional[Doc]: + """Get the annotated document for text. + + Undefined arguments and keyword arguments get passed on to + the equivalent `CAT` method. + + Args: + text (Optional[str]): The input text. + + Returns: + Optional[Doc]: The annotated document. + """ + return self.cat(text, *args, **kwargs) + + def get_entities(self, text: str, *args, **kwargs) -> dict: + """Gets the entities recognized within a given text. + + The output format is identical to `CAT.get_entities`. + + Undefined arguments and keyword arguments get passed on to + CAT.get_entities. + + Args: + text (str): The input text. + + Returns: + dict: The output entities. + """ + return self.cat.get_entities(text, *args, **kwargs) + + @property + def config(self) -> Config: + return self.cat.config + + @property + def cdb(self) -> CDB: + return self.cat.cdb + + @classmethod + def create(cls, ner: Union[TransformersNER, List[TransformersNER]]) -> 'NerModel': + """Create a NER model with a TransformersNER + + Args: + ner (Union[TransformersNER, List[TransformersNER]]): The TransformersNER instance(s). + + Returns: + NerModel: The resulting model + """ + # expecting all to have the same CDB + cdb = ner.cdb if isinstance(ner, TransformersNER) else ner[0].cdb + cat = CAT(cdb=cdb, addl_ner=ner) + return cls(cat) + + @classmethod + def load_model_pack(cls, model_pack_path: str) -> 'NerModel': + """Load NER model from model pack. + + The method first wraps the loaded CAT instance. + + Args: + model_pack_path (str): The model pack path. + + Returns: + NerModel: The resulting DeI model. + """ + cat = CAT.load_model_pack(model_pack_path) + return cls(cat) diff --git a/medcat/utils/preprocess_umls.py b/medcat/utils/preprocess_umls.py index 0b3505981..7c47f451a 100644 --- a/medcat/utils/preprocess_umls.py +++ b/medcat/utils/preprocess_umls.py @@ -1,6 +1,9 @@ from typing import List, Union import pandas as pd +import tqdm +import os +from typing import Dict _DEFAULT_COLUMNS: list = [ "CUI", @@ -20,7 +23,7 @@ "STR", "SRL", "SUPPRESS", - "CVF", + "CVF", ] _DEFAULT_SEM_TYPE_COLUMNS: list = [ @@ -32,12 +35,24 @@ "CVF", ] +_DEFAULT_MRHIER_COLUMNS: list = [ + "CUI", + "AUI", + "CXN", + "PAUI", + "SAB", + "RELA", + "PTR", + "HCD", + "CVF", +] + medcat_csv_mapper: dict = { 'CUI': 'cui', 'STR': 'name', 'SAB': 'ontologies', 'ISPREF': 'name_status', - 'TUI': 'type_ids', # from MRSTY.RRF + 'TUI': 'type_ids', # from MRSTY.RRF } @@ -57,11 +72,13 @@ class UMLS: def __init__(self, main_file_name: str, sem_types_file: str, allow_languages: list = ['ENG'], sep: str = '|'): self.main_file_name = main_file_name self.sem_types_file = sem_types_file - self.main_columns = list(_DEFAULT_COLUMNS) # copy - self.sem_types_columns = list(_DEFAULT_SEM_TYPE_COLUMNS) # copy + self.main_columns = list(_DEFAULT_COLUMNS) # copy + self.sem_types_columns = list(_DEFAULT_SEM_TYPE_COLUMNS) # copy + self.mrhier_columns = list(_DEFAULT_MRHIER_COLUMNS) # copy self.sep = sep # copy in case of default list - self.allow_langugages = list(allow_languages) if allow_languages else allow_languages + self.allow_langugages = list( + allow_languages) if allow_languages else allow_languages def to_concept_df(self) -> pd.DataFrame: """Create a concept DataFrame. @@ -72,7 +89,8 @@ def to_concept_df(self) -> pd.DataFrame: """ # target columns: # cui, name, name_status, ontologies, description_type_ids, type_ids - df = pd.read_csv(self.main_file_name, names=self.main_columns, sep=self.sep, index_col=False) + df = pd.read_csv(self.main_file_name, + names=self.main_columns, sep=self.sep, index_col=False) # filter languages if self.allow_langugages: @@ -82,7 +100,8 @@ def to_concept_df(self) -> pd.DataFrame: # get TUI - sem_types = pd.read_csv(self.sem_types_file, names=self.sem_types_columns, sep=self.sep, index_col=False) + sem_types = pd.read_csv( + self.sem_types_file, names=self.sem_types_columns, sep=self.sep, index_col=False) df = df.merge(sem_types) # rename columns @@ -109,7 +128,8 @@ def map_umls2snomed(self) -> pd.DataFrame: Returns: pd.DataFrame: Dataframe that contains the SCUI (source CUI) as well as the UMLS CUI for each applicable concept """ - df = pd.read_csv(self.main_file_name, names=self.main_columns, sep=self.sep, index_col=False, dtype={'SCUI': 'str'}) + df = pd.read_csv(self.main_file_name, names=self.main_columns, + sep=self.sep, index_col=False, dtype={'SCUI': 'str'}) # get only SNOMED-CT US based concepts that have a SNOMED-CT (source) CUI df = df[df.SAB == 'SNOMEDCT_US'][df.SCUI.notna()] # sort by SCUI @@ -154,7 +174,8 @@ def map_umls2source(self, sources: Union[str, List[str]]) -> pd.DataFrame: Returns: pd.DataFrame: DataFrame that has the target source codes """ - df = pd.read_csv(self.main_file_name, names=self.main_columns, sep=self.sep, index_col=False, dtype={'CODE': 'str'}) + df = pd.read_csv(self.main_file_name, names=self.main_columns, + sep=self.sep, index_col=False, dtype={'CODE': 'str'}) # get the specified source(s) if isinstance(sources, list): df = df[df.SAB.isin(sources)][df.CODE.notna()] @@ -166,6 +187,76 @@ def map_umls2source(self, sources: Union[str, List[str]]) -> pd.DataFrame: df = df[['CODE',] + [col for col in df.columns.values if col != 'CODE']] return df + def get_pt2ch(self) -> dict: + """Generates a parent to children dict. + + It goes through all the < # TODO + + The resulting dictionary maps a CUI to a list of CUIs that + consider that CUI as their parent. + + PS: + This expects the MRHIER.RRF file to also exist in the same folder + as the MRCONSO.RRF file. + + Raises: + ValueError: If the MRHIER.RRF file wasn't found + + Returns: + dict: The dictionary of parent CUI and their children. + """ + path = self.main_file_name.rsplit('/', 1)[0] + hier_file = f"{path}/MRHIER.RRF" + + if not os.path.exists(hier_file): + raise ValueError( + f'Expected MRHIER.RRF to exist within the same parent folder ({path})') + + conso_df = pd.read_csv(self.main_file_name, names=self.main_columns, + sep=self.sep, index_col=False) + + hier_df = pd.read_csv(hier_file, sep=self.sep, index_col=False, + header=None, names=self.mrhier_columns) + + # filter languages + if self.allow_langugages: + conso_df = conso_df[conso_df["LAT"].isin(self.allow_langugages)] + + # create a AUI -> CUI map + aui_cui = dict(zip(conso_df["AUI"], conso_df["CUI"])) + + # remove non-preferred from conso + conso_df = conso_df[conso_df['ISPREF'] == 'Y'] + + # filter ISA relationships + hier_df = hier_df[hier_df['RELA'] == 'isa'] + + # merge dataframes + merged_df = pd.merge(conso_df, hier_df, on=['AUI', 'CUI']) + + # only keep CUI and parent AUI + cui_parent = merged_df[['CUI', 'PAUI']] + # only include CUIs with a parent + cui_parent = cui_parent[cui_parent['PAUI'].notna()] + + # create dict + pt2ch: dict = {} + for _, row in tqdm.tqdm(cui_parent.iterrows(), total=len(cui_parent.index)): + cur_cui = row['CUI'] + paui = row['PAUI'] + parent_cui = aui_cui[paui] + # avoid self as parent/child + if parent_cui == cur_cui: + continue + if parent_cui not in pt2ch: + pt2ch[parent_cui] = set() + pt2ch[parent_cui].add(cur_cui) + # move from set to list for consistency with SNOMED + pt2ch: Dict[str, List[str]] = pt2ch # type: ignore + for k, v in pt2ch.items(): + pt2ch[k] = list(v) + return pt2ch + if __name__ == '__main__': import sys @@ -187,3 +278,20 @@ def map_umls2source(self, sources: Union[str, List[str]]) -> pd.DataFrame: to_ICD10_man = umls.map_umls2source(sources=['ICD10']) print('As ICD-10(MAN):') print(to_ICD10_man.head()) + pt2ch = umls.get_pt2ch() + print('Get parent-child dict', len(pt2ch), + '' if len(pt2ch) > 1_000 else pt2ch) + all_vals = [len(v) for v in pt2ch.values()] + print('LEN of VALS:', sum(all_vals), 'max', + max(all_vals), 'min', min(all_vals), 'mean', sum(all_vals) / len(all_vals)) + import random + random_4_keys = random.sample(list(pt2ch.keys()), k=4) + + def _get_name(cui: str) -> str: + matches = df[df['cui'] == cui] + if len(matches.index) == 0: + return 'N/A' # UNKNOWN + return matches['name'].iloc[0] + print('FF RAW ', [f"{k}:{pt2ch[k]}" for k in random_4_keys]) + print('FIRST FEW', [ + (f"{_get_name(key)} ({key})", [f"{_get_name(child)} ({child})" for child in pt2ch[key]]) for key in random_4_keys]) diff --git a/medcat/utils/saving/coding.py b/medcat/utils/saving/coding.py new file mode 100644 index 000000000..c03e6816f --- /dev/null +++ b/medcat/utils/saving/coding.py @@ -0,0 +1,146 @@ +from typing import Any, Protocol, runtime_checkable, List, Union, Type, Optional, Callable + +import json + + +@runtime_checkable +class EncodeableObject(Protocol): + + def to_dict(self) -> dict: + """Converts the object to a dict. + + Returns: + dict: The dict to be serialised. + """ + + +class UnsuitableObject(ValueError): + pass + + +class PartEncoder(Protocol): + + def try_encode(self, obj: object) -> Any: + """Try to encode an object + + Args: + obj (object): The object to encode + + Raises: + UnsuitableObject: If the object is unsuitable for encoding. + + Returns: + Any: The encoded object + """ + + +SET_IDENTIFIER = '==SET==' + + +class SetEncoder(PartEncoder): + """JSONEncoder (and decoder) for sets. + + Generally, JSON doesn't support serializing of sets natively. + This encoder adds a set identifier to the data when being serialized + and provides a method to read said identifier upon decoding.""" + + def try_encode(self, obj): + if isinstance(obj, set): + return {SET_IDENTIFIER: list(obj)} + raise UnsuitableObject() + + +class PartDecoder(Protocol): + + def try_decode(self, dct: dict) -> Union[dict, Any]: + """Try to decode the dictionary. + + Args: + dct (dict): The dict to decode. + + Returns: + Union[dict, Any]: The dict if unable to decode, the decoded object otherwise + """ + + +class SetDecoder(PartDecoder): + + def try_decode(self, dct: dict) -> Union[dict, set]: + """Decode sets from input dicts. + + Args: + dct (dict): The input dict + + Returns: + Union[dict, set]: The original dict if this was not a serialized set, the set otherwise + """ + if SET_IDENTIFIER in dct: + return set(dct[SET_IDENTIFIER]) + return dct + + +PostProcessor = Callable[[Any], None] # CDB -> None + +DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, ] +DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, ] +LOADING_POSTPROCESSORS: List[PostProcessor] = [] + + +def register_encoder_decoder(encoder: Optional[Type[PartEncoder]], + decoder: Optional[Type[PartDecoder]], + loading_postprocessor: Optional[PostProcessor]): + if encoder: + DEFAULT_ENCODERS.append(encoder) + if decoder: + DEFAULT_DECODERS.append(decoder) + if loading_postprocessor: + LOADING_POSTPROCESSORS.append(loading_postprocessor) + + +class CustomDelegatingEncoder(json.JSONEncoder): + + def __init__(self, delegates: List[PartEncoder], *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._delegates = delegates + + def default(self, obj): + for delegator in self._delegates: + try: + return delegator.try_encode(obj) + except UnsuitableObject: + pass + return json.JSONEncoder.default(self, obj) + + @classmethod + def def_inst(cls, *args, **kwargs) -> 'CustomDelegatingEncoder': + return cls([_cls() for _cls in DEFAULT_ENCODERS], *args, **kwargs) + + +class CustomDelegatingDecoder(json.JSONDecoder): + _def_inst: Optional['CustomDelegatingDecoder'] = None + + def __init__(self, delegates: List[PartDecoder]) -> None: + self._delegates = delegates + + def object_hook(self, dct: dict) -> Any: + for delegator in self._delegates: + ret_val = delegator.try_decode(dct) + if ret_val is not dct: + return ret_val + return dct + + @classmethod + def def_inst(cls) -> 'CustomDelegatingDecoder': + if cls._def_inst is None: + cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS]) + return cls._def_inst + + +def default_hook(dct: dict) -> Any: + cdd = CustomDelegatingDecoder.def_inst() + return cdd.object_hook(dct) + + +def default_postprocessing(cdb) -> None: + for pp in LOADING_POSTPROCESSORS: + pp(cdb) diff --git a/medcat/utils/saving/serializer.py b/medcat/utils/saving/serializer.py index c08124831..d82df751c 100644 --- a/medcat/utils/saving/serializer.py +++ b/medcat/utils/saving/serializer.py @@ -5,11 +5,13 @@ """ import os import logging -from typing import cast, Dict, Optional, Union +from typing import cast, Dict, Optional, Type import dill import json from medcat.config import Config +from medcat.utils.saving.coding import CustomDelegatingEncoder, default_hook, default_postprocessing + logger = logging.getLogger(__name__) @@ -17,35 +19,8 @@ __SPECIALITY_NAMES_NAME = set( ["name2cuis", "name2cuis2status", "name_isupper"]) __SPECIALITY_NAMES_OTHER = set(["snames", "addl_info"]) -SPECIALITY_NAMES = __SPECIALITY_NAMES_CUI | __SPECIALITY_NAMES_NAME | __SPECIALITY_NAMES_OTHER - - -class SetEncode(json.JSONEncoder): - """JSONEncoder (and decoder) for sets. - - Generally, JSON doesn't support serializing of sets natively. - This encoder adds a set identifier to the data when being serialized - and provides a method to read said identifier upon decoding.""" - SET_IDENTIFIER = '==SET==' - - def default(self, obj): - if isinstance(obj, set): - return {SetEncode.SET_IDENTIFIER: list(obj)} - return json.JSONEncoder.default(self, obj) - - @staticmethod - def set_decode(dct: dict) -> Union[dict, set]: - """Decode sets from input dicts. - - Args: - dct (dict): The input dict - - Returns: - Union[dict, set]: The original dict if this was not a serialized set, the set otherwise - """ - if SetEncode.SET_IDENTIFIER in dct: - return set(dct[SetEncode.SET_IDENTIFIER]) - return dct +ONE2MANY = set(['cui2many', 'name2many']) # these may or may not exist +SPECIALITY_NAMES = __SPECIALITY_NAMES_CUI | __SPECIALITY_NAMES_NAME | __SPECIALITY_NAMES_OTHER | ONE2MANY class JsonSetSerializer: @@ -75,7 +50,11 @@ def write(self, d: dict) -> None: logger.info('Writing data for "%s" into "%s"', self.name, self.file_name) with open(self.file_name, 'w') as f: - json.dump(d, f, cls=SetEncode) + # the def_inst method, when called, + # returns the right type of object anyway + + json.dump(d, f, cls=cast(Type[json.JSONEncoder], + CustomDelegatingEncoder.def_inst)) def read(self) -> dict: """Read the json file specified by this serializer. @@ -85,7 +64,8 @@ def read(self) -> dict: """ logger.info('Reading data for %s from %s', self.name, self.file_name) with open(self.file_name, 'r') as f: - data = json.load(f, object_hook=SetEncode.set_decode) + data = json.load( + f, object_hook=default_hook) return data @@ -168,6 +148,8 @@ def serialize(self, cdb, overwrite: bool = False) -> None: dill.dump(to_save, f) if self.jsons is not None: for name in SPECIALITY_NAMES: + if name not in cdb.__dict__: + continue # in case cui2many doesn't exit self.jsons[name].write(cdb.__dict__[name]) def deserialize(self, cdb_cls): @@ -199,5 +181,10 @@ def deserialize(self, cdb_cls): # if applicable if self.jsons is not None: for name in SPECIALITY_NAMES: + if not os.path.exists(self.jsons[name].file_name): + continue # in case of non-memory-optimised where cui2many doesn't exist cdb.__dict__[name] = self.jsons[name].read() + # if anything has + # been registered to postprocess the CDBs + default_postprocessing(cdb) return cdb diff --git a/medcat/utils/versioning.py b/medcat/utils/versioning.py new file mode 100644 index 000000000..539af0339 --- /dev/null +++ b/medcat/utils/versioning.py @@ -0,0 +1,314 @@ +from typing import Tuple, List +import re +import os +import shutil +import argparse +import logging + +import dill + +from medcat.cat import CAT + +logger = logging.getLogger(__name__) + +SemanticVersion = Tuple[int, int, int] + + +# Regex as per: +# https://semver.org/#is-there-a-suggested-regular-expression-regex-to-check-a-semver-string +SEMANTIC_VERSION_REGEX = (r"^(0|[1-9]\d*)" # major + r"\.(0|[1-9]\d*)" # .minor + # CHANGE FROM NORM - allowing dev before patch version number + # but NOT capturing the group + r"\.(?:dev)?" + r"(0|[1-9]\d*)" # .patch + # and then some trailing stuff + r"(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?" + r"(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$") +SEMANTIC_VERSION_PATTERN = re.compile(SEMANTIC_VERSION_REGEX) + + +CDB_FILE_NAME = "cdb.dat" + + +def get_semantic_version(version: str) -> SemanticVersion: + """Get the semantiv version from the string. + + Args: + version (str): The version string. + + Raises: + ValueError: If the version string does not match the semantic versioning format. + + Returns: + SemanticVersion | Tuple[int, int, int]: The major, minor and patch version + """ + match = SEMANTIC_VERSION_PATTERN.match(version) + if not match: + raise ValueError(f"Unknown version string: {version}") + return int(match.group(1)), int(match.group(2)), int(match.group(3)) + + +def get_version_from_modelcard(d: dict) -> SemanticVersion: + """Gets the the major.minor.patch version from a model card. + + The version needs to be specified at: + model_card["MedCAT Version"] + The version is expected to be semantic (major.minor.patch). + + Args: + d (dict): The model card in dict format. + + Returns: + SemanticVersion | Tuple[int, int, int]: The major, minor and patch version + """ + version_str: str = d["MedCAT Version"] + return get_semantic_version(version_str) + + +def get_semantic_version_from_model(cat: CAT) -> SemanticVersion: + """Get the semantic version of a CAT model. + + This uses the `get_version_from_modelcard` method on the model's + model card. + + So it is equivalen to `get_version_from_modelcard(cat.get_model_card(as_dict=True))`. + + Args: + cat (CAT): The CAT model. + + Returns: + SemanticVersion | Tuple[int, int, int]: The major, minor and patch version + """ + return get_version_from_modelcard(cat.get_model_card(as_dict=True)) + + +def get_version_from_cdb_dump(cdb_path: str) -> SemanticVersion: + """Get the version from a CDB dump (cdb.dat). + + The version information is expected in the following location: + cdb["config"]["version"]["medcat_version"] + + Args: + cdb_path (str): The path to cdb.dat + + Returns: + SemanticVersion | Tuple[int, int, int]: The major, minor and patch version + """ + with open(cdb_path, 'rb') as f: + d = dill.load(f) + config: dict = d["config"] + version = config["version"]["medcat_version"] + return get_semantic_version(version) + + +def get_version_from_modelpack_zip(zip_path: str, cdb_file_name=CDB_FILE_NAME) -> SemanticVersion: + """Get the semantic version from a MedCAT model pack zip file. + + This involves simply reading the config on file and reading the version information from there. + + The zip file is extracted if it has not yet been extracted. + + Args: + zip_path (str): The zip file path for the model pack. + cdb_file_name (str, optional): The CDB file name to use. Defaults to "cdb.dat". + + Returns: + SemanticVersion | Tuple[int, int, int]: The major, minor and patch version + """ + model_pack_path = CAT.attempt_unpack(zip_path) + return get_version_from_cdb_dump(os.path.join(model_pack_path, cdb_file_name)) + + +UPDATE_VERSION = (1, 3, 0) + + +class ConfigUpgrader: + """Config updater. + + Attempts to upgrade pre 1.3.0 medcat configs to the newer format. + + Args: + zip_path (str): The model pack zip path. + cdb_file_name (str, optional): The CDB file name. Defaults to "cdb.dat". + """ + + def __init__(self, zip_path: str, cdb_file_name: str = CDB_FILE_NAME) -> None: + self.model_pack_path = CAT.attempt_unpack(zip_path) + self.cdb_path = os.path.join(self.model_pack_path, cdb_file_name) + self.current_version = get_version_from_cdb_dump(self.cdb_path) + logger.debug("Loaded model from %s at version %s", + self.model_pack_path, self.current_version) + + def needs_upgrade(self) -> bool: + """Check if the specified modelpack needs an upgrade. + + It needs an upgrade if its version is less than 1.3.0. + + Returns: + bool: Whether or not an upgrade is needed. + """ + return self.current_version < UPDATE_VERSION + + def _get_relevant_files(self, ignore_hidden: bool = True) -> List[str]: + """Get the list of relevant files with full path names. + + By default this will ignore hidden files (those that start with '.'). + + Args: + ignore_hidden (bool, optional): Whether to ignore hidden files. Defaults to True. + + Returns: + List[str]: The list of relevant file names to copy. + """ + return [os.path.join(self.model_pack_path, fn) # ignores hidden files + for fn in os.listdir(self.model_pack_path) if (ignore_hidden and not fn.startswith("."))] + + def _check_existance(self, files_to_copy: List[str], new_path: str, overwrite: bool): + if overwrite: + return # ignore all + if not os.path.exists(new_path): + os.makedirs(new_path) + return # all good, new folder + # check file existance in new (existing) path + for file_to_copy in files_to_copy: + new_file_name = os.path.join( + new_path, os.path.basename(file_to_copy)) + if os.path.exists(new_file_name): + raise ValueError(f"File already exists: {new_file_name}. " + "Pass overwrite=True to overwrite") + + def _copy_files(self, files_to_copy: List[str], new_path: str) -> None: + for file_to_copy in files_to_copy: + new_file_name = os.path.join( + new_path, os.path.basename(file_to_copy)) + if os.path.isdir(file_to_copy): + # if exists is OK since it should have been checked before + # if it was not to be overwritten + logger.debug("Copying folder %s to %s", + file_to_copy, new_file_name) + shutil.copytree(file_to_copy, new_file_name, + dirs_exist_ok=True) + else: + logger.debug("Copying file %s to %s", + file_to_copy, new_file_name) + shutil.copy(file_to_copy, new_file_name) + + def upgrade(self, new_path: str, overwrite: bool = False) -> None: + """Upgrade the model. + + The upgrade copies all the files from the original folder + to the new folder. + + After copying, it changes the config into the format + required by MedCAT after version 1.3.0. + + Args: + new_path (str): The path for the new model pack folder. + overwrite (bool, optional): Whether to overwrite new path. Defaults to False. + + Raises: + ValueError: If one of the target files exists and cannot be overwritten. + ValueError: If model pack does not need an upgrade + """ + if not self.needs_upgrade(): + raise ValueError(f"Model pack does not need ugprade: {self.model_pack_path} " + f"since it's at version: {self.current_version}") + logger.info("Starting to upgrade %s at (version %s)", + self.model_pack_path, self.current_version) + files_to_copy = self._get_relevant_files() + self._check_existance(files_to_copy, new_path, overwrite) + logger.debug("Copying files from %s", self.model_pack_path) + self._copy_files(files_to_copy, new_path) + logger.info("Going to try and fix CDB") + self._fix_cdb(new_path) + self._make_archive(new_path) + + def _fix_cdb(self, new_path: str) -> None: + new_cdb_path = os.path.join(new_path, os.path.basename(self.cdb_path)) + with open(new_cdb_path, 'rb') as f: + data = dill.load(f) + # make the changes + + logger.debug("Fixing CDB issue #1 (linking.filters.cui)") + # Number 1 + # the linking.filters.cuis is set to "{}" + # which is assumed to be an empty set, but actually + # evaluates to an empty dict instead + cuis = data['config']['linking']['filters']['cuis'] + if cuis == {}: + # though it _should_ be the empty set + data['config']['linking']['filters']['cuis'] = set(cuis) + # save modified version + logger.debug("Saving CDB back into %s", new_cdb_path) + with open(new_cdb_path, 'wb') as f: + dill.dump(data, f) + + def _make_archive(self, new_path: str): + logger.debug("Taking data from %s and writing it to %s.zip", + new_path, new_path) + shutil.make_archive( + base_name=new_path, format='zip', base_dir=new_path) + + +def parse_args() -> argparse.Namespace: + """Parse the arguments from the CLI. + + Returns: + argparse.Namespace: The parsed arguments. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "action", help="The action. Currently, only 'fix-config' is available.", choices=['fix-config'], type=str.lower) + parser.add_argument("modelpack", help="MedCAT modelpack zip path") + parser.add_argument("newpath", help="The path for the new modelpack") + parser.add_argument( + "--overwrite", help="Allow overvwriting existing files", action="store_true") + parser.add_argument( + "--silent", help="Disable logging", action="store_true") + parser.add_argument( + "--verbose", help="Show debug output", action="store_true") + return parser.parse_args() + + +def setup_logging(args: argparse.Namespace) -> None: + """Setup logging for the runnable based on CLI arguments. + + Args: + args (argparse.Namespace): The parsed arguments. + """ + if not args.silent: + logger.addHandler(logging.StreamHandler()) + if args.verbose: + logger.setLevel(logging.DEBUG) + + +def fix_config(args: argparse.Namespace) -> None: + """Perform the fix-config action based on the CLI arguments. + + Args: + args (argparse.Namespace): The parsed arguments. + """ + logger.debug("Setting up upgrader") + upgrader = ConfigUpgrader(args.modelpack) + logger.debug("Starting the upgrade process") + upgrader.upgrade(args.newpath, overwrite=args.overwrite) + + +def main() -> None: + """Run the CLI associated with this module. + + Raises: + ValueError: If an unknown action is provided. + """ + args = parse_args() + setup_logging(args) + logger.debug("Will attempt to perform action %s", args.action) + if args.action == 'fix-config': + fix_config(args) + else: + raise ValueError(f"Unknown action: {args.action}") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index e0fd4ba1b..4e73b2f89 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ 'gensim>=4.3.0', # first to support 3.11 'spacy>=3.1.0', 'scipy~=1.9.2', # first to support 3.11 - 'transformers>=4.19.2', + 'transformers>=4.19.2,<4.22.0', # upper bound is needed for the de-id model until it is retrained 'torch>=1.13.0', # first to support 3.11 'tqdm>=4.27', 'scikit-learn>=1.1.3', # first to supporrt 3.11 @@ -39,7 +39,7 @@ 'xxhash>=3.0.0', # allow later versions, tested with 3.1.0 'blis>=0.7.5', # allow later versions, tested with 0.7.9 'click>=8.0.4', # allow later versions, tested with 8.1.3 - 'pydantic>=1.10.0', # for spacy compatibility + 'pydantic>=1.10.0,<2.0', # for spacy compatibility; avoid 2.0 due to breaking changes # the following are not direct dependencies of MedCAT but needed for docs/building # hopefully will no longer need the transitive dependencies 'aiohttp==3.8.3', # 3.8.3 is needed for compatibility with fsspec <- datasets <- medcat diff --git a/tests/resources/deid_train_data.json b/tests/resources/deid_train_data.json new file mode 100644 index 000000000..55310bd5d --- /dev/null +++ b/tests/resources/deid_train_data.json @@ -0,0 +1 @@ +{"projects": [{"name": "chatGPT-gen", "documents": [{"name": "doc_0", "text": "\nPatient Name: John Smith\nAddress: 15 Maple Avenue\nCity: New York\nCC: Chronic back pain\n\nHX: Mr. Smith is a 52-year-old male who has been experiencing chronic back pain for the past six months. The pain initially started after a lifting incident at work. He describes the pain as a dull ache in the lower back, which worsens with prolonged sitting or standing. He has tried over-the-counter pain medications with limited relief. Mr. Smith decided to seek medical attention due to the persistent nature of his symptoms.\n\nFHX: No significant family history of back pain or spinal conditions.\n\nSHX: Office worker. Non-smoker. Occasional alcohol consumption.\n\nPhysical examination revealed tenderness over the lumbar spine with no signs of neurological deficit. X-rays performed on 6/10/2023 showed degenerative changes in the lumbar spine, consistent with spondylosis.\n\nSeen by Dr. R. Johnson on 6/15/2023.\n\n", "annotations": [{"start": 15, "end": 25, "cui": "PATIENT", "value": " John Smith"}, {"start": 35, "end": 50, "cui": "HOSPITAL", "value": " 15 Maple Avenue"}, {"start": 57, "end": 65, "cui": "HOSPITAL", "value": " New York"}, {"start": 97, "end": 102, "cui": "PATIENT", "value": " Smith"}, {"start": 433, "end": 438, "cui": "PATIENT", "value": " Smith"}, {"start": 879, "end": 880, "cui": "DOCTOR", "value": " R"}, {"start": 882, "end": 889, "cui": "PATIENT", "value": " Johnson"}]}, {"name": "doc_1", "text": "\nPatient Name: Emily Davis\nAddress: 22 Willow Lane\nCity: Los Angeles\nCC: Allergic rhinitis\n\nHX: Miss Davis is a 28-year-old female who presents with symptoms of allergic rhinitis. She complains of frequent sneezing, nasal congestion, and itchy eyes, which have been bothering her for the past two years. Symptoms are worse during the spring and fall seasons and improve with over-the-counter antihistamines. Miss Davis seeks medical advice to explore other treatment options.\n\nFHX: No significant family history of allergic rhinitis or other allergic conditions.\n\nSHX: Office administrator. Non-smoker. No alcohol or drug use.\n\nNasal examination revealed pale, boggy nasal mucosa with clear nasal discharge. Skin prick testing conducted on 6/12/2023 demonstrated positive reactions to grass pollen and dust mites.\n\nSeen by Dr. S. Patel on 6/17/2023.\n\n", "annotations": [{"start": 15, "end": 26, "cui": "PATIENT", "value": " Emily Davis"}, {"start": 36, "end": 50, "cui": "HOSPITAL", "value": " 22 Willow Lane"}, {"start": 57, "end": 68, "cui": "HOSPITAL", "value": " Los Angeles"}, {"start": 101, "end": 106, "cui": "PATIENT", "value": " Davis"}, {"start": 413, "end": 418, "cui": "PATIENT", "value": " Davis"}, {"start": 827, "end": 828, "cui": "DOCTOR", "value": " S"}, {"start": 830, "end": 835, "cui": "PATIENT", "value": " Patel"}]}, {"name": "doc_2", "text": "\nPatient Name: Michael Johnson\nAddress: 10 Oak Street\nCity: Chicago\nCC: Acute bronchitis\n\nHX: Mr. Johnson is a 42-year-old male who presents with symptoms of acute bronchitis. He reports a cough productive of yellowish sputum, mild chest discomfort, and low-grade fever for the past five days. He denies any shortness of breath or wheezing. Mr. Johnson sought medical attention due to the persistence of symptoms and concern about the nature of his illness.\n\nFHX: No significant family history of respiratory conditions or chronic lung diseases.\n\nSHX: Construction worker. Non-smoker. Occasional alcohol consumption.\n\nPulmonary examination revealed scattered coarse breath sounds with no signs of consolidation. Chest X-ray performed on 6/13/2023 showed no evidence of pneumonia.\n\nSeen by Dr. L. Anderson on 6/16/2023.\n\n", "annotations": [{"start": 15, "end": 30, "cui": "PATIENT", "value": " Michael Johnson"}, {"start": 40, "end": 53, "cui": "HOSPITAL", "value": " 10 Oak Street"}, {"start": 60, "end": 67, "cui": "HOSPITAL", "value": " Chicago"}, {"start": 98, "end": 105, "cui": "PATIENT", "value": " Johnson"}, {"start": 345, "end": 352, "cui": "PATIENT", "value": " Johnson"}, {"start": 793, "end": 794, "cui": "DOCTOR", "value": " L"}, {"start": 796, "end": 804, "cui": "PATIENT", "value": " Anderson"}]}, {"name": "doc_3", "text": "\nPatient Name: Sarah Thompson\nAddress: 5 Elm Street\nCity: San Francisco\nCC: Migraine headaches\n\nHX: Miss Thompson is a 30-year-old female who complains of recurrent migraine headaches. She describes the headaches as pulsating, moderate to severe in intensity, lasting for several hours to a day. The headaches are usually accompanied by nausea, vomiting, and sensitivity to light and sound. Miss Thompson reports experiencing these episodes once or twice a month for the past two years. She seeks medical advice to explore treatment options and alleviate her symptoms.\n\nFHX: Maternal aunt had a history of migraines. No other significant family history of neurological conditions.\n\nSHX: Graphic designer. Non-smoker. Rare alcohol consumption.\n\nNeurological examination revealed no focal deficits. Miss Thompson's headache characteristics and frequency are consistent with a diagnosis of migraines.\n\nSeen by Dr. K. Roberts on 6/19/2023.\n\n", "annotations": [{"start": 15, "end": 29, "cui": "PATIENT", "value": " Sarah Thompson"}, {"start": 39, "end": 51, "cui": "HOSPITAL", "value": " 5 Elm Street"}, {"start": 58, "end": 71, "cui": "HOSPITAL", "value": " San Francisco"}, {"start": 105, "end": 113, "cui": "PATIENT", "value": " Thompson"}, {"start": 396, "end": 404, "cui": "PATIENT", "value": " Thompson"}, {"start": 802, "end": 810, "cui": "PATIENT", "value": " Thompson"}, {"start": 911, "end": 912, "cui": "DOCTOR", "value": " K"}, {"start": 914, "end": 921, "cui": "PATIENT", "value": " Roberts"}]}, {"name": "doc_4", "text": "\nPatient Name: David Wilson\nAddress: 3 Pine Street\nCity: Houston\nCC: Gastroesophageal reflux disease (GERD)\n\nHX: Mr. Wilson is a 48-year-old male who presents with symptoms of gastroesophageal reflux disease. He complains of frequent heartburn, regurgitation, and a bitter taste in his mouth, particularly after meals. Symptoms have been bothering him for the past six months, and he has noticed a decrease in his appetite and unintentional weight loss. Mr. Wilson seeks medical advice to manage his symptoms and address the weight loss.\n\nFHX: No significant family history of gastrointestinal conditions.\n\nSHX: Accountant. Non-smoker. Occasional alcohol consumption.\n\nAbdominal examination revealed epigastric tenderness. Upper endoscopy performed on 6/16/2023 demonstrated evidence of esophagitis and hiatal hernia.\n\nSeen by Dr. J. Anderson on 6/21/2023.\n\n", "annotations": [{"start": 15, "end": 27, "cui": "PATIENT", "value": " David Wilson"}, {"start": 37, "end": 50, "cui": "HOSPITAL", "value": " 3 Pine Street"}, {"start": 57, "end": 64, "cui": "HOSPITAL", "value": " Houston"}, {"start": 117, "end": 123, "cui": "PATIENT", "value": " Wilson"}, {"start": 458, "end": 464, "cui": "PATIENT", "value": " Wilson"}, {"start": 831, "end": 832, "cui": "DOCTOR", "value": " J"}, {"start": 834, "end": 842, "cui": "PATIENT", "value": " Anderson"}]}, {"name": "doc_5", "text": "\nPatient Name: Olivia Martinez\nAddress: 12 Rose Lane\nCity: Miami\nCC: Depression\n\nHX: Miss Martinez is a 36-year-old female who presents with symptoms of depression. She reports feeling persistent sadness, loss of interest in activities, decreased energy, changes in appetite and sleep patterns, and difficulty concentrating for the past six months. These symptoms have significantly affected her daily functioning and overall quality of life. Miss Martinez seeks medical assistance to address her depressive symptoms.\n\nFHX: No significant family history of mood disorders.\n\nSHX: Teacher. Non-smoker. No alcohol or drug use.\n\nPsychiatric evaluation revealed a depressed mood, anhedonia, and impaired concentration. Based on the clinical presentation, Miss Martinez meets the criteria for major depressive disorder.\n\nSeen by Dr. A. Ramirez on 6/23/2023.\n\n", "annotations": [{"start": 15, "end": 30, "cui": "PATIENT", "value": " Olivia Martinez"}, {"start": 40, "end": 52, "cui": "HOSPITAL", "value": " 12 Rose Lane"}, {"start": 59, "end": 64, "cui": "HOSPITAL", "value": " Miami"}, {"start": 90, "end": 98, "cui": "PATIENT", "value": " Martinez"}, {"start": 448, "end": 456, "cui": "PATIENT", "value": " Martinez"}, {"start": 755, "end": 763, "cui": "PATIENT", "value": " Martinez"}, {"start": 827, "end": 828, "cui": "DOCTOR", "value": " A"}, {"start": 830, "end": 837, "cui": "PATIENT", "value": " Ramirez"}]}, {"name": "doc_6", "text": "\nPatient Name: Daniel Lee\nAddress: 8 Maple Street\nCity: Seattle\nCC: Hypertension\n\nHX: Mr. Lee is a 58-year-old male who presents with elevated blood pressure readings during routine check-ups. He has a family history of hypertension and is concerned about his cardiovascular health. Mr. Lee has no associated symptoms but seeks medical advice to manage his blood pressure and reduce the risk of complications.\n\nFHX: Father and paternal grandfather had hypertension. No other significant family history of cardiovascular diseases.\n\nSHX: Engineer. Non-smoker. Occasional alcohol consumption.\n\nPhysical examination revealed blood pressure consistently above the normal range. Further investigations, including 24-hour ambulatory blood pressure monitoring, confirmed the diagnosis of essential hypertension.\n\nSeen by Dr. H. Johnson on 6/25/2023.\n\n", "annotations": [{"start": 15, "end": 25, "cui": "PATIENT", "value": " Daniel Lee"}, {"start": 35, "end": 49, "cui": "HOSPITAL", "value": " 8 Maple Street"}, {"start": 56, "end": 63, "cui": "HOSPITAL", "value": " Seattle"}, {"start": 90, "end": 93, "cui": "PATIENT", "value": " Lee"}, {"start": 287, "end": 290, "cui": "PATIENT", "value": " Lee"}, {"start": 817, "end": 818, "cui": "DOCTOR", "value": " H"}, {"start": 820, "end": 827, "cui": "PATIENT", "value": " Johnson"}]}, {"name": "doc_7", "text": "\nPatient Name: Sophia Adams\nAddress: 18 Cedar Avenue\nCity: Boston\nCC: Urinary tract infection (UTI)\n\nHX: Miss Adams is a 24-year-old female who complains of urinary frequency, urgency, and a burning sensation during urination. Symptoms started two days ago and have progressively worsened. She denies any hematuria or fever. Miss Adams seeks medical attention due to the persistence of symptoms and concern about a possible urinary tract infection.\n\nFHX: No significant family history of urinary tract infections.\n\nSHX: Marketing executive. Non-smoker. No alcohol or drug use.\n\nUrinalysis revealed pyuria and positive leukocyte esterase, indicating a urinary tract infection. A midstream urine culture confirmed the presence of Escherichia coli.\n\nSeen by Dr. M. Patel on 6/28/2023.\n\n", "annotations": [{"start": 15, "end": 27, "cui": "PATIENT", "value": " Sophia Adams"}, {"start": 37, "end": 52, "cui": "HOSPITAL", "value": " 18 Cedar Avenue"}, {"start": 59, "end": 65, "cui": "HOSPITAL", "value": " Boston"}, {"start": 110, "end": 115, "cui": "PATIENT", "value": " Adams"}, {"start": 330, "end": 335, "cui": "PATIENT", "value": " Adams"}, {"start": 759, "end": 760, "cui": "DOCTOR", "value": " M"}, {"start": 762, "end": 767, "cui": "PATIENT", "value": " Patel"}]}, {"name": "doc_8", "text": "\nPatient Name: Benjamin Thompson\nAddress: 25 Oak Street\nCity: Chicago\nCC: Seasonal allergies\n\nHX: Mr. Thompson is a 40-year-old male who presents with symptoms of seasonal allergies. He reports sneezing, itching, and a runny nose, particularly during the spring and summer months. Symptoms significantly interfere with his daily activities and sleep. Mr. Thompson seeks medical advice to manage his allergic symptoms.\n\nFHX: Mother had a history of seasonal allergies. No other significant family history of allergic conditions.\n\nSHX: IT specialist. Non-smoker. No alcohol or drug use.\n\nAllergy testing conducted on 6/26/2023 demonstrated positive reactions to grass pollen and tree pollen.\n\nSeen by Dr. E. Anderson on 6/30/2023.\n\n", "annotations": [{"start": 15, "end": 32, "cui": "PATIENT", "value": " Benjamin Thompson"}, {"start": 42, "end": 55, "cui": "HOSPITAL", "value": " 25 Oak Street"}, {"start": 62, "end": 69, "cui": "HOSPITAL", "value": " Chicago"}, {"start": 102, "end": 110, "cui": "PATIENT", "value": " Thompson"}, {"start": 355, "end": 363, "cui": "PATIENT", "value": " Thompson"}, {"start": 703, "end": 704, "cui": "DOCTOR", "value": " E"}, {"start": 706, "end": 714, "cui": "PATIENT", "value": " Anderson"}]}, {"name": "doc_9", "text": "\nPatient Name: Emma Davis\nAddress: 6 Willow Lane\nCity: Los Angeles\nCC: Anxiety\n\nHX: Miss Davis is a 32-year-old female who presents with symptoms of anxiety. She reports excessive worrying, restlessness, irritability, muscle tension, and difficulty concentrating. These symptoms have been present for the past six months and have\n\n\n", "annotations": [{"start": 15, "end": 25, "cui": "PATIENT", "value": " Emma Davis"}, {"start": 35, "end": 48, "cui": "HOSPITAL", "value": " 6 Willow Lane"}, {"start": 55, "end": 66, "cui": "HOSPITAL", "value": " Los Angeles"}, {"start": 89, "end": 94, "cui": "PATIENT", "value": " Davis"}]}, {"name": "doc_10", "text": "\nPatient Name: Alexander Johnson\nAddress: 9 Elm Street\nCity: San Francisco\nCC: Asthma\n\nHX: Mr. Johnson is a 28-year-old male who presents with symptoms of asthma. He complains of recurrent episodes of wheezing, shortness of breath, and chest tightness, particularly during physical activity and exposure to triggers such as dust and pollen. Symptoms have been present since childhood and have recently worsened. Mr. Johnson seeks medical assistance to manage his asthma symptoms and improve his quality of life.\n\nFHX: Mother and paternal uncle have a history of asthma. No other significant family history of respiratory conditions.\n\nSHX: Sales representative. Non-smoker. No alcohol or drug use.\n\nPulmonary function tests revealed airflow obstruction with significant reversibility after bronchodilator administration, confirming the diagnosis of asthma.\n\nSeen by Dr. N. Patel on 7/2/2023.\n\n", "annotations": [{"start": 15, "end": 32, "cui": "PATIENT", "value": " Alexander Johnson"}, {"start": 42, "end": 54, "cui": "HOSPITAL", "value": " 9 Elm Street"}, {"start": 61, "end": 74, "cui": "HOSPITAL", "value": " San Francisco"}, {"start": 95, "end": 102, "cui": "PATIENT", "value": " Johnson"}, {"start": 416, "end": 423, "cui": "PATIENT", "value": " Johnson"}, {"start": 869, "end": 870, "cui": "DOCTOR", "value": " N"}, {"start": 872, "end": 877, "cui": "PATIENT", "value": " Patel"}]}, {"name": "doc_11", "text": "\nPatient Name: Lily Wilson\nAddress: 4 Pine Street\nCity: Houston\nCC: Gastroenteritis\n\nHX: Miss Wilson is a 22-year-old female who presents with symptoms of gastroenteritis. She reports diarrhea, abdominal cramping, nausea, and vomiting, which started after consuming a meal at a local restaurant. Symptoms have been ongoing for the past 24 hours, and she is concerned about dehydration and the persistence of symptoms. Miss Wilson seeks medical advice for symptom relief and to ensure appropriate management.\n\nFHX: No significant family history of gastrointestinal conditions.\n\nSHX: Student. Non-smoker. No alcohol or drug use.\n\nPhysical examination revealed mild abdominal tenderness with no signs of peritonitis. Based on the clinical presentation and recent food exposure, the diagnosis of gastroenteritis is likely.\n\nSeen by Dr. K. Roberts on 7/5/2023.\n\n", "annotations": [{"start": 15, "end": 26, "cui": "PATIENT", "value": " Lily Wilson"}, {"start": 36, "end": 49, "cui": "HOSPITAL", "value": " 4 Pine Street"}, {"start": 56, "end": 63, "cui": "HOSPITAL", "value": " Houston"}, {"start": 94, "end": 100, "cui": "PATIENT", "value": " Wilson"}, {"start": 423, "end": 429, "cui": "PATIENT", "value": " Wilson"}, {"start": 832, "end": 833, "cui": "DOCTOR", "value": " K"}, {"start": 835, "end": 842, "cui": "PATIENT", "value": " Roberts"}]}, {"name": "doc_12", "text": "\nPatient Name: Noah Thompson\nAddress: 19 Cedar Avenue\nCity: Boston\nCC: Insomnia\n\nHX: Mr. Thompson is a 45-year-old male who complains of difficulty falling asleep and maintaining sleep. He reports frequent awakenings during the night and feeling unrefreshed upon waking up. These symptoms have been present for the past three months and significantly affect his daytime functioning. Mr. Thompson seeks medical assistance to address his insomnia and improve his sleep quality.\n\nFHX: No significant family history of sleep disorders.\n\nSHX: Financial analyst. Non-smoker. Occasional alcohol consumption.\n\nSleep diary records revealed prolonged sleep latency and frequent awakenings during the night. Based on the clinical presentation, Mr. Thompson meets the criteria for chronic insomnia disorder.\n\nSeen by Dr. S. Ramirez on 7/8/2023.\n\n", "annotations": [{"start": 15, "end": 28, "cui": "PATIENT", "value": " Noah Thompson"}, {"start": 38, "end": 53, "cui": "HOSPITAL", "value": " 19 Cedar Avenue"}, {"start": 60, "end": 66, "cui": "HOSPITAL", "value": " Boston"}, {"start": 89, "end": 97, "cui": "PATIENT", "value": " Thompson"}, {"start": 387, "end": 395, "cui": "PATIENT", "value": " Thompson"}, {"start": 737, "end": 745, "cui": "PATIENT", "value": " Thompson"}, {"start": 809, "end": 810, "cui": "DOCTOR", "value": " S"}, {"start": 812, "end": 819, "cui": "PATIENT", "value": " Ramirez"}]}, {"name": "doc_13", "text": "\nPatient Name: Chloe Adams\nAddress: 14 Cedar Avenue\nCity: Boston\nCC: Sinusitis\n\nHX: Miss Adams is a 26-year-old female who presents with symptoms of sinusitis. She reports nasal congestion, facial pressure, headache, and thick nasal discharge, which have been bothering her for the past week. Miss Adams tried over-the-counter nasal decongestants with minimal relief. She seeks medical assistance to manage her symptoms and prevent complications.\n\nFHX: No significant family history of sinusitis or chronic sinus conditions.\n\nSHX: Graphic designer. Non-smoker. No alcohol or drug use.\n\nNasal examination revealed erythematous nasal mucosa with purulent discharge. Based on the clinical presentation, Miss Adams is diagnosed with acute sinusitis.\n\nSeen by Dr. L. Anderson on 7/11/2023.\n\n", "annotations": [{"start": 15, "end": 26, "cui": "PATIENT", "value": " Chloe Adams"}, {"start": 36, "end": 51, "cui": "HOSPITAL", "value": " 14 Cedar Avenue"}, {"start": 58, "end": 64, "cui": "HOSPITAL", "value": " Boston"}, {"start": 89, "end": 94, "cui": "PATIENT", "value": " Adams"}, {"start": 298, "end": 303, "cui": "PATIENT", "value": " Adams"}, {"start": 705, "end": 710, "cui": "PATIENT", "value": " Adams"}, {"start": 759, "end": 760, "cui": "DOCTOR", "value": " L"}, {"start": 762, "end": 770, "cui": "PATIENT", "value": " Anderson"}]}, {"name": "doc_14", "text": "\nPatient Name: Grace Turner\nAddress: 11 Maple Avenue\nCity: New York\nCC: Rheumatoid arthritis\n\nHX: Miss Turner\n", "annotations": [{"start": 15, "end": 27, "cui": "PATIENT", "value": " Grace Turner"}, {"start": 37, "end": 52, "cui": "HOSPITAL", "value": " 11 Maple Avenue"}, {"start": 59, "end": 67, "cui": "HOSPITAL", "value": " New York"}, {"start": 103, "end": 109, "cui": "PATIENT", "value": " Turner"}]}, {"name": "doc_15", "text": "\nPatient Name: Ethan Harris\nAddress: 16 Pine Street\nCity: Houston\nCC: Gout\n\nHX: Mr. Harris is a 55-year-old male who presents with symptoms of gout. He reports sudden and severe joint pain, swelling, and redness in his right big toe. The symptoms started yesterday, and he has a history of similar episodes in the past. Mr. Harris seeks medical assistance to manage his acute gout attack and prevent future flares.\n\nFHX: No significant family history of gout or other rheumatic conditions.\n\nSHX: Retired. Non-smoker. Occasional alcohol consumption.\n\nPhysical examination revealed warmth, tenderness, and erythema in the affected joint. Based on the clinical presentation and history of recurrent episodes, Mr. Harris is diagnosed with acute gouty arthritis.\n\nSeen by Dr. M. Johnson on 7/14/2023.\n\n", "annotations": [{"start": 15, "end": 27, "cui": "PATIENT", "value": " Ethan Harris"}, {"start": 37, "end": 51, "cui": "HOSPITAL", "value": " 16 Pine Street"}, {"start": 58, "end": 65, "cui": "HOSPITAL", "value": " Houston"}, {"start": 84, "end": 90, "cui": "PATIENT", "value": " Harris"}, {"start": 324, "end": 330, "cui": "PATIENT", "value": " Harris"}, {"start": 710, "end": 716, "cui": "PATIENT", "value": " Harris"}, {"start": 771, "end": 772, "cui": "DOCTOR", "value": " M"}, {"start": 774, "end": 781, "cui": "PATIENT", "value": " Johnson"}]}, {"name": "doc_16", "text": "\nPatient Name: Mia Clark\nAddress: 7 Willow Lane\nCity: Los Angeles\nCC: Urinary incontinence\n\nHX: Miss Clark is a 62-year-old female who complains of urinary incontinence. She reports involuntary urine leakage, particularly with coughing, sneezing, and physical exertion. Symptoms have been present for the past six months and have progressively worsened. Miss Clark seeks medical advice to address her urinary incontinence and improve her quality of life.\n\nFHX: No significant family history of urinary incontinence or pelvic floor disorders.\n\nSHX: Retired. Non-smoker. No alcohol or drug use.\n\nPelvic examination revealed weakened pelvic floor muscles. Based on the clinical presentation, Miss Clark is diagnosed with stress urinary incontinence.\n\nSeen by Dr. E. Patel on 7/17/2023.\n\n", "annotations": [{"start": 15, "end": 24, "cui": "PATIENT", "value": " Mia Clark"}, {"start": 34, "end": 47, "cui": "HOSPITAL", "value": " 7 Willow Lane"}, {"start": 54, "end": 65, "cui": "HOSPITAL", "value": " Los Angeles"}, {"start": 101, "end": 106, "cui": "PATIENT", "value": " Clark"}, {"start": 359, "end": 364, "cui": "PATIENT", "value": " Clark"}, {"start": 694, "end": 699, "cui": "PATIENT", "value": " Clark"}, {"start": 760, "end": 761, "cui": "DOCTOR", "value": " E"}, {"start": 763, "end": 768, "cui": "PATIENT", "value": " Patel"}]}, {"name": "doc_17", "text": "\nPatient Name: Samuel Wright\nAddress: 20 Oak Street\nCity: Chicago\nCC: Osteoarthritis\n\nHX: Mr. Wright is a 70-year-old male who presents with symptoms of osteoarthritis. He reports joint pain, stiffness, and reduced range of motion in his knees and hands. Symptoms have been progressively worsening over the past year and significantly affect his daily activities. Mr. Wright seeks medical assistance to manage his osteoarthritis symptoms and improve his functional ability.\n\nFHX: No significant family history of musculoskeletal conditions.\n\nSHX: Retired. Non-smoker. No alcohol or drug use.\n\nPhysical examination revealed crepitus, bony enlargement, and limited range of motion in the affected joints. Based on the clinical presentation and imaging findings, Mr. Wright is diagnosed with osteoarthritis.\n\nSeen by Dr. R. Anderson on 7/20/2023.\n\n", "annotations": [{"start": 15, "end": 28, "cui": "PATIENT", "value": " Samuel Wright"}, {"start": 38, "end": 51, "cui": "HOSPITAL", "value": " 20 Oak Street"}, {"start": 58, "end": 65, "cui": "HOSPITAL", "value": " Chicago"}, {"start": 94, "end": 100, "cui": "PATIENT", "value": " Wright"}, {"start": 368, "end": 374, "cui": "PATIENT", "value": " Wright"}, {"start": 764, "end": 770, "cui": "PATIENT", "value": " Wright"}, {"start": 818, "end": 819, "cui": "DOCTOR", "value": " R"}, {"start": 821, "end": 829, "cui": "PATIENT", "value": " Anderson"}]}, {"name": "doc_18", "text": "\nPatient Name: Harper Turner\nAddress: 13 Maple Avenue\nCity: New York\nCC: Hypothyroidism\n\nHX: Miss Turner is a 30-year-old female who presents with symptoms of hypothyroidism. She reports fatigue, weight gain, cold intolerance, constipation, and dry skin. These symptoms have been present for the past six months and have gradually worsened. Miss Turner seeks medical assistance to evaluate her thyroid function and explore appropriate treatment options.\n\nFHX: No significant family history of thyroid disorders.\n\nSHX: Office manager. Non-smoker. No alcohol or drug use.\n\nLaboratory tests revealed elevated thyroid-stimulating hormone (TSH) levels and decreased free thyroxine (T4) levels, confirming the diagnosis of primary hypothyroidism.\n\nSeen by Dr. S. Johnson on 7/23/2023.\n\n", "annotations": [{"start": 15, "end": 28, "cui": "PATIENT", "value": " Harper Turner"}, {"start": 38, "end": 53, "cui": "HOSPITAL", "value": " 13 Maple Avenue"}, {"start": 60, "end": 68, "cui": "HOSPITAL", "value": " New York"}, {"start": 98, "end": 104, "cui": "PATIENT", "value": " Turner"}, {"start": 346, "end": 352, "cui": "PATIENT", "value": " Turner"}, {"start": 754, "end": 755, "cui": "DOCTOR", "value": " S"}, {"start": 757, "end": 764, "cui": "PATIENT", "value": " Johnson"}]}, {"name": "doc_19", "text": "\nPatient Name: Ava Lewis\nAddress: 10 Pine Street\nCity: Houston\n\n", "annotations": [{"start": 15, "end": 24, "cui": "PATIENT", "value": " Ava Lewis"}, {"start": 34, "end": 48, "cui": "HOSPITAL", "value": " 10 Pine Street"}]}, {"name": "doc_20", "text": "\nPatient Name: Henry Adams\nAddress: 5 Elm Street\nCity: San Francisco\nCC: Type 2 diabetes mellitus\n\nHX: Mr. Adams is a 50-year-old male who presents with symptoms of increased thirst, frequent urination, and unintentional weight loss. He reports feeling fatigued and has a family history of diabetes. Laboratory tests revealed elevated fasting blood glucose levels and HbA1c levels, indicating poor glycemic control. Mr. Adams seeks medical assistance to manage his diabetes and prevent complications.\n\nFHX: Father and two siblings have a history of type 2 diabetes.\n\nSHX: Teacher. Non-smoker. No alcohol or drug use.\n\nBased on the clinical presentation and laboratory findings, Mr. Adams is diagnosed with type 2 diabetes mellitus.\n\nSeen by Dr. N. Patel on 7/26/2023.\n\n", "annotations": [{"start": 15, "end": 26, "cui": "PATIENT", "value": " Henry Adams"}, {"start": 36, "end": 48, "cui": "HOSPITAL", "value": " 5 Elm Street"}, {"start": 55, "end": 68, "cui": "HOSPITAL", "value": " San Francisco"}, {"start": 107, "end": 112, "cui": "PATIENT", "value": " Adams"}, {"start": 420, "end": 425, "cui": "PATIENT", "value": " Adams"}, {"start": 682, "end": 687, "cui": "PATIENT", "value": " Adams"}, {"start": 745, "end": 746, "cui": "DOCTOR", "value": " N"}, {"start": 748, "end": 753, "cui": "PATIENT", "value": " Patel"}]}, {"name": "doc_21", "text": "\nPatient Name: Emily Wright\nAddress: 21 Oak Street\nCity: Chicago\nCC: Migraine headaches\n\nHX: Miss Wright is a 25-year-old female who presents with recurrent episodes of severe headache accompanied by nausea, vomiting, and sensitivity to light and sound. She reports experiencing these symptoms since adolescence and seeks medical assistance to manage her migraines and improve her quality of life.\n\nFHX: Mother has a history of migraines.\n\nSHX: Graphic designer. Non-smoker. Occasional alcohol consumption.\n\nThe clinical presentation and symptom pattern are consistent with a diagnosis of migraine headaches.\n\nSeen by Dr. E. Anderson on 7/29/2023.\n\n", "annotations": [{"start": 15, "end": 27, "cui": "PATIENT", "value": " Emily Wright"}, {"start": 37, "end": 50, "cui": "HOSPITAL", "value": " 21 Oak Street"}, {"start": 57, "end": 64, "cui": "HOSPITAL", "value": " Chicago"}, {"start": 98, "end": 104, "cui": "PATIENT", "value": " Wright"}, {"start": 622, "end": 623, "cui": "DOCTOR", "value": " E"}, {"start": 625, "end": 633, "cui": "PATIENT", "value": " Anderson"}]}, {"name": "doc_22", "text": "\nPatient Name: Oliver Mitchell\nAddress: 15 Cedar Avenue\nCity: Boston\nCC: Plantar fasciitis\n\nHX: Mr. Mitchell is a 42-year-old male who presents with heel pain that is worse in the morning and improves with activity. He reports experiencing pain for the past three months, particularly after prolonged periods of standing or walking. Mr. Mitchell seeks medical assistance to alleviate his foot pain and restore his normal daily activities.\n\nFHX: No significant family history of foot or musculoskeletal conditions.\n\nSHX: IT specialist. Non-smoker. No alcohol or drug use.\n\nPhysical examination revealed tenderness and pain along the plantar fascia. Based on the clinical presentation, Mr. Mitchell is diagnosed with plantar fasciitis.\n\nSeen by Dr. L. Patel on 8/2/2023.\n\n", "annotations": [{"start": 15, "end": 30, "cui": "PATIENT", "value": " Oliver Mitchell"}, {"start": 40, "end": 55, "cui": "HOSPITAL", "value": " 15 Cedar Avenue"}, {"start": 62, "end": 68, "cui": "HOSPITAL", "value": " Boston"}, {"start": 100, "end": 108, "cui": "PATIENT", "value": " Mitchell"}, {"start": 337, "end": 345, "cui": "PATIENT", "value": " Mitchell"}, {"start": 688, "end": 696, "cui": "PATIENT", "value": " Mitchell"}, {"start": 747, "end": 748, "cui": "DOCTOR", "value": " L"}, {"start": 750, "end": 755, "cui": "PATIENT", "value": " Patel"}]}, {"name": "doc_23", "text": "\nPatient Name: Victoria Turner\nAddress: 17 Maple Avenue\nCity: New York\nCC: Chronic obstructive pulmonary disease (COPD)\n\nHX: Miss Turner is a 60-year-old female who presents with symptoms of chronic cough, sputum production, and shortness of breath, particularly during physical exertion. She reports a history of smoking for 30 years. Pulmonary function tests revealed airflow limitation and reduced forced expiratory volume. Miss Turner seeks medical assistance to manage her COPD symptoms and optimize her respiratory function.\n\nFHX: No significant family history of respiratory conditions.\n\nSHX: Retired. Former smoker. No alcohol or drug use.\n\nBased on the clinical presentation, smoking history, and pulmonary function test results, Miss Turner is diagnosed with chronic obstructive pulmonary disease.\n\nSeen by Dr. S. Johnson on 8/5/2023.\n\n", "annotations": [{"start": 15, "end": 30, "cui": "PATIENT", "value": " Victoria Turner"}, {"start": 40, "end": 55, "cui": "HOSPITAL", "value": " 17 Maple Avenue"}, {"start": 62, "end": 70, "cui": "HOSPITAL", "value": " New York"}, {"start": 130, "end": 136, "cui": "PATIENT", "value": " Turner"}, {"start": 432, "end": 438, "cui": "PATIENT", "value": " Turner"}, {"start": 744, "end": 750, "cui": "PATIENT", "value": " Turner"}, {"start": 821, "end": 822, "cui": "DOCTOR", "value": " S"}, {"start": 824, "end": 831, "cui": "PATIENT", "value": " Johnson"}]}, {"name": "doc_24", "text": "\nPatient Name: Oliver Parker\nAddress: 22 Oak Street\nCity: Chicago\nCC: Allergic rhinitis\n\nHX: Mr. Parker is a 32-year-old male who presents with symptoms of allergic rhinitis. He reports sneezing, nasal congestion, itching, and a runny nose, particularly during the spring and fall seasons. Symptoms significantly interfere with his daily activities\n\n", "annotations": [{"start": 15, "end": 28, "cui": "PATIENT", "value": " Oliver Parker"}, {"start": 38, "end": 51, "cui": "HOSPITAL", "value": " 22 Oak Street"}, {"start": 58, "end": 65, "cui": "HOSPITAL", "value": " Chicago"}, {"start": 97, "end": 103, "cui": "PATIENT", "value": " Parker"}]}, {"name": "doc_25", "text": "\nPatient Name: Isabella Cooper\nAddress: 12 Willow Lane\nCity: Los Angeles\nCC: Anxiety disorder\n\nHX: Miss Cooper is a 27-year-old female who presents with symptoms of anxiety. She reports excessive worry, restlessness, irritability, muscle tension, and difficulty sleeping. These symptoms have been present for the past year and have progressively worsened. Miss Cooper seeks medical assistance to address her anxiety symptoms and improve her overall well-being.\n\nFHX: No significant family history of anxiety disorders.\n\nSHX: Accountant. Non-smoker. Occasional alcohol consumption.\n\nPsychiatric evaluation revealed symptoms consistent with generalized anxiety disorder. Miss Cooper is experiencing significant distress and impairment in multiple areas of her life.\n\nSeen by Dr. E. Ramirez on 8/8/2023.\n\n", "annotations": [{"start": 15, "end": 30, "cui": "PATIENT", "value": " Isabella Cooper"}, {"start": 40, "end": 54, "cui": "HOSPITAL", "value": " 12 Willow Lane"}, {"start": 61, "end": 72, "cui": "HOSPITAL", "value": " Los Angeles"}, {"start": 104, "end": 110, "cui": "PATIENT", "value": " Cooper"}, {"start": 361, "end": 367, "cui": "PATIENT", "value": " Cooper"}, {"start": 674, "end": 680, "cui": "PATIENT", "value": " Cooper"}, {"start": 777, "end": 778, "cui": "DOCTOR", "value": " E"}, {"start": 780, "end": 787, "cui": "PATIENT", "value": " Ramirez"}]}, {"name": "doc_26", "text": "\nPatient Name: Jacob Martinez\nAddress: 18 Elm Street\nCity: San Francisco\nCC: Hypertensive crisis\n\nHX: Mr. Martinez is a 60-year-old male with a known history of hypertension. He presents with severe headache, chest pain, and shortness of breath. He reports missing his antihypertensive medication for the past three days. Upon measurement, his blood pressure is significantly elevated. Mr. Martinez seeks urgent medical attention to manage his hypertensive crisis.\n\nFHX: Father had a history of hypertension and stroke.\n\nSHX: Retired. Non-smoker. Occasional alcohol consumption.\n\nPhysical examination and blood pressure measurements confirm the diagnosis of hypertensive crisis. Immediate interventions are initiated to lower blood pressure and prevent complications.\n\nSeen by Dr. H. Johnson on 8/11/2023.\n\n", "annotations": [{"start": 15, "end": 29, "cui": "PATIENT", "value": " Jacob Martinez"}, {"start": 39, "end": 52, "cui": "HOSPITAL", "value": " 18 Elm Street"}, {"start": 59, "end": 72, "cui": "HOSPITAL", "value": " San Francisco"}, {"start": 106, "end": 114, "cui": "PATIENT", "value": " Martinez"}, {"start": 390, "end": 398, "cui": "PATIENT", "value": " Martinez"}, {"start": 781, "end": 782, "cui": "DOCTOR", "value": " H"}, {"start": 784, "end": 791, "cui": "PATIENT", "value": " Johnson"}]}, {"name": "doc_27", "text": "\nPatient Name: Ava Foster\nAddress: 14 Pine Street\nCity: Houston\nCC: Peptic ulcer disease\n\nHX: Miss Foster is a 35-year-old female who presents with symptoms of abdominal pain, particularly in the upper abdomen. She reports a burning sensation and occasional nausea. Symptoms worsen after meals. Miss Foster seeks medical assistance to evaluate her abdominal pain and determine the underlying cause.\n\nFHX: No significant family history of gastrointestinal conditions.\n\nSHX: Marketing executive. Non-smoker. Occasional alcohol consumption.\n\nGastroscopy reveals a duodenal ulcer. Helicobacter pylori testing is performed, and the results confirm the presence of H. pylori infection.\n\nSeen by Dr. M. Johnson on 8/14/2023.\n\n", "annotations": [{"start": 15, "end": 25, "cui": "PATIENT", "value": " Ava Foster"}, {"start": 35, "end": 49, "cui": "HOSPITAL", "value": " 14 Pine Street"}, {"start": 56, "end": 63, "cui": "HOSPITAL", "value": " Houston"}, {"start": 99, "end": 105, "cui": "PATIENT", "value": " Foster"}, {"start": 300, "end": 306, "cui": "PATIENT", "value": " Foster"}, {"start": 693, "end": 694, "cui": "DOCTOR", "value": " M"}, {"start": 696, "end": 703, "cui": "PATIENT", "value": " Johnson"}]}, {"name": "doc_28", "text": "\nPatient Name: William Turner\nAddress: 11 Cedar Avenue\nCity: Boston\nCC: Major depressive disorder\n\nHX: Mr. Turner is a 38-year-old male who presents with symptoms of depression. He reports a persistent depressed mood, loss of interest in activities, feelings of worthlessness, changes in appetite, and difficulty concentrating. These symptoms have been present for the past six months and significantly impair his daily functioning. Mr. Turner seeks medical assistance to address his depressive symptoms.\n\nFHX: No significant family history of mood disorders.\n\nSHX: Software engineer. Non-smoker. No alcohol or drug use.\n\nPsychiatric evaluation reveals symptoms consistent with major depressive disorder. Mr. Turner exhibits significant distress and impairment in multiple areas of his life.\n\nSeen by Dr. L. Anderson on 8/17/2023.\n\n", "annotations": [{"start": 15, "end": 29, "cui": "PATIENT", "value": " William Turner"}, {"start": 39, "end": 54, "cui": "HOSPITAL", "value": " 11 Cedar Avenue"}, {"start": 61, "end": 67, "cui": "HOSPITAL", "value": " Boston"}, {"start": 107, "end": 113, "cui": "PATIENT", "value": " Turner"}, {"start": 437, "end": 443, "cui": "PATIENT", "value": " Turner"}, {"start": 709, "end": 715, "cui": "PATIENT", "value": " Turner"}, {"start": 805, "end": 806, "cui": "DOCTOR", "value": " L"}, {"start": 808, "end": 816, "cui": "PATIENT", "value": " Anderson"}]}, {"name": "doc_29", "text": "\nPatient Name: Sophia Reed\nAddress: 9 Willow Lane\nCity: Los Angeles\nCC: Iron-deficiency anemia\n\nHX: Miss Reed is a 29-year-old female who presents with symptoms of fatigue, weakness, and shortness of breath. She reports heavy menstrual bleeding and follows a vegetarian diet. Miss Reed seeks medical assistance to evaluate her symptoms and determine the cause of her anemia.\n\nFHX: No significant\n\n", "annotations": [{"start": 15, "end": 26, "cui": "PATIENT", "value": " Sophia Reed"}, {"start": 36, "end": 49, "cui": "HOSPITAL", "value": " 9 Willow Lane"}, {"start": 56, "end": 67, "cui": "HOSPITAL", "value": " Los Angeles"}, {"start": 105, "end": 109, "cui": "PATIENT", "value": " Reed"}, {"start": 281, "end": 285, "cui": "PATIENT", "value": " Reed"}]}, {"name": "doc_30", "text": "\nName: Olivia Davis\nAddress: 12 Elm Street\nCity: Springfield\nCC: Chronic back pain.\n\nHX: Ms. Davis is a 45-year-old female who presents with chronic lower back pain for the past six months. The pain is described as dull and aching, primarily localized to the lumbar region. It worsens with prolonged sitting or physical activity. She has tried over-the-counter pain medications with limited relief.\n\nFHX: No family history of chronic back pain or spinal disorders.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nOn examination, there is tenderness on palpation over the lumbar spine. Range of motion is slightly restricted. No neurological deficits are noted.\n\nSeen by Dr. R. Martinez on 10/15/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Olivia Davis"}, {"start": 29, "end": 42, "cui": "HOSPITAL", "value": " 12 Elm Street"}, {"start": 49, "end": 60, "cui": "HOSPITAL", "value": " Springfield"}, {"start": 93, "end": 98, "cui": "PATIENT", "value": " Davis"}, {"start": 686, "end": 687, "cui": "DOCTOR", "value": " R"}, {"start": 689, "end": 697, "cui": "PATIENT", "value": " Martinez"}]}, {"name": "doc_31", "text": "\nName: Ethan Thompson\nAddress: 18 Oak Avenue\nCity: Riverside\nCC: Abdominal pain.\n\nHX: Mr. Thompson is a 32-year-old male presenting with intermittent abdominal pain for the past two weeks. The pain is localized to the right lower quadrant and is associated with occasional nausea. It is not aggravated by food intake. No changes in bowel movements or urinary symptoms.\n\nFHX: No significant family history of abdominal disorders.\n\nSHX: Office worker. Non-smoker. Occasional alcohol consumption.\n\nAbdominal examination reveals tenderness and mild guarding in the right lower quadrant. No rebound tenderness or palpable masses are noted.\n\nSeen by Dr. S. Reynolds on 10/18/2023.\n\n", "annotations": [{"start": 7, "end": 21, "cui": "PATIENT", "value": " Ethan Thompson"}, {"start": 31, "end": 44, "cui": "HOSPITAL", "value": " 18 Oak Avenue"}, {"start": 51, "end": 60, "cui": "HOSPITAL", "value": " Riverside"}, {"start": 90, "end": 98, "cui": "PATIENT", "value": " Thompson"}, {"start": 648, "end": 649, "cui": "DOCTOR", "value": " S"}, {"start": 651, "end": 659, "cui": "PATIENT", "value": " Reynolds"}]}, {"name": "doc_32", "text": "\nName: Sophia Walker\nAddress: 9 Maple Lane\nCity: Willowville\nCC: Fatigue and weakness.\n\nHX: Ms. Walker is a 52-year-old female who presents with complaints of persistent fatigue and weakness for the past two months. She reports feeling tired even after a good night's sleep and experiences difficulty in performing routine tasks. No specific triggers or alleviating factors identified.\n\nFHX: No family history of chronic fatigue or neuromuscular disorders.\n\nSHX: Homemaker. Non-smoker. No alcohol consumption.\n\nPhysical examination reveals generalized weakness without focal neurological deficits. No abnormal findings on cardiovascular or respiratory examination.\n\nSeen by Dr. L. Carter on 10/21/2023.\n\n", "annotations": [{"start": 7, "end": 20, "cui": "PATIENT", "value": " Sophia Walker"}, {"start": 30, "end": 42, "cui": "HOSPITAL", "value": " 9 Maple Lane"}, {"start": 49, "end": 60, "cui": "HOSPITAL", "value": " Willowville"}, {"start": 96, "end": 102, "cui": "PATIENT", "value": " Walker"}, {"start": 678, "end": 679, "cui": "DOCTOR", "value": " L"}, {"start": 681, "end": 687, "cui": "PATIENT", "value": " Carter"}]}, {"name": "doc_33", "text": "\nName: Benjamin Harris\nAddress: 5 Pine Street\nCity: Meadowville\nCC: Headache and dizziness.\n\nHX: Mr. Harris is a 38-year-old male presenting with recurrent headaches and dizziness for the past month. The headaches are described as throbbing in nature and occur mostly in the afternoon. Dizziness is experienced upon standing up quickly or with sudden head movements.\n\nFHX: No significant family history of migraines or vestibular disorders.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nNeurological examination is unremarkable. No abnormal findings on visual acuity, coordination, or gait.\n\nSeen by Dr. M. Rodriguez on 10/24/2023.\n\n", "annotations": [{"start": 7, "end": 22, "cui": "PATIENT", "value": " Benjamin Harris"}, {"start": 32, "end": 45, "cui": "HOSPITAL", "value": " 5 Pine Street"}, {"start": 52, "end": 63, "cui": "HOSPITAL", "value": " Meadowville"}, {"start": 101, "end": 107, "cui": "PATIENT", "value": " Harris"}, {"start": 618, "end": 619, "cui": "DOCTOR", "value": " M"}, {"start": 621, "end": 630, "cui": "PATIENT", "value": " Rodriguez"}]}, {"name": "doc_34", "text": "\nName: Lily Green\nAddress: 23 Cedar Road\nCity: Woodville\nCC: Allergic rhinitis.\n\nHX: Ms. Green is a 28-year-old female presenting with symptoms of sneezing, nasal congestion, and itchy, watery eyes for the past few weeks. Symptoms are worse in the morning and improve throughout the day. She reports a history of seasonal allergies.\n\nFHX: No significant family history of allergic rhinitis or respiratory disorders.\n\nSHX: Teacher. Non-smoker. No alcohol consumption.\n\nPhysical examination reveals clear nasal discharge, congestion, and allergic shiners. No signs of respiratory distress.\n\nSeen by Dr. K. Mitchell on\n\n", "annotations": [{"start": 7, "end": 17, "cui": "PATIENT", "value": " Lily Green"}, {"start": 27, "end": 40, "cui": "HOSPITAL", "value": " 23 Cedar Road"}, {"start": 47, "end": 56, "cui": "HOSPITAL", "value": " Woodville"}, {"start": 89, "end": 94, "cui": "PATIENT", "value": " Green"}, {"start": 601, "end": 602, "cui": "DOCTOR", "value": " K"}, {"start": 604, "end": 612, "cui": "PATIENT", "value": " Mitchell"}]}, {"name": "doc_35", "text": "\nName: Henry Foster\nAddress: 14 Willow Street\nCity: Meadowville\nCC: Cough and shortness of breath.\n\nHX: Mr. Foster is a 62-year-old male presenting with a persistent cough and shortness of breath for the past two weeks. The cough is productive of yellowish sputum. He reports feeling breathless even with minimal exertion. No fever or chest pain.\n\nFHX: No significant family history of respiratory disorders.\n\nSHX: Retired. Former smoker (quit 10 years ago). No alcohol consumption.\n\nChest auscultation reveals decreased breath sounds and scattered crackles. No wheezing or dullness on percussion.\n\nSeen by Dr. S. Adams on 10/27/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Henry Foster"}, {"start": 29, "end": 45, "cui": "HOSPITAL", "value": " 14 Willow Street"}, {"start": 52, "end": 63, "cui": "HOSPITAL", "value": " Meadowville"}, {"start": 108, "end": 114, "cui": "PATIENT", "value": " Foster"}, {"start": 611, "end": 612, "cui": "DOCTOR", "value": " S"}, {"start": 614, "end": 619, "cui": "PATIENT", "value": " Adams"}]}, {"name": "doc_36", "text": "\nName: Emily Evans\nAddress: 8 Cherry Lane\nCity: Riverside\nCC: Sleep disturbances.\n\nHX: Ms. Evans is a 35-year-old female presenting with complaints of sleep disturbances for the past month. She reports difficulty falling asleep and frequent awakenings during the night. No daytime sleepiness or snoring. No significant life stressors identified.\n\nFHX: No family history of sleep disorders or psychiatric conditions.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nNo significant findings on physical examination. Normal mental status and intact concentration.\n\nSeen by Dr. L. Carter on 10/30/2023.\n\n", "annotations": [{"start": 7, "end": 18, "cui": "PATIENT", "value": " Emily Evans"}, {"start": 28, "end": 41, "cui": "HOSPITAL", "value": " 8 Cherry Lane"}, {"start": 48, "end": 57, "cui": "HOSPITAL", "value": " Riverside"}, {"start": 91, "end": 96, "cui": "PATIENT", "value": " Evans"}, {"start": 585, "end": 586, "cui": "DOCTOR", "value": " L"}, {"start": 588, "end": 594, "cui": "PATIENT", "value": " Carter"}]}, {"name": "doc_37", "text": "\nName: Samuel Hayes\nAddress: 11 Elm Street\nCity: Springfield\nCC: Abnormal mole.\n\nHX: Mr. Hayes is a 42-year-old male who noticed an abnormal mole on his back. The mole has increased in size and has an irregular border. He reports occasional itching but no pain or bleeding.\n\nFHX: No significant family history of skin cancer or melanoma.\n\nSHX: Construction worker. Non-smoker. Occasional alcohol consumption.\n\nSkin examination reveals a dark, asymmetrical mole with irregular borders and uneven coloration. No palpable lymph nodes in the surrounding area.\n\nSeen by Dr. R. Martinez on 11/2/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Samuel Hayes"}, {"start": 29, "end": 42, "cui": "HOSPITAL", "value": " 11 Elm Street"}, {"start": 49, "end": 60, "cui": "HOSPITAL", "value": " Springfield"}, {"start": 89, "end": 94, "cui": "PATIENT", "value": " Hayes"}, {"start": 569, "end": 570, "cui": "DOCTOR", "value": " R"}, {"start": 572, "end": 580, "cui": "PATIENT", "value": " Martinez"}]}, {"name": "doc_38", "text": "\nName: Isabella Simmons\nAddress: 6 Oak Avenue\nCity: Willowville\nCC: Joint pain and swelling.\n\nHX: Ms. Simmons is a 55-year-old female presenting with joint pain and swelling in her hands and knees for the past three months. The pain is worse in the morning and improves with movement. No history of trauma or recent infections.\n\nFHX: No significant family history of autoimmune disorders or arthritis.\n\nSHX: Teacher. Non-smoker. Rare alcohol consumption.\n\nJoint examination reveals swelling and tenderness in the proximal and distal interphalangeal joints and knees. No erythema or warmth.\n\nSeen by Dr. S. Reynolds on 11/5/2023.\n\n", "annotations": [{"start": 7, "end": 23, "cui": "PATIENT", "value": " Isabella Simmons"}, {"start": 33, "end": 45, "cui": "HOSPITAL", "value": " 6 Oak Avenue"}, {"start": 52, "end": 63, "cui": "HOSPITAL", "value": " Willowville"}, {"start": 102, "end": 109, "cui": "PATIENT", "value": " Simmons"}, {"start": 603, "end": 604, "cui": "DOCTOR", "value": " S"}, {"start": 606, "end": 614, "cui": "PATIENT", "value": " Reynolds"}]}, {"name": "doc_39", "text": "\nName: Daniel Thompson\nAddress: 9 Maple Lane\nCity: Willowville\nCC: Epigastric pain and heartburn.\n\nHX: Mr. Thompson is a 48-year-old male presenting with epigastric pain and heartburn for the past two weeks. The pain is described as a burning sensation and is aggravated by spicy foods and lying down after meals. No vomiting or black, tarry stools.\n\nFHX: No significant family history of gastrointestinal disorders.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nAbdominal examination reveals epigastric tenderness on palpation. No rebound tenderness or organomegaly.\n\nSeen by Dr. L. Carter on 11/8/2023.\n\n", "annotations": [{"start": 7, "end": 22, "cui": "PATIENT", "value": " Daniel Thompson"}, {"start": 32, "end": 44, "cui": "HOSPITAL", "value": " 9 Maple Lane"}, {"start": 51, "end": 62, "cui": "HOSPITAL", "value": " Willowville"}, {"start": 107, "end": 115, "cui": "PATIENT", "value": " Thompson"}, {"start": 595, "end": 596, "cui": "DOCTOR", "value": " L"}, {"start": 598, "end": 604, "cui": "PATIENT", "value": " Carter"}]}, {"name": "doc_40", "text": "\nName: Emily Turner\nAddress: 15 Pine Street\nCity: Meadowville\nCC: Fatigue and weight gain.\n\nHX: Ms. Turner is a 30-year-old female presenting with persistent fatigue and unexplained weight gain over the past six months. She reports feeling tired despite getting adequate sleep and has noticed a significant increase in her weight without changes in her diet or exercise routine.\n\nFHX: No significant family history of endocrine disorders or autoimmune conditions.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nPhysical examination reveals no specific abnormalities. No edema or thyroid enlargement palpable.\n\nSeen by Dr. M. Rodriguez on 11/11/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Emily Turner"}, {"start": 29, "end": 43, "cui": "HOSPITAL", "value": " 15 Pine Street"}, {"start": 50, "end": 61, "cui": "HOSPITAL", "value": " Meadowville"}, {"start": 100, "end": 106, "cui": "PATIENT", "value": " Turner"}, {"start": 635, "end": 636, "cui": "DOCTOR", "value": " M"}, {"start": 638, "end": 647, "cui": "PATIENT", "value": " Rodriguez"}]}, {"name": "doc_41", "text": "\nName: Oliver Clark\nAddress: 7 Cedar Road\nCity: Woodville\nCC: Swollen lymph nodes.\n\nHX: Mr. Clark is a 44-year-old male presenting with enlarged lymph nodes in his neck and groin for the past two weeks. The lymph nodes are painless and progressively increasing in size. No fever or night sweats reported.\n\nFHX: No significant family history of lymphatic disorders or malignancies.\n\nSHX: Teacher. Non-smoker. Rare alcohol consumption.\n\nLymph node examination reveals palpable, enlarged lymph nodes in the neck and groin regions. No other abnormal findings.\n\nSeen by Dr. K. Mitchell on 11/14/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Oliver Clark"}, {"start": 29, "end": 41, "cui": "HOSPITAL", "value": " 7 Cedar Road"}, {"start": 48, "end": 57, "cui": "HOSPITAL", "value": " Woodville"}, {"start": 92, "end": 97, "cui": "PATIENT", "value": " Clark"}, {"start": 569, "end": 570, "cui": "DOCTOR", "value": " K"}, {"start": 572, "end": 580, "cui": "PATIENT", "value": " Mitchell"}]}, {"name": "doc_42", "text": "\nName: Ava Patterson\nAddress: 13 Willow Street\nCity: Meadowville\nCC: Irregular menstrual cycles.\n\nHX: Ms. Patterson is a 27-year-old female presenting with irregular menstrual cycles for the past six months. She reports unpredictable timing, varying durations, and occasional heavy bleeding during her periods. No significant pain or other associated symptoms.\n\nFHX: No significant family history of gynecological disorders or hormonal imbalances.\n\nSHX: Office worker. Non-smoker. No alcohol consumption.\n\nPelvic examination reveals no palpable masses or tenderness. Normal external genitalia and vaginal walls.\n\nSeen by Dr. S. Adams on 11/17/2023.\n\n", "annotations": [{"start": 7, "end": 20, "cui": "PATIENT", "value": " Ava Patterson"}, {"start": 30, "end": 46, "cui": "HOSPITAL", "value": " 13 Willow Street"}, {"start": 53, "end": 64, "cui": "HOSPITAL", "value": " Meadowville"}, {"start": 106, "end": 115, "cui": "PATIENT", "value": " Patterson"}, {"start": 625, "end": 626, "cui": "DOCTOR", "value": " S"}, {"start": 628, "end": 633, "cui": "PATIENT", "value": " Adams"}]}, {"name": "doc_43", "text": "\nName: Noah Turner\nAddress: 16 Oak Avenue\nCity: Riverside\nCC: Frequent urination and increased thirst.\n\nHX: Mr. Turner is a 58-year-old male presenting with frequent urination and increased thirst for the past month. He reports waking up multiple times during the night to urinate and feeling constantly thirsty throughout the day. No significant weight changes or other urinary symptoms.\n\nFHX: No significant family history of diabetes or renal disorders.\n\nSHX: Retired. Non-smoker. Occasional alcohol consumption.\n\nNo specific findings on physical examination. No edema or signs of dehydration.\n\nSeen by Dr. L. Carter on 11/20/2023.\n\n", "annotations": [{"start": 7, "end": 18, "cui": "PATIENT", "value": " Noah Turner"}, {"start": 28, "end": 41, "cui": "HOSPITAL", "value": " 16 Oak Avenue"}, {"start": 48, "end": 57, "cui": "HOSPITAL", "value": " Riverside"}, {"start": 112, "end": 118, "cui": "PATIENT", "value": " Turner"}, {"start": 610, "end": 611, "cui": "DOCTOR", "value": " L"}, {"start": 613, "end": 619, "cui": "PATIENT", "value": " Carter"}]}, {"name": "doc_44", "text": "\nName: Mia Mitchell\nAddress: 10 Cherry Lane\nCity: Riverside\nCC: Skin rash and itching.\n\nHX: Ms. Mitchell is a 36-year-old female presenting with a skin rash and intense itching for the past week. The rash is characterized by red, raised bumps and appears primarily on her arms and legs. It worsens at night and with exposure to heat.\n\nFHX: No significant family history of skin conditions or allergies.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nSkin examination reveals multiple erythematous papules and plaques with excoriation marks. No signs of infection.\n\nSeen by Dr. R. Martinez on 11/23/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Mia Mitchell"}, {"start": 29, "end": 43, "cui": "HOSPITAL", "value": " 10 Cherry Lane"}, {"start": 50, "end": 59, "cui": "HOSPITAL", "value": " Riverside"}, {"start": 96, "end": 104, "cui": "PATIENT", "value": " Mitchell"}, {"start": 590, "end": 591, "cui": "DOCTOR", "value": " R"}, {"start": 593, "end": 601, "cui": "PATIENT", "value": " Martinez"}]}, {"name": "doc_45", "text": "\nName: Ethan Johnson\nAddress: 11 Maple Lane\nCity: Willowville\nCC: Abdominal bloating and constipation.\n\nHX: Mr. Johnson is a 50-year-old male presenting with complaints of abdominal bloating and constipation for the past two months. He reports feeling full quickly after eating and experiences infrequent bowel movements. No significant changes in diet or exercise.\n\nFHX: No significant family history of gastrointestinal disorders.\n\nSHX: Construction worker. Non-smoker. Occasional alcohol consumption.\n\nAbdominal examination reveals distension and mild tenderness on palpation. No masses or organomegaly appreciated.\n\nSeen by Dr. S. Reynolds on 11/26/2023.\n\n", "annotations": [{"start": 7, "end": 20, "cui": "PATIENT", "value": " Ethan Johnson"}, {"start": 30, "end": 43, "cui": "HOSPITAL", "value": " 11 Maple Lane"}, {"start": 50, "end": 61, "cui": "HOSPITAL", "value": " Willowville"}, {"start": 112, "end": 119, "cui": "PATIENT", "value": " Johnson"}, {"start": 632, "end": 633, "cui": "DOCTOR", "value": " S"}, {"start": 635, "end": 643, "cui": "PATIENT", "value": " Reynolds"}]}, {"name": "doc_46", "text": "\nName: Sophia Nelson\nAddress: 17 Elm Street\nCity: Springfield\nCC: Anxiety and panic attacks.\n\nHX: Ms. Nelson is a 33-year-old female presenting with symptoms of anxiety and recurrent panic attacks for the past six months. She describes episodes of sudden fear, rapid heartbeat, shortness of breath, and sweating. No specific triggers identified.\n\nFHX: No significant family history of anxiety or psychiatric disorders.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nNormal findings on physical examination. No signs of distress during the evaluation.\n\nSeen by Dr. M. Rodriguez on 11/29/2023.\n\n", "annotations": [{"start": 7, "end": 20, "cui": "PATIENT", "value": " Sophia Nelson"}, {"start": 30, "end": 43, "cui": "HOSPITAL", "value": " 17 Elm Street"}, {"start": 50, "end": 61, "cui": "HOSPITAL", "value": " Springfield"}, {"start": 102, "end": 108, "cui": "PATIENT", "value": " Nelson"}, {"start": 577, "end": 578, "cui": "DOCTOR", "value": " M"}, {"start": 580, "end": 589, "cui": "PATIENT", "value": " Rodriguez"}]}, {"name": "doc_47", "text": "\nName: Olivia Clark\nAddress: 9 Cedar Road\nCity: Woodville\nCC: Knee pain and swelling.\n\nHX: Ms. Clark is a 42-year-old female presenting with pain and swelling in her right knee for the past month. She reports that the symptoms started gradually and worsen with prolonged activity or climbing stairs. No history of trauma or previous knee issues.\n\nFHX: No significant family history of joint disorders or arthritis.\n\nSHX: Teacher. Non-smoker. No alcohol consumption.\n\nOn examination, there is swelling and tenderness in the right knee joint. Limited range of motion due to pain.\n\nSeen by Dr. K. Mitchell on 12/2/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Olivia Clark"}, {"start": 29, "end": 41, "cui": "HOSPITAL", "value": " 9 Cedar Road"}, {"start": 48, "end": 57, "cui": "HOSPITAL", "value": " Woodville"}, {"start": 95, "end": 100, "cui": "PATIENT", "value": " Clark"}, {"start": 591, "end": 592, "cui": "DOCTOR", "value": " K"}, {"start": 594, "end": 602, "cui": "PATIENT", "value": " Mitchell"}]}, {"name": "doc_48", "text": "\nName: Benjamin Anderson\nAddress: 12 Oak Avenue\nCity: Riverside\nCC: Sore throat and difficulty swallowing.\n\nHX: Mr. Anderson is a 28-year-old male presenting with a sore throat and difficulty swallowing for the past week. He reports pain and discomfort with swallowing, especially with solid foods. No fever, cough, or other respiratory symptoms.\n\nFHX: No significant family history of throat infections or inflammatory conditions.\n\nSHX: Office worker. Non-smoker. Occasional alcohol consumption.\n\nThroat examination reveals erythema and swelling of the posterior pharynx. No tonsillar enlargement or exudate.\n\nSeen by Dr. L. Carter on 12/5/2023.\n\n", "annotations": [{"start": 7, "end": 24, "cui": "PATIENT", "value": " Benjamin Anderson"}, {"start": 34, "end": 47, "cui": "HOSPITAL", "value": " 12 Oak Avenue"}, {"start": 54, "end": 63, "cui": "HOSPITAL", "value": " Riverside"}, {"start": 116, "end": 124, "cui": "PATIENT", "value": " Anderson"}, {"start": 623, "end": 624, "cui": "DOCTOR", "value": " L"}, {"start": 626, "end": 632, "cui": "PATIENT", "value": " Carter"}]}, {"name": "doc_49", "text": "\nName: Lily Cooper\nAddress: 14 Cherry Lane\nCity: Riverside\nCC: Frequent headaches.\n\nHX: Ms. Cooper is a 25-year-old female presenting with recurrent headaches for the past three months. The headaches occur several times a week and are described as throbbing in nature. No specific triggers or associated symptoms identified.\n\nFHX: No significant family history of migraines or neurological disorders.\n\nSHX: Teacher. Non-smoker. Rare alcohol consumption.\n\nNormal neurological examination. No focal deficits or abnormalities.\n\nSeen by Dr. R. Martinez on 12/8/2023.\n\n", "annotations": [{"start": 7, "end": 18, "cui": "PATIENT", "value": " Lily Cooper"}, {"start": 28, "end": 42, "cui": "HOSPITAL", "value": " 14 Cherry Lane"}, {"start": 49, "end": 58, "cui": "HOSPITAL", "value": " Riverside"}, {"start": 92, "end": 98, "cui": "PATIENT", "value": " Cooper"}, {"start": 537, "end": 538, "cui": "DOCTOR", "value": " R"}, {"start": 540, "end": 548, "cui": "PATIENT", "value": " Martinez"}]}, {"name": "doc_50", "text": "\nName: Sophia Williams\nAddress: 15 Elm Street\nCity: Springfield\nCC: Fatigue and muscle weakness.\n\nHX: Ms. Williams is a 42-year-old female presenting with persistent fatigue and muscle weakness for the past two months. She reports feeling tired even after getting sufficient rest and experiences difficulty performing daily activities. No significant weight changes or other associated symptoms.\n\nFHX: No significant family history of muscular disorders or autoimmune conditions.\n\nSHX: Office worker. Non-smoker. Occasional alcohol consumption.\n\nPhysical examination reveals decreased muscle strength and generalized weakness. No specific findings on neurological evaluation.\n\nSeen by Dr. S. Adams on 12/11/2023.\n\n", "annotations": [{"start": 7, "end": 22, "cui": "PATIENT", "value": " Sophia Williams"}, {"start": 32, "end": 45, "cui": "HOSPITAL", "value": " 15 Elm Street"}, {"start": 52, "end": 63, "cui": "HOSPITAL", "value": " Springfield"}, {"start": 106, "end": 114, "cui": "PATIENT", "value": " Williams"}, {"start": 689, "end": 690, "cui": "DOCTOR", "value": " S"}, {"start": 692, "end": 697, "cui": "PATIENT", "value": " Adams"}]}, {"name": "doc_51", "text": "\nName: Benjamin Turner\nAddress: 16 Pine Street\nCity: Meadowville\nCC: Chest pain and shortness of breath.\n\nHX: Mr. Turner is a 52-year-old male presenting with complaints of chest pain and shortness of breath for the past week. The chest pain is described as a squeezing sensation and is accompanied by breathlessness during exertion. No associated symptoms of dizziness or palpitations.\n\nFHX: No significant family history of cardiac disorders or cardiovascular conditions.\n\nSHX: Retired. Non-smoker. Rare alcohol consumption.\n\nCardiovascular examination reveals regular heart sounds and no murmurs. No signs of respiratory distress.\n\nSeen by Dr. M. Rodriguez on 12/14/2023.\n\n", "annotations": [{"start": 7, "end": 22, "cui": "PATIENT", "value": " Benjamin Turner"}, {"start": 32, "end": 46, "cui": "HOSPITAL", "value": " 16 Pine Street"}, {"start": 53, "end": 64, "cui": "HOSPITAL", "value": " Meadowville"}, {"start": 114, "end": 120, "cui": "PATIENT", "value": " Turner"}, {"start": 647, "end": 648, "cui": "DOCTOR", "value": " M"}, {"start": 650, "end": 659, "cui": "PATIENT", "value": " Rodriguez"}]}, {"name": "doc_52", "text": "\nName: Chloe Parker\nAddress: 13 Cedar Road\nCity: Woodville\nCC: Frequent urination and burning sensation.\n\nHX: Ms. Parker is a 30-year-old female presenting with frequent urination and a burning sensation during urination for the past week. She reports a sense of urgency to urinate and occasional lower abdominal discomfort. No fever or back pain.\n\nFHX: No significant family history of urinary tract infections or urological conditions.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nNo specific findings on physical examination. No costovertebral angle tenderness.\n\nSeen by Dr. K. Mitchell on 12/17/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Chloe Parker"}, {"start": 29, "end": 42, "cui": "HOSPITAL", "value": " 13 Cedar Road"}, {"start": 49, "end": 58, "cui": "HOSPITAL", "value": " Woodville"}, {"start": 114, "end": 120, "cui": "PATIENT", "value": " Parker"}, {"start": 593, "end": 594, "cui": "DOCTOR", "value": " K"}, {"start": 596, "end": 604, "cui": "PATIENT", "value": " Mitchell"}]}, {"name": "doc_53", "text": "\nName: Oliver Lewis\nAddress: 10 Maple Lane\nCity: Willowville\nCC: Vision changes and eye pain.\n\nHX: Mr. Lewis is a 60-year-old male presenting with vision changes and intermittent eye pain in his right eye for the past month. He reports blurred vision and the sensation of pressure in the eye. No redness or discharge noted.\n\nFHX: No significant family history of eye disorders or ocular conditions.\n\nSHX: Construction worker. Non-smoker. Occasional alcohol consumption.\n\nVisual acuity testing reveals decreased vision in the right eye. No external abnormalities or conjunctival injection.\n\nSeen by Dr. S. Reynolds on 12/20/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Oliver Lewis"}, {"start": 29, "end": 42, "cui": "HOSPITAL", "value": " 10 Maple Lane"}, {"start": 49, "end": 60, "cui": "HOSPITAL", "value": " Willowville"}, {"start": 103, "end": 108, "cui": "PATIENT", "value": " Lewis"}, {"start": 602, "end": 603, "cui": "DOCTOR", "value": " S"}, {"start": 605, "end": 613, "cui": "PATIENT", "value": " Reynolds"}]}, {"name": "doc_54", "text": "\nName: Emma Peterson\nAddress: 17 Oak Avenue\nCity: Riverside\nCC: Abdominal pain and diarrhea.\n\nHX: Ms. Peterson is a 38-year-old female presenting with abdominal pain and frequent episodes of diarrhea for the past week. The abdominal pain is crampy in nature and is associated with loose, watery stools. No blood or mucus in the stool.\n\nFHX: No significant family history of gastrointestinal disorders.\n\nSHX: Teacher. Non-smoker. No alcohol consumption.\n\nAbdominal examination reveals tenderness in the lower abdomen. No rebound tenderness or palpable masses.\n\nSeen by Dr. L. Carter on 12/23/2023.\n\n", "annotations": [{"start": 7, "end": 20, "cui": "PATIENT", "value": " Emma Peterson"}, {"start": 30, "end": 43, "cui": "HOSPITAL", "value": " 17 Oak Avenue"}, {"start": 50, "end": 59, "cui": "HOSPITAL", "value": " Riverside"}, {"start": 102, "end": 110, "cui": "PATIENT", "value": " Peterson"}, {"start": 572, "end": 573, "cui": "DOCTOR", "value": " L"}, {"start": 575, "end": 581, "cui": "PATIENT", "value": " Carter"}]}, {"name": "doc_55", "text": "\nName: Amelia Adams\nAddress: 11 Willow Street\nCity: Meadowville\nCC: Depression and loss of interest.\n\nHX: Ms. Adams is a 35-year-old female presenting with symptoms of depression and loss of interest in activities for the past six months. She reports feeling sad, hopeless, and having a decreased motivation to engage in previously enjoyed hobbies. No suicidal thoughts or changes in appetite.\n\nFHX: No significant family history of mood disorders or psychiatric conditions.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nNo specific findings on physical examination. No signs of distress during the evaluation.\n\nSeen by Dr. R. Martinez on 12/26/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Amelia Adams"}, {"start": 29, "end": 45, "cui": "HOSPITAL", "value": " 11 Willow Street"}, {"start": 52, "end": 63, "cui": "HOSPITAL", "value": " Meadowville"}, {"start": 110, "end": 115, "cui": "PATIENT", "value": " Adams"}, {"start": 638, "end": 639, "cui": "DOCTOR", "value": " R"}, {"start": 641, "end": 649, "cui": "PATIENT", "value": " Martinez"}]}, {"name": "doc_56", "text": "\nName: Henry Turner\nAddress: 12 Pine Street\nCity: Meadowville\nCC: Joint pain and stiffness.\n\nHX: Mr. Turner is a 60-year-old male presenting with joint pain and stiffness in his hands and knees for the past three months. He reports difficulty with movements, especially in the mornings, and occasional swelling in the affected joints. No history of trauma or previous joint disorders.\n\nFHX: No significant family history of arthritis or rheumatic conditions.\n\nSHX: Retired. Non-smoker. Rare alcohol consumption.\n\nOn examination, there is tenderness, warmth, and swelling in the affected joints. Limited range of motion due to pain.\n\nSeen by Dr. M. Rodriguez on 12/29/2023.\n\n", "annotations": [{"start": 7, "end": 19, "cui": "PATIENT", "value": " Henry Turner"}, {"start": 29, "end": 43, "cui": "HOSPITAL", "value": " 12 Pine Street"}, {"start": 50, "end": 61, "cui": "HOSPITAL", "value": " Meadowville"}, {"start": 101, "end": 107, "cui": "PATIENT", "value": " Turner"}, {"start": 645, "end": 646, "cui": "DOCTOR", "value": " M"}, {"start": 648, "end": 657, "cui": "PATIENT", "value": " Rodriguez"}]}, {"name": "doc_57", "text": "\nName: Harper Mitchell\nAddress: 14 Cedar Road\nCity: Woodville\nCC: Allergic rhinitis and nasal congestion.\n\nHX: Ms. Mitchell is a 28-year-old female presenting with symptoms of allergic rhinitis, including nasal congestion, sneezing, and itchy eyes, for the past two weeks. She reports these symptoms are worse in the morning and in certain environments. No history of sinus infections or nasal polyps.\n\nFHX: No significant family history of allergies or respiratory conditions.\n\nSHX: Office worker. Non-smoker. Rare alcohol consumption.\n\nNasal examination reveals nasal congestion, clear rhinorrhea, and pale, boggy nasal mucosa. No signs of septal deviation or polyps.\n\nSeen by Dr. K. Mitchell on 1/2/2024.\n\n", "annotations": [{"start": 7, "end": 22, "cui": "PATIENT", "value": " Harper Mitchell"}, {"start": 32, "end": 45, "cui": "HOSPITAL", "value": " 14 Cedar Road"}, {"start": 52, "end": 61, "cui": "HOSPITAL", "value": " Woodville"}, {"start": 115, "end": 123, "cui": "PATIENT", "value": " Mitchell"}, {"start": 683, "end": 684, "cui": "DOCTOR", "value": " K"}, {"start": 686, "end": 694, "cui": "PATIENT", "value": " Mitchell"}]}, {"name": "doc_58", "text": "\nName: Jackson Turner\nAddress: 9 Maple Lane\nCity: Willowville\nCC: Sleep disturbances and daytime sleepiness.\n\nHX: Mr. Turner is a 45-year-old male presenting with complaints of sleep disturbances and excessive daytime sleepiness for the past three months. He reports difficulty falling asleep, frequent awakenings during the night, and feeling tired during the day despite sufficient hours of sleep.\n\nFHX: No significant family history of sleep disorders or neurological conditions.\n\nSHX: Construction worker. Non-smoker. Occasional alcohol consumption.\n\nNo specific findings on physical examination. No signs of respiratory disorders.\n\nSeen by Dr. S. Reynolds on 1/5/2024.\n\n", "annotations": [{"start": 7, "end": 21, "cui": "PATIENT", "value": " Jackson Turner"}, {"start": 31, "end": 43, "cui": "HOSPITAL", "value": " 9 Maple Lane"}, {"start": 50, "end": 61, "cui": "HOSPITAL", "value": " Willowville"}, {"start": 118, "end": 124, "cui": "PATIENT", "value": " Turner"}, {"start": 649, "end": 650, "cui": "DOCTOR", "value": " S"}, {"start": 652, "end": 660, "cui": "PATIENT", "value": " Reynolds"}]}, {"name": "doc_59", "text": "\nName: Penelope Walker\nAddress: 13 Oak Avenue\nCity: Riverside\nCC: Nausea and vomiting.\n\nHX: Ms. Walker is a 42-year-old female presenting with symptoms of nausea and vomiting for the past two days. She reports episodes of sudden, uncontrollable vomiting and a persistent feeling of queasiness. No abdominal pain or changes in bowel movements.\n\nFHX: No significant family history of gastrointestinal disorders.\n\nSHX: Teacher. Non-smoker. No alcohol consumption.\n\nAbdominal examination reveals no tenderness or palpable masses. No signs of dehydration.\n\nSeen by Dr. L. Carter on 1/8/2024.", "annotations": [{"start": 7, "end": 22, "cui": "PATIENT", "value": " Penelope Walker"}, {"start": 32, "end": 45, "cui": "HOSPITAL", "value": " 13 Oak Avenue"}, {"start": 52, "end": 61, "cui": "HOSPITAL", "value": " Riverside"}, {"start": 96, "end": 102, "cui": "PATIENT", "value": " Walker"}, {"start": 564, "end": 565, "cui": "DOCTOR", "value": " L"}, {"start": 567, "end": 573, "cui": "PATIENT", "value": " Carter"}]}]}]} \ No newline at end of file diff --git a/tests/test_cdb.py b/tests/test_cdb.py index 219e43d33..96425bc8c 100644 --- a/tests/test_cdb.py +++ b/tests/test_cdb.py @@ -64,6 +64,23 @@ def test_empty_count_train(self): stats = self.undertest.make_stats() self.assertFalse(np.isnan(stats["Average training examples per concept"])) self.undertest.cui2count_train = copied + + def test_remove_cui(self): + self.undertest.remove_cui('C0000039') + assert 'C0000039' not in self.undertest.cui2names + assert 'C0000039' not in self.undertest.cui2snames + assert 'C0000039' not in self.undertest.cui2count_train + assert 'C0000039' not in self.undertest.cui2type_ids + assert 'C0000039' not in self.undertest.cui2preferred_name + assert 'C0000039' not in self.undertest.name2cuis['virus~z'] + assert 'C0000039' not in self.undertest.name2cuis2status['virus~z'] + + def test_cui2snames_population(self): + self.undertest.cui2snames.clear() + self.undertest.populate_cui2snames() + for cui in self.undertest.cui2names: + with self.subTest(cui): + self.assertIn(cui, self.undertest.cui2snames) if __name__ == '__main__': unittest.main() diff --git a/tests/utils/ner/__init__.py b/tests/utils/ner/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utils/ner/test_deid.py b/tests/utils/ner/test_deid.py new file mode 100644 index 000000000..dcc8938b8 --- /dev/null +++ b/tests/utils/ner/test_deid.py @@ -0,0 +1,118 @@ +from medcat.utils.ner import deid +from medcat.utils.ner import make_or_update_cdb + +from medcat.ner import transformers_ner + +from spacy.tokens import Doc + +from typing import Any, List, Tuple +import os + +import unittest + +FILE_DIR = os.path.dirname(os.path.realpath(__file__)) + + +# NB! This 'training data' is extremely flawed +# it is only (somewhat) useful for the purpose of this +# test +# DO NOT USE THIS DATA ELSEWHERE - IT WILL NOT BE USEFUL +TRAIN_DATA = os.path.join(FILE_DIR, "..", "..", + "resources", "deid_train_data.json") + + +class DeIDmodelCreationTests(unittest.TestCase): + + def test_can_make_cdb(self): + cdb = make_or_update_cdb(TRAIN_DATA) + self.assertIsNotNone(cdb) + + def test_can_create_model(self): + cdb = make_or_update_cdb(TRAIN_DATA) + config = transformers_ner.ConfigTransformersNER() + config.general['test_size'] = 0.1 # Usually set this to 0.1-0.2 + ner = transformers_ner.TransformersNER(cdb=cdb, config=config) + deid_model = deid.DeIdModel.create(ner) + self.assertIsNotNone(deid_model) + + +def _add_model(cls): + cdb = make_or_update_cdb(TRAIN_DATA) + config = transformers_ner.ConfigTransformersNER() + config.general['test_size'] = 0.1 # Usually set this to 0.1-0.2 + cls.ner = transformers_ner.TransformersNER(cdb=cdb, config=config) + cls.ner.training_arguments.num_train_epochs = 1 # Use 5-10 normally + # As we are NOT training on a GPU that can, we'll set it to 1 + cls.ner.training_arguments.per_device_train_batch_size = 1 + cls.ner.training_arguments.gradient_accumulation_steps = 1 # No need for acc + cls.ner.training_arguments.per_device_eval_batch_size = 1 + # For the metric to be used for best model we pick Recall here, as for deid that is most important + cls.ner.training_arguments.metric_for_best_model = 'eval_recall' + cls.deid_model = deid.DeIdModel.create(cls.ner) + + +def train_model_once(model: deid.DeIdModel, + _trained: List[Tuple[Tuple[Any, Any, Any], + deid.DeIdModel]] = [] + ) -> Tuple[Tuple[Any, Any, Any], deid.DeIdModel]: + if not _trained: + retval = model.train(TRAIN_DATA) + _trained.append((retval, model)) + return _trained[0] + + +class DeIDModelTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + _add_model(cls) + + def test_training(self): + df, examples, dataset = train_model_once(self.deid_model)[0] + self.assertIsNotNone(df) + self.assertIsNotNone(examples) + self.assertIsNotNone(dataset) + + +input_text = ''' +James Joyce +7 Eccles Street, +Dublin +CC: Memory difficulty. + +HX: Mr James is a 64 y/o RHM, had difficulty remembering names, phone numbers and events for 12 months prior to presentation, on 2/28/95. He had visited London recently and had had no professional or social faux pas or mishaps due to his memory. J.J. could not tell whether his problem was becoming worse, so he brought himself to the Neurology clinic on his own referral. + +FHX: Both parents (Mary and John) experienced memory problems in their ninth decades, but not earlier. 5 siblings have had no memory trouble. There are no neurological illnesses in his family. + +SHX: Writer and Poet. Tobacco/ETOH/illicit drug use. + +The rest of the neurologic exam was unremarkable and there were no extrapyramidal signs or primitive reflexes noted. +11/1996 in Dublin. + +The findings indicated multiple areas of cerebral dysfunction. With the exception of the patient's report of minimal occupational dysfunction ( which may reflect poor insight), the clinical picture is consistent with a progressive dementia syndrome such as Alzheimer disease. MRI brain, 3/6/95, showed mild generalized atrophy, more severe in the occipital-parietal regions. + +Seen by Dr. M. Sully on 11/11/1996. +''' + + +class DeIDModelWorks(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + _add_model(cls) + cls.deid_model = train_model_once(cls.deid_model)[1] + + def test_model_works_deid_text(self): + anon_text = self.deid_model.deid_text(input_text) + self.assertIn("[DOCTOR]", anon_text) + self.assertIn("[HOSPITAL]", anon_text) + + def test_model_works_dunder_call(self): + anon_doc = self.deid_model(input_text) + self.assertIsInstance(anon_doc, Doc) + + def test_model_works_deid_text_redact(self): + anon_text = self.deid_model.deid_text(input_text, redact=True) + self.assertIn("****", anon_text) + self.assertNotIn("[DOCTOR]", anon_text) + self.assertNotIn("[HOSPITAL]", anon_text) diff --git a/tests/utils/regression/test_metadata.py b/tests/utils/regression/test_metadata.py index ad3a83533..7dd6f0911 100644 --- a/tests/utils/regression/test_metadata.py +++ b/tests/utils/regression/test_metadata.py @@ -4,6 +4,7 @@ # from Anthony's 1.4 model +EXAMPLE_VERSION = 1, 3, 0 MODEL_CARD_EXAMPLE = { "Model ID": "acd0dfc2f0df45de", "Last Modified On": "04 October 2022", diff --git a/tests/utils/saving/test_coding.py b/tests/utils/saving/test_coding.py new file mode 100644 index 000000000..c60a3b1f2 --- /dev/null +++ b/tests/utils/saving/test_coding.py @@ -0,0 +1,77 @@ +from medcat.utils.saving import coding + +import json + +import unittest + + +class SetEncodeTests(unittest.TestCase): + string2sets_dict1 = {'s1': set(['v1', 'v2', 'v3']), + 's2': set(['u1', 'u2', 'u3'])} + string2sets_dict2 = {'p1': set([1, 2, 3]), + 'p2': set([3, 4, 5])} + + def serialise(self, d: dict) -> str: + return json.dumps(d, cls=coding.CustomDelegatingEncoder.def_inst) + + def _helper_serialises(self, d: dict): + s = self.serialise(d) + self.assertIsInstance(s, str) + + def test_sets_of_strings_serialise(self): + self._helper_serialises(self.string2sets_dict1) + + def test_sets_of_ints_serialise(self): + self._helper_serialises(self.string2sets_dict2) + + def _helper_keys_in_json(self, d: dict): + s = self.serialise(d) + for k in d.keys(): + with self.subTest(k): + self.assertIn(str(k), s) + + def test_sos_keys_in_json(self): + self._helper_keys_in_json(self.string2sets_dict1) + + def test_soi_keys_in_json(self): + self._helper_keys_in_json(self.string2sets_dict2) + + def _helper_values_in_json(self, d: dict): + s = self.serialise(d) + for key, v in d.items(): + for nr, el in enumerate(v): + with self.subTest(f"Key: {key}; Element {nr}"): + self.assertIn(str(el), s) + + def test_sos_values_in_json(self): + self._helper_values_in_json(self.string2sets_dict1) + + def test_soi_values_in_json(self): + self._helper_values_in_json(self.string2sets_dict2) + + +class SetDecodeTests(unittest.TestCase): + + def deserialise(self, s: str) -> dict: + return json.loads(s, object_hook=coding.default_hook) + + def setUp(self) -> None: + self.encoder = SetEncodeTests() + self.encoded1 = self.encoder.serialise(self.encoder.string2sets_dict1) + self.encoded2 = self.encoder.serialise(self.encoder.string2sets_dict2) + + def test_sos_decodes(self): + d = self.deserialise(self.encoded1) + self.assertIsInstance(d, dict) + + def test_soi_decodes(self): + d = self.deserialise(self.encoded2) + self.assertIsInstance(d, dict) + + def test_sos_decodes_to_identical(self): + d = self.deserialise(self.encoded1) + self.assertEqual(d, self.encoder.string2sets_dict1) + + def test_soi_decodes_to_identical(self): + d = self.deserialise(self.encoded2) + self.assertEqual(d, self.encoder.string2sets_dict2) diff --git a/tests/utils/saving/test_serialization.py b/tests/utils/saving/test_serialization.py index 6313906dc..f0cc75de1 100644 --- a/tests/utils/saving/test_serialization.py +++ b/tests/utils/saving/test_serialization.py @@ -9,11 +9,13 @@ from medcat.cat import CAT from medcat.vocab import Vocab -from medcat.utils.saving.serializer import JsonSetSerializer, CDBSerializer, SPECIALITY_NAMES +from medcat.utils.saving.serializer import JsonSetSerializer, CDBSerializer, SPECIALITY_NAMES, ONE2MANY +import medcat.utils.saving.coding as _ -class JSONSerialoizationTests(unittest.TestCase): - folder = os.path.join('temp', 'JSONSerialoizationTests') + +class JSONSerializationTests(unittest.TestCase): + folder = os.path.join('temp', 'JSONSerializationTests') def setUp(self) -> None: return super().setUp() @@ -42,6 +44,11 @@ def test_round_trip(self): self.ser.serialize(self.cdb, overwrite=True) cdb = self.ser.deserialize(CDB) for name in SPECIALITY_NAMES: + if name in ONE2MANY: + # ignore cui2many and name2many + # since they don't exist if/when + # optimisation hasn't been done + continue with self.subTest(name): orig = getattr(self.cdb, name) now = getattr(cdb, name) @@ -82,11 +89,19 @@ def test_dill_to_json(self): json_path = os.path.join(model_pack_folder, "*.json") jsons = glob.glob(json_path) # there is also a model_card.json - self.assertGreaterEqual(len(jsons), len(SPECIALITY_NAMES)) + # but nothing for cui2many or name2many + # so can remove the length of ONE2MANY + self.assertGreaterEqual(len(jsons), len( + SPECIALITY_NAMES) - len(ONE2MANY)) for json in jsons: with self.subTest(f'JSON {json}'): if json.endswith('model_card.json'): continue # ignore model card here + if any(name in json for name in ONE2MANY): + # ignore cui2many and name2many + # since they don't exist if/when + # optimisation hasn't been done + continue self.assertTrue( any(special_name in json for special_name in SPECIALITY_NAMES)) return model_pack_folder @@ -128,6 +143,11 @@ def test_round_trip(self): self.assertEqual(cat.vocab.unigram_table, self.undertest.vocab.unigram_table) for name in SPECIALITY_NAMES: + if name in ONE2MANY: + # ignore cui2many and name2many + # since they don't exist if/when + # optimisation hasn't been done + continue with self.subTest(f'CDB Name {name}'): self.assertEqual(cat.cdb.__dict__[ name], self.undertest.cdb.__dict__[name]) diff --git a/tests/utils/test_memory_optimiser.py b/tests/utils/test_memory_optimiser.py new file mode 100644 index 000000000..5f59f5274 --- /dev/null +++ b/tests/utils/test_memory_optimiser.py @@ -0,0 +1,375 @@ +from medcat.utils import memory_optimiser + +import unittest +import tempfile +import os +import shutil +import json +from medcat.cat import CAT +from medcat.cdb import CDB +from medcat.vocab import Vocab +from medcat.utils.saving import coding + + +class DelegatingDictTests(unittest.TestCase): + _dict = {'c1': [None, 0], 'c2': [1, None]} + + def setUp(self) -> None: + # deep copy so that the origianl remains unchangeds + _dict = dict((k, v.copy() + ) for k, v in self._dict.items()) + self.del_dict1 = memory_optimiser.DelegatingDict(_dict, 0, 2) + self.del_dict2 = memory_optimiser.DelegatingDict(_dict, 1, 2) + self.delegators = [self.del_dict1, self.del_dict2] + self.names = ['delegator 1', 'delegator 2'] + self.expected_lens = [len( + [v[nr] for v in _dict.values() if v[nr] is not None] + ) for nr in range(len(_dict[list(_dict.keys())[0]]))] + + def test_removal(self, key='c2'): + self.assertIn(key, self.del_dict1) + del self.del_dict1[key] + self.assertNotIn(key, self.del_dict1) + + def test_pop_no_def_existing(self, key='c2'): + self.assertIn(key, self.del_dict1) + val = self.del_dict1.pop(key) + self.assertNotIn(key, self.del_dict1) + self.assertIs(val, self._dict[key][0]) + + def test_pop_def_non_existing(self, key='c1', def_val='DEF VAL'): + self.assertNotIn(key, self.del_dict1) + val = self.del_dict1.pop(key, def_val) + self.assertNotIn(key, self.del_dict1) + self.assertIs(val, def_val) + + def test_adding_exiting_key_nonexist_value(self, key: str = 'c1'): + self.assertNotIn(key, self.del_dict1) + self.del_dict1[key] = 'value' + self.assertIn(key, self.del_dict1) + + def test_adding_nonexiting_key(self, key: str = 'nek1'): + self.assertNotIn(key, self.del_dict1) + self.del_dict1[key] = 'value-NEW' + self.assertIn(key, self.del_dict1) + + def test_adding_nonexiting_key_not_affect_other(self, key: str = 'nek2'): + self.assertNotIn(key, self.del_dict2) + self.del_dict1[key] = 'value-NEW-2' + self.assertNotIn(key, self.del_dict2) + + def test_delegating_dict_has_correct_keys(self): + for delegator, exp_len, name in zip(self.delegators, self.expected_lens, self.names): + with self.subTest(name): + self.assertEqual(len(delegator.keys()), exp_len) + + def test_delegating_dict_has_same_number_of_keys_and_values(self): + for delegator, exp_len, name in zip(self.delegators, self.expected_lens, self.names): + with self.subTest(name): + self.assertEqual(len(delegator.keys()), exp_len) + self.assertEqual(len(delegator.values()), exp_len) + + def test_delegating_dict_has_same_number_of_items_and_iter_values(self): + for delegator, exp_len, name in zip(self.delegators, self.expected_lens, self.names): + with self.subTest(name): + self.assertEqual(len(delegator.items()), exp_len) + # __iter__ -> list -> len + self.assertEqual(len(list(delegator)), exp_len) + + def test_delegator_do_not_have_None_values(self): + for delegator, name in zip(self.delegators, self.names): + for key, val in delegator.items(): + with self.subTest(f"{name}: {key}"): + self.assertIsNotNone(val) + + def test_delegator_keys_in_original(self): + for delegator, name in zip(self.delegators, self.names): + for key in delegator.keys(): + with self.subTest(f"{name}: {key}"): + self.assertIn(key, self._dict) + + def test_delegator_keys_in_container(self): + for delegator, name in zip(self.delegators, self.names): + for key in delegator.keys(): + with self.subTest(f"{name}: {key}"): + self.assertIn(key, delegator) + + def test_delegator_get_gets_key(self, def_value='#DEFAULT#'): + for delegator, name in zip(self.delegators, self.names): + for key in delegator.keys(): + with self.subTest(f"{name}: {key}"): + val = delegator.get(key, def_value) + self.assertIsNot(val, def_value) + + def test_delegator_get_defaults_non_existant_key(self, def_value='#DEFAULT#'): + for delegator, name in zip(self.delegators, self.names): + for key in self._dict.keys(): + if key in delegator: + continue + with self.subTest(f"{name}: {key}"): + val = delegator.get(key, def_value) + self.assertIs(val, def_value) + + +class DelegatingDictJsonTests(unittest.TestCase): + _dict = {'c5': [None, 10], 'c6': [11, None]} + + def setUp(self) -> None: + self.del_dict1 = memory_optimiser.DelegatingDict(self._dict, 0, 2) + self.del_dict2 = memory_optimiser.DelegatingDict(self._dict, 1, 2) + self.delegators = [self.del_dict1, self.del_dict2] + self.master_dict = {'one2many': self._dict, + 'part1': self.del_dict1, + 'part2': self.del_dict2} + + def serialise_master(self) -> str: + return json.dumps(self.master_dict, + cls=coding.CustomDelegatingEncoder.def_inst) + + def deserialise(self, s: str, one2many_name='one2many') -> dict: + d = json.loads(s, object_hook=coding.default_hook) + one2many = d[one2many_name] + for key, value in d.items(): + if key == one2many_name: + continue + if value.delegate is None: + value.delegate = one2many + return d + + def test_dict_of_delegation_serialises(self): + s = self.serialise_master() + self.assertIsInstance(s, str) + + def test_dod_ser_has_keys(self): + s = self.serialise_master() + for key in self.master_dict: + with self.subTest(key): + self.assertIn(key, s) + + def test_dod_ser_one2many_has_sub_keys(self): + s = self.serialise_master() + for key in self.master_dict['one2many']: + with self.subTest(key): + self.assertIn(key, s) + + def test_round_trip(self): + s = self.serialise_master() + d = self.deserialise(s) + self.assertIsInstance(d, dict) + + def test_round_trip_equal(self): + s = self.serialise_master() + d = self.deserialise(s) + self.assertEqual(d, self.master_dict) + + +class UnOptimisingTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + + def test_unoptimised_cdb_does_not_have_cui2many(self): + self.assertFalse(hasattr(self.cdb, 'cui2many')) + + def test_unoptmised_cdb_does_not_have_delegating_dicts(self): + for key, val in self.cdb.__dict__.items(): + with self.subTest(key): + self.assertNotIsInstance(val, memory_optimiser.DelegatingDict) + + def test_unoptimised_knows_has_no_optimsied_parts(self): + self.assertFalse(self.cdb._memory_optimised_parts, + "Should have empty optimised partss") + + def test_simply_loaded_model_not_dirty(self): + self.assertFalse(self.cdb.is_dirty) + + +class MemoryOptimisingTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + memory_optimiser.perform_optimisation(cls.cdb, optimise_snames=True) + + def test_is_dirty(self): + self.assertTrue(self.cdb.is_dirty, + "Should be dirty after optimisation") + + def test_knows_optimised(self): + self.assertTrue(self.cdb._memory_optimised_parts, + "Should have non-empty `_memory_optimised_parts`") + + def test_knows_correct_parts_optimsed(self, should_be=['CUIS', 'snames']): + for name in should_be: + with self.subTest(name): + self.assertIn(name, self.cdb._memory_optimised_parts) + + def test_knows_incorrect_parts_NOT_optimised(self, should_not_be=['NAMES']): + for name in should_not_be: + with self.subTest(name): + self.assertNotIn(name, self.cdb._memory_optimised_parts) + + def test_cdb_has_one2many(self, one2many_name='cui2many'): + self.assertTrue(hasattr(self.cdb, one2many_name)) + one2many = getattr(self.cdb, one2many_name) + self.assertIsInstance(one2many, dict) + + def test_cdb_has_delegating_dicts(self): + for dict_name in memory_optimiser.CUI_DICT_NAMES_TO_COMBINE: + with self.subTest(dict_name): + d = getattr(self.cdb, dict_name) + self.assertIsInstance(d, memory_optimiser.DelegatingDict) + + def test_has_delegating_set(self): + self.assertIsInstance( + self.cdb.snames, memory_optimiser.DelegatingValueSet) + + def test_delegating_set_has_values(self): + for values in self.cdb.cui2snames.values(): + for val in values: + with self.subTest(f'Checking {val}'): + self.assertIn(val, self.cdb.snames) + + +class MemoryUnoptimisingTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + + def test_optimisation_round_trip_cuis(self): + cui_dicts_before = [getattr(self.cdb, dict_name) + for dict_name in memory_optimiser.CUI_DICT_NAMES_TO_COMBINE] + memory_optimiser.perform_optimisation(self.cdb) + memory_optimiser.unoptimise_cdb(self.cdb) + cui_dicts_after = [getattr(self.cdb, dict_name) + for dict_name in memory_optimiser.CUI_DICT_NAMES_TO_COMBINE] + for before, after, name in zip(cui_dicts_before, + cui_dicts_after, + memory_optimiser.CUI_DICT_NAMES_TO_COMBINE): + with self.subTest(f'{name}'): + self.assertIsInstance(before, dict) + self.assertIsInstance(after, dict) + self.assertEquals(len(before), len(after)) + self.assertEquals(before, after) + + def test_optimisation_round_trip_snames(self): + snames_before = self.cdb.snames + memory_optimiser.perform_optimisation(self.cdb) + memory_optimiser.unoptimise_cdb(self.cdb) + snames_after = self.cdb.snames + self.assertIsInstance(snames_before, set) + self.assertIsInstance(snames_after, set) + self.assertEquals(len(snames_before), len(snames_after)) + self.assertEquals(snames_before, snames_after) + + def test_optimisation_round_trip_dirty(self): + memory_optimiser.perform_optimisation(self.cdb) + memory_optimiser.unoptimise_cdb(self.cdb) + self.assertTrue(self.cdb.is_dirty) + + def test_optimisation_round_trip_no_optimised_parts(self): + memory_optimiser.perform_optimisation(self.cdb) + memory_optimiser.unoptimise_cdb(self.cdb) + self.assertFalse(self.cdb._memory_optimised_parts, + "Should have no optimised parts") + + +class OperationalTests(unittest.TestCase): + temp_folder = tempfile.TemporaryDirectory() + temp_cdb_path = os.path.join(temp_folder.name, 'cat.cdb') + json_path = temp_cdb_path.rsplit(os.path.sep, 1)[0] + # importing here so it's in the local namespace + # otherwise, all of its parts would get run again + from tests.test_cat import CATTests + test_callable_with_single_text = CATTests.test_callable_with_single_text + test_callable_with_single_empty_text = CATTests.test_callable_with_single_empty_text + test_callable_with_single_none_text = CATTests.test_callable_with_single_none_text + test_get_entities = CATTests.test_get_entities + test_get_entities_including_text = CATTests.test_get_entities_including_text + test_get_entities_multi_texts = CATTests.test_get_entities_multi_texts + test_get_entities_multi_texts_including_text = CATTests.test_get_entities_multi_texts_including_text + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + memory_optimiser.perform_optimisation(cls.cdb, optimise_snames=True) + cls.vocab = Vocab.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "vocab.dat")) + cls.cdb.config.general.spacy_model = "en_core_web_md" + cls.cdb.config.ner.min_name_len = 2 + cls.cdb.config.ner.upper_case_limit_len = 3 + cls.cdb.config.general.spell_check = True + cls.cdb.config.linking.train_count_threshold = 10 + cls.cdb.config.linking.similarity_threshold = 0.3 + cls.cdb.config.linking.train = True + cls.cdb.config.linking.disamb_length_limit = 5 + cls.cdb.config.general.full_unlink = True + cls.meta_cat_dir = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "tmp") + cls.undertest = CAT(cdb=cls.cdb, config=cls.cdb.config, + vocab=cls.vocab, meta_cats=[]) + cls._linkng_filters = cls.undertest.config.linking.filters.copy_of() + + # # add tests from CAT tests + + @classmethod + def tearDownClass(cls) -> None: + cls.temp_folder.cleanup() + cls.undertest.destroy_pipe() + if os.path.exists(cls.meta_cat_dir): + shutil.rmtree(cls.meta_cat_dir) + + def tearDown(self) -> None: + self.cdb.config.annotation_output.include_text_in_output = False + # need to make sure linking filters are not retained beyond a test scope + self.undertest.config.linking.filters = self._linkng_filters.copy_of() + + def test_optimised_cdb_has_cui2many(self): + self.assertTrue(hasattr(self.cdb, 'cui2many')) + + def test_can_be_saved_as_json(self): + self.cdb.save(self.temp_cdb_path, json_path=self.json_path) + + def test_can_be_loaded_as_json(self): + self.test_can_be_saved_as_json() + cdb = CDB.load(self.temp_cdb_path, self.json_path) + self.assertEqual(self.cdb.cui2many, cdb.cui2many) + for del_name in memory_optimiser.CUI_DICT_NAMES_TO_COMBINE: + d = getattr(cdb, del_name) + with self.subTest(del_name): + self.assertIsInstance(d, memory_optimiser.DelegatingDict) + self.assertIs(cdb.cui2many, d.delegate) + + +class DelegatingValueSetTests(unittest.TestCase): + + def setUp(self) -> None: + self.delegate = {'a': set('abcd'), + 'b': set('efghij'), + 'c': set('lm'), # skip k + 'd': set('qrst'), # skip a bunch + } + self.original = set([v for s in self.delegate for v in s]) + + def test_DelegatingValueSet_constructs(self): + dvs = memory_optimiser.DelegatingValueSet(self.delegate) + self.assertIsInstance(dvs, memory_optimiser.DelegatingValueSet) + + def test_DelegatingValueSet_contains_values(self): + dvs = memory_optimiser.DelegatingValueSet(self.delegate) + for v in self.original: + with self.subTest(f'Check: {v}'): + self.assertIn(v, dvs) + + def test_DelegatingValueSet_contains_incorrect_values(self, + to_check=set('kopuvwxyz')): + dvs = memory_optimiser.DelegatingValueSet(self.delegate) + for v in to_check: + with self.subTest(f'Check: {v}'): + self.assertNotIn(v, dvs) diff --git a/tests/utils/test_versioning.py b/tests/utils/test_versioning.py new file mode 100644 index 000000000..30a3afdd3 --- /dev/null +++ b/tests/utils/test_versioning.py @@ -0,0 +1,163 @@ +import unittest +import os +import tempfile +import shutil + +import dill +import pydantic + +from medcat.utils.versioning import get_version_from_modelcard, get_semantic_version_from_model +from medcat.utils.versioning import get_version_from_cdb_dump, get_version_from_modelpack_zip +from medcat.utils.versioning import ConfigUpgrader +from medcat.cat import CAT +from medcat.cdb import CDB +from medcat.vocab import Vocab + +from .regression.test_metadata import MODEL_CARD_EXAMPLE, EXAMPLE_VERSION + + +CORRECT_SEMANTIC_VERSIONS = [("1.0.1-alpha-1", (1, 0, 1)), ("0.0.1-alpha-1", (0, 0, 1)), + ("1.0.0-alpha.1", (1, 0, 0) + ), ("1.0.0-0.3.7", (1, 0, 0)), + ("1.0.0-x.7.z.92", (1, 0, 0) + ), ("1.0.0-x-y-z.--", (1, 0, 0)), + ("1.0.0-alpha+001", (1, 0, 0) + ), ("1.0.0+20130313144700", (1, 0, 0)), + ("1.0.0-beta+exp.sha.5114f85", (1, 0, 0)), + ("1.0.0+21AF26D3----117B344092BD", (1, 0, 0))] +INCORRECT_SEMANTIC_VERSIONS = ["01.0.0", "0.01.0", "0.0.01", "0.0.0\nSOMETHING", + "1.0.space", "1.0.0- space"] + + +class VersionGettingFromModelCardTests(unittest.TestCase): + FAKE_MODEL_CARD1 = {"Something": "value"} + FAKE_MODEL_CARD2 = {"MedCAT Version": "not semantic"} + FAKE_MODEL_CARD3 = {"MedCAT Version": "almost.semantic"} + FAKE_MODEL_CARD4 = {"MedCAT Version": "closest.to.semantic"} + WRONG_VERSION_FAKE_MODELS = [FAKE_MODEL_CARD2, + FAKE_MODEL_CARD3, FAKE_MODEL_CARD4] + + def test_gets_correct_version(self): + maj, minor, patch = get_version_from_modelcard(MODEL_CARD_EXAMPLE) + self.assertEqual(EXAMPLE_VERSION, (maj, minor, patch)) + + def test_fails_upon_model_card_with_no_version_defined(self): + with self.assertRaises(KeyError): + get_version_from_modelcard(self.FAKE_MODEL_CARD1) + + def test_fails_upon_model_card_with_incorrect_version(self): + cntr = 0 + for fake_model_card in self.WRONG_VERSION_FAKE_MODELS: + with self.assertRaises(ValueError): + get_version_from_modelcard(fake_model_card) + cntr += 1 + self.assertEqual(cntr, len(self.WRONG_VERSION_FAKE_MODELS)) + + def test_fails_upon_wrong_version(self): + cntr = 0 + for wrong_version in INCORRECT_SEMANTIC_VERSIONS: + d = {"MedCAT Version": wrong_version} + with self.subTest(f"With version: {wrong_version}"): + with self.assertRaises(ValueError): + get_version_from_modelcard(d) + cntr += 1 + self.assertEqual(cntr, len(INCORRECT_SEMANTIC_VERSIONS)) + + def test_gets_version_from_correct_versions(self): + cntr = 0 + for version, expected in CORRECT_SEMANTIC_VERSIONS: + d = {"MedCAT Version": version} + with self.subTest(f"With version: {version}"): + got_version = get_version_from_modelcard(d) + self.assertEqual(got_version, expected) + cntr += 1 + self.assertEqual(cntr, len(CORRECT_SEMANTIC_VERSIONS)) + + +NEW_CDB_NAME = "cdb_new.dat" +CDB_PATH = os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", NEW_CDB_NAME) +EXPECTED_CDB_VERSION = (1, 0, 0) + + +class VersionGettingFromCATTests(unittest.TestCase): + + def setUp(self) -> None: + self.cdb = CDB.load(CDB_PATH) + self.vocab = Vocab.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "vocab.dat")) + self.cdb.config.general.spacy_model = "en_core_web_md" + self.cdb.config.ner.min_name_len = 2 + self.cdb.config.ner.upper_case_limit_len = 3 + self.cdb.config.general.spell_check = True + self.cdb.config.linking.train_count_threshold = 10 + self.cdb.config.linking.similarity_threshold = 0.3 + self.cdb.config.linking.train = True + self.cdb.config.linking.disamb_length_limit = 5 + self.cdb.config.general.full_unlink = True + self.meta_cat_dir = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "tmp") + self.undertest = CAT( + cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[]) + + def test_gets_correct_version(self): + version = get_semantic_version_from_model(self.undertest) + self.assertEqual(EXPECTED_CDB_VERSION, version) + + +class VersionGetterFromCDBTests(unittest.TestCase): + + def test_gets_version_from_cdb(self): + version = get_version_from_cdb_dump(CDB_PATH) + self.assertEqual(EXPECTED_CDB_VERSION, version) + + +class VersionGettFromModelPackTests(unittest.TestCase): + + def test_gets_version_from_model_pack(self): + # not strictly speaking a ZIP, but should work currently + # since the folder exists + model_pack_zip = os.path.dirname(CDB_PATH) + version = get_version_from_modelpack_zip( + model_pack_zip, cdb_file_name=NEW_CDB_NAME) + self.assertEqual(EXPECTED_CDB_VERSION, version) + + +class VersioningFixTests(unittest.TestCase): + + def break_cdb(self): + with open(self.broken_cdb_path, 'rb') as rf: + data = dill.load(rf) + data['config']['linking']['filters']['cuis'] = {} + with open(self.broken_cdb_path, 'wb') as wf: + dill.dump(data, wf) + + def setUp(self) -> None: + self.temp_folder = tempfile.TemporaryDirectory() + self.broken_cdb_path = os.path.join(self.temp_folder.name, "cdb.dat") + self.new_temp_folder = tempfile.TemporaryDirectory() + shutil.copyfile(CDB_PATH, self.broken_cdb_path) + self.break_cdb() + + def tearDown(self) -> None: + self.temp_folder.cleanup() + self.new_temp_folder.cleanup() + + def test_new_format_does_not_change_when_upgraded(self): + fixer = ConfigUpgrader(os.path.dirname( + CDB_PATH), cdb_file_name=NEW_CDB_NAME) + fixer.upgrade(self.new_temp_folder.name) + old_cdb = CDB.load(CDB_PATH) + new_cdb = CDB.load(os.path.join( + self.new_temp_folder.name, NEW_CDB_NAME)) + self.assertEqual(old_cdb.config.get_hash(), new_cdb.config.get_hash()) + + def test_old_format_needs_upgrade(self): + fixer = ConfigUpgrader(self.temp_folder.name) + self.assertTrue(fixer.needs_upgrade()) + + def test_fixes_old_format(self): + fixer = ConfigUpgrader(self.temp_folder.name) + fixer.upgrade(self.new_temp_folder.name) + new_cdb = CDB.load(os.path.join(self.new_temp_folder.name, "cdb.dat")) + self.assertIsInstance(new_cdb, CDB) diff --git a/webapp/webapp/requirements.txt b/webapp/webapp/requirements.txt index 9534e025a..a4b7827ad 100644 --- a/webapp/webapp/requirements.txt +++ b/webapp/webapp/requirements.txt @@ -1,4 +1,4 @@ -Django==3.2.18 +Django==3.2.20 django-dbbackup==4.0.0b0 django-storages[boto3]==1.12.3 django-cron==0.5.1 From ba1dc4aa535feb79226bd72b34e1fdca31c8d6df Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Thu, 21 Sep 2023 17:13:05 +0300 Subject: [PATCH 2/9] v1.9.1 Release PR (#346) * remove bad merge

element * CU-8692kpchc Fix for Rosalind link not working (#342) * CU-8692kpchc Add the 403 exception to vocab downloader * CU-8692kpchc Add the new vocab download link * Add missing self argument (#343) To `_refset_df2dict ` method in Snomed preprocessing * CU-8692kn0yv Fix issue with fake dict in identifier based config More specifically the get method which was not able to return default values for non-existant keys (#341) * CU-8692mevx8 Fix issue with filters not taking effect in train_supervised method (#345) * CU-8692mevx8 Fix issue with filters not taking effect in train_supervised method * CU-8692mevx8 Fix filter retention in train_supervised method --------- Co-authored-by: tomolopolis --- medcat/cat.py | 20 +++++++++++++++---- medcat/config.py | 6 +++++- medcat/utils/preprocess_snomed.py | 2 +- tests/helper.py | 15 ++++++++++++-- .../demo/templates/train_annotations.html | 2 -- 5 files changed, 35 insertions(+), 10 deletions(-) diff --git a/medcat/cat.py b/medcat/cat.py index 5218e9d02..2323cd737 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -490,7 +490,8 @@ def _print_stats(self, fp_docs: Set = set() fn_docs: Set = set() - local_filters = self.config.linking.filters.copy_of() + orig_filters = self.config.linking.filters.copy_of() + local_filters = self.config.linking.filters for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False): local_filters.cuis = set() @@ -645,6 +646,8 @@ def _print_stats(self, except Exception: traceback.print_exc() + self.config.linking.filters = orig_filters + return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples def _set_project_filters(self, local_filters: LinkingFilters, project: dict, @@ -1033,7 +1036,13 @@ def train_supervised_raw(self, """ checkpoint = self._init_ckpts(is_resumed, checkpoint) - local_filters = self.config.linking.filters.copy_of() + # the config.linking.filters stuff is used directly in + # medcat.linking.context_based_linker and medcat.linking.vector_context_model + # as such, they need to be kept up to date with per-project filters + # However, the original state needs to be kept track of + # so that it can be restored after training + orig_filters = self.config.linking.filters.copy_of() + local_filters = self.config.linking.filters fp = fn = tp = p = r = f1 = examples = {} @@ -1094,7 +1103,7 @@ def train_supervised_raw(self, if retain_filters and extra_cui_filter and not retain_extra_cui_filter: # adding project filters without extra_cui_filters self._set_project_filters(local_filters, project, set(), use_filters) - self.config.linking.filters.merge_with(local_filters) + orig_filters.merge_with(local_filters) # adding extra_cui_filters, but NOT project filters self._set_project_filters(local_filters, project, extra_cui_filter, False) # refrain from doing it again for subsequent epochs @@ -1140,7 +1149,7 @@ def train_supervised_raw(self, checkpoint.save(self.cdb, latest_trained_step) # if retaining MCT filters AND (if they exist) extra_cui_filters if retain_filters: - self.config.linking.filters.merge_with(local_filters) + orig_filters.merge_with(local_filters) # refrain from doing it again for subsequent epochs retain_filters = False @@ -1162,6 +1171,9 @@ def train_supervised_raw(self, use_groups=use_groups, extra_cui_filter=extra_cui_filter) + # reset the state of filters + self.config.linking.filters = orig_filters + return fp, fn, tp, p, r, f1, cui_counts, examples def get_entities(self, diff --git a/medcat/config.py b/medcat/config.py index 07cf6f7f1..b2e324deb 100644 --- a/medcat/config.py +++ b/medcat/config.py @@ -28,7 +28,11 @@ class FakeDict: """FakeDict that allows the use of the __getitem__ and __setitem__ method for legacy access.""" def __getitem__(self, arg: str) -> Any: - return getattr(self, arg) + try: + return getattr(self, arg) + except AttributeError as e: + raise KeyError from e + def __setitem__(self, arg: str, val) -> None: setattr(self, arg, val) diff --git a/medcat/utils/preprocess_snomed.py b/medcat/utils/preprocess_snomed.py index 5e65b3a77..3ba94b977 100644 --- a/medcat/utils/preprocess_snomed.py +++ b/medcat/utils/preprocess_snomed.py @@ -327,7 +327,7 @@ def _check_path_and_release(self): raise FileNotFoundError('Incorrect path to SNOMED CT directory') return paths, snomed_releases - def _refset_df2dict(refset_df: pd.DataFrame) -> dict: + def _refset_df2dict(self, refset_df: pd.DataFrame) -> dict: """ This function takes a SNOMED refset DataFrame as an input and converts it into a dictionary. The DataFrame should contain the columns 'referencedComponentId','mapTarget','mapGroup','mapPriority','mapRule','mapAdvice'. diff --git a/tests/helper.py b/tests/helper.py index 23afdb6b4..9fb66589b 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -24,6 +24,15 @@ async def __call__(self, *args, **kwargs): """ +ERROR_403 = b""" + +403 Forbidden + +

Forbidden

+

You don't have permission to access this resource.

+ +""" + SIMPLE_WORDS = """house 34444 0.3232 0.123213 1.231231 dog 14444 0.76762 0.76767 1.45454""" @@ -45,7 +54,7 @@ def generate_simple_vocab(): class VocabDownloader: - url = 'https://medcat.rosalind.kcl.ac.uk/media/vocab.dat' + url = 'https://cogstack-medcat-example-models.s3.eu-west-2.amazonaws.com/medcat-example-models/vocab.dat' vocab_path = "./tmp_vocab.dat" _has_simple = False @@ -54,6 +63,8 @@ def is_valid(self): content = f.read() if content == ERROR_503: return False + if content == ERROR_403: + return False v = Vocab.load(self.vocab_path) if len(v.vocab) == 2: # simple one self._has_simple = True @@ -64,7 +75,7 @@ def check_or_download(self): if os.path.exists(self.vocab_path) and self.is_valid(): return tmp = requests.get(self.url) - if tmp.content == ERROR_503: + if tmp.content == ERROR_503 or tmp.content == ERROR_403: print('Rosalind server unavailable') if self._has_simple: print('Local simple vocab already present') diff --git a/webapp/webapp/demo/templates/train_annotations.html b/webapp/webapp/demo/templates/train_annotations.html index 19b5882c7..25677cd21 100644 --- a/webapp/webapp/demo/templates/train_annotations.html +++ b/webapp/webapp/demo/templates/train_annotations.html @@ -29,8 +29,6 @@
Disclaimer
WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED.

contact@cogstack.com for more information.

-

contact@cogstack.com for more information.

-

Sample text

From 4e618aa2f918540c49bdb7cded8f46c21474fc46 Mon Sep 17 00:00:00 2001
From: Mart Ratas 
Date: Mon, 8 Jan 2024 16:59:40 +0200
Subject: [PATCH 3/9] v1.10.0 (#388)

* Bump urllib3 from 1.26.5 to 1.26.17 in /webapp/webapp

Bumps [urllib3](https://github.com/urllib3/urllib3) from 1.26.5 to 1.26.17.
- [Release notes](https://github.com/urllib3/urllib3/releases)
- [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst)
- [Commits](https://github.com/urllib3/urllib3/compare/1.26.5...1.26.17)

---
updated-dependencies:
- dependency-name: urllib3
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] 

* Cu 8692wbcq5 docs builds (#359)

* CU-8692wbcq5: Pin max version of numpy

* CU-8692wbcq5: Pin max version of numpy in setup.py

* CU-8692wbcq5: Bump python version for readthedocs workflow

* CU-8692wbcq5: Pin all requirement versions in docs requirements

* CU-8692wbcq5: Move docs requirements before setuptools

* CU-8692wbcq5: Fix typo in docs requirements

* CU-8692wbcq5: Remove some less relevant stuff from docs requirements

* CU-8692wbcq5: Add back sphinx-based requirements to docs requirements

* CU-8692wbcq5: Move back to python 3.9 on docs build workflow

* CU-8692wbcq5: Bump sphinx-autoapi version

* CU-8692wbcq5: Bump sphinx version

* CU-8692wbcq5: Bump python version back to 3.10 for future-proofing

* CU-8692wbcq5: Undo pinning numpy to max version in setup.py

* CU-8692wbcq5: Remove docs-build specific dependencies in setup.py

* Bump urllib3 from 1.26.17 to 1.26.18 in /webapp/webapp

Bumps [urllib3](https://github.com/urllib3/urllib3) from 1.26.17 to 1.26.18.
- [Release notes](https://github.com/urllib3/urllib3/releases)
- [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst)
- [Commits](https://github.com/urllib3/urllib3/compare/1.26.17...1.26.18)

---
updated-dependencies:
- dependency-name: urllib3
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] 

* CU-8692uznvd: Allow empty-dict config.linking.filters.cuis and convert to set in memory (#352)

* CU-8692uznvd: Allow empty-dict config.linking.filters.cuis and convert to set in memory

* CU-8692uznvd: Move the empty-set detection and conversion from validator to init

* CU-8692uznvd: Remove unused import

* CU-8692t3fdf separate config on save (#350)

* CU-8692t3fdf Move saving config outside of the cdb.dat; Add test to make sure the config does not get saved with the CDB; patch a few existing tests

* CU-8692t3fdf Use class methods on class instead of instance in a few tests

* CU-8692t3fdf Fix typing issue

* CU-8692t3fdf Add additional tests for 2 configs and zero configs when loading model pack

* CU-8692t3fdf: Make sure CDB is linked to the correct config; Treat incorrect configs as dirty CDBs and force a recalc of the hash

* CU-2cdpd4t: Unify default addl_info in different methdos. (#363)

* Bump django from 3.2.20 to 3.2.23 in /webapp/webapp

Bumps [django](https://github.com/django/django) from 3.2.20 to 3.2.23.
- [Commits](https://github.com/django/django/compare/3.2.20...3.2.23)

---
updated-dependencies:
- dependency-name: django
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] 

* Changing cdb.add_concept to a protected method

* Re-added deprecated method  with deprecated flag and addtional comments

* Initial commit for merge_cdb method

* Added indentation to make merge_cdb a class method

* fixed syntax issues

* more lint fixes

* more lint fixes

* bug fixes of merge_cdb

* removed print statements

* CU-86931prq4: Update GHA versions (checkout and setup-python) to v4 (#368)

* Cu 1yn0v9e duplicate multiprocessing methods (#364)

* CU-1yn0v9e: Rename and deprecate one of the multiprocessing methods;

Add docstring. Trying to be more explicit regarding usage and differences between different methods

* CU-1yn0v9e: Rename and deprecate the multiprocessing_pipe method;

Add docstring. Trying to be more explicit regarding usage and differences between different methods

* CU-1yn0v9e: Fix typo in docstring; more consistent naming

* 869377m3u: Add comment regarding demo link load times to README (#376)

* intermediate changes of merge_cdb and testing function

* Added README.md documentation for CPU only installations (#365)

* changed README.md to reflect installation options.

* added setup script to demonstrate how wrapper could look for CPU installations

* removed setup.sh as unnessescary for cpu only builds

* Initial commit for merge_cdb method

* Added indentation to make merge_cdb a class method

* fixed syntax issues

* more lint fixes

* more lint fixes

* bug fixes of merge_cdb

* removed print statements

* Added commentary on disk space usage of pytorch-gpu

* removed merge_cdb from branch

---------

Co-authored-by: adam-sutton-1992 

* Cu 8692zguyq no preferred name (#367)

* CU-8692zguyq: Slight simplification of minimum-name-length logic

* CU-8692zguyq: Add some tests for prepare_name preprocessor

* CU-8692zguyq: Add warning if no preferred name was added along a new CUI

* CU-8692zguyq: Add additional warning messages when adding/training a new CUI with no preferred name

* CU-8692zguyq: Make no preferred name warnings only run if name status is preferred

* CU-8692zguyq: Add tests for no-preferred name warnings

* CU-8692zguyq: Add Vocab.make_unigram_table to CAT tests

* CU-8692zguyq: Move to built in asserting for logging instead of patching the method

* CU-8692zguyq: Add workaround for assertNoLogs on python 3.8 and 3.9

* Add trainer callbacks for Transformer NER (#377)

CU-86938vf30 add trainer callbacks for Transformer NER

* changes to merge_cdb and adding unit tests for method

* fixing lint issues

* fixing flake8 linting

* bug fixes, additional tests, and more documentation

* moved set up of cdbs to be merged to tests.helper

* moved merge_cdb to utils and created test_cdb_utils

* removed class wrapper in cdb utils and fixed class set up in tests

* changed test object setup to class setup

* removed erroneous new line

* CU-2e77a31 improve print stats (#366)

* Add base class for CAT

* Add CDB base class

* Some whitespace fixes for base modules

* CU-2e77a31: Move print stats to their own module and class

* CU-2e77a31: Fix issues introduced by moving print stats

* CU-2e77a31: Rename print_stats to get_stats and add option to avoid printed output

* CU-2e77a31: Add test for print_stats

* CU-2e77a31: Remove unused import

* CU-2e77a31: Add new package to setup.py

* CU-2e77a31: Fix a bunch of typing issues

* CU-2e77a31: Revert CAT and CDB abstraction

* Load stopwords in Defaults before spacy model

* CU-8693az82g Remove cdb tests side effects (#380)

* 8693az82g: Add method to CDBMaker to reset the CDB

* 8693az82g: Add test in CDB tests to ensure a new CDB is used for each test

* 8693az82g: Reset CDB in CDB tests before each test to avoid side effects

* Added tests

* CU-8693bpq82 fallback spacy model (#384)

* CU-8693bpq82: Add fallback spacy model along with test

* CU-8693bpq82: Remove debug output

* CU-8693bpq82: Add exception info to warning upon spacy model load failure and fallback

* Remove tests of internals where possible

* Add test for skipping of stopwords

* Avoid supporting only English for stopwords

* Remove debug output

* Make sure stopwords language getter works for file-path spacy models

* CU-8693cv3w0 Fix fallback spacy model existance on pip installs (#386)

* CU-8693cv3w0: Add method to ensure spacy model and use it when falling back to default model

* CU-8693cv3w0: Add logged output when installing/downloading spacy model

* CU-8693b0a61 Add method to get spacy model version (#381)

* CU-8693b0a61: Add method to find spacy folder in model pack along with some tests

* CU-8693b0a61: Add test for spacy folder finding (full path)

* CU-8693b0a61: Add method for finding spacy model in model pack along with tests

* CU-8693b0a61: Add method for finding current spacy version

* CU-8693b0a61: Add method for getting spacy model version installed

* CU-8693b0a61: Fix getting spacy model folder return path

* CU-8693b0a61: Add method to get name and meta of spacy model based on model pack

* CU-8693b0a61: Add missing fake spacy model meta

* CU-8693b0a61: Add missing docstrings

* CU-8693b0a61: Change name of method for clarity

* CU-8693b0a61: Add method to get spacy model name and version from model pack path

* CU-8693b0a61: Fix a few typing issues

* CU-8693b0a61: Add a missing docstring

* CU-8693b0a61: Match folder name of fake spacy model to its name

* CU-8693b0a61: Make the final method return true name of spacy model instead of folder name

* Add additional output to method for getting spacy model version - the compatible spacy versions

* CU-8693b0a61: Add method for querying whether the spacy version is compatible with a range

* CU-8693b0a61: Add better abstraction for spacy version mocking in tests

* CU-8693b0a61: Add some more abstraction for fake model pack in tests

* CU-8693b0a61: Add method for checking whethera model pack has a spacy model compatible with installed spacy version

* CU-8693b0a61: Improve abstraction within tests

* CU-8693b0a61: Add method to check which of two versions is older

* CU-8693b0a61: Fix fake spacy model versioning

* CU-8693b0a61: Add method for determining whether a model pack has semi-compatible spacy model

* CU-8693b0a61: Add missing word in docstring.

* CU-8693b0a61: Change some method to protected ones

---------

Signed-off-by: dependabot[bot] 
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: tomolopolis 
Co-authored-by: adam-sutton-1992 
Co-authored-by: adam-sutton-1992 <60137864+adam-sutton-1992@users.noreply.github.com>
Co-authored-by: Xi Bai <82581439+baixiac@users.noreply.github.com>
Co-authored-by: jenniferajiang 
Co-authored-by: Jennifer Jiang <37081323+jenniferajiang@users.noreply.github.com>
---
 .github/workflows/main.yml                    |   8 +-
 .github/workflows/production.yml              |   4 +-
 .readthedocs.yaml                             |   6 +-
 README.md                                     |  12 +
 docs/requirements.txt                         | 106 +++++-
 medcat/cat.py                                 | 283 ++++-----------
 medcat/cdb.py                                 | 132 ++++++-
 medcat/cdb_maker.py                           |  16 +-
 medcat/config.py                              |  21 +-
 medcat/ner/transformers_ner.py                |  16 +-
 medcat/pipe.py                                |  30 +-
 medcat/preprocessing/cleaners.py              |   3 +-
 medcat/stats/__init__.py                      |   0
 medcat/stats/stats.py                         | 340 ++++++++++++++++++
 medcat/utils/cdb_utils.py                     | 117 ++++++
 medcat/utils/filters.py                       |  43 ++-
 medcat/utils/helpers.py                       |  32 ++
 medcat/utils/regression/targeting.py          |   4 +-
 medcat/utils/saving/serializer.py             |  15 +-
 medcat/utils/spacy_compatibility.py           | 211 +++++++++++
 setup.py                                      |   8 +-
 tests/archive_tests/test_cdb_maker_archive.py |   2 +-
 tests/helper.py                               |  35 ++
 tests/ner/test_transformers_ner.py            |  50 +++
 tests/preprocessing/__init__.py               |   0
 tests/preprocessing/test_cleaners.py          | 104 ++++++
 tests/resources/ff_core_fake_dr/meta.json     |   8 +
 tests/test_cat.py                             | 228 +++++++++++-
 tests/test_cdb.py                             |  18 +
 tests/test_config.py                          |  50 ++-
 tests/test_pipe.py                            |   9 +-
 tests/utils/saving/test_serialization.py      |   2 +-
 tests/utils/test_cdb_utils.py                 |  43 +++
 tests/utils/test_hashing.py                   |  53 ++-
 tests/utils/test_helpers.py                   |  24 ++
 tests/utils/test_spacy_compatibility.py       | 302 ++++++++++++++++
 webapp/webapp/requirements.txt                |   4 +-
 37 files changed, 2064 insertions(+), 275 deletions(-)
 create mode 100644 medcat/stats/__init__.py
 create mode 100644 medcat/stats/stats.py
 create mode 100644 medcat/utils/cdb_utils.py
 create mode 100644 medcat/utils/spacy_compatibility.py
 create mode 100644 tests/ner/test_transformers_ner.py
 create mode 100644 tests/preprocessing/__init__.py
 create mode 100644 tests/preprocessing/test_cleaners.py
 create mode 100644 tests/resources/ff_core_fake_dr/meta.json
 create mode 100644 tests/utils/test_cdb_utils.py
 create mode 100644 tests/utils/test_helpers.py
 create mode 100644 tests/utils/test_spacy_compatibility.py

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index c769dfc2e..a5468fb9b 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -16,9 +16,9 @@ jobs:
       max-parallel: 4
 
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v4
       - name: Set up Python ${{ matrix.python-version }}
-        uses: actions/setup-python@v2
+        uses: actions/setup-python@v4
         with:
           python-version: ${{ matrix.python-version }}
       - name: Install dependencies
@@ -48,13 +48,13 @@ jobs:
 
     steps:
       - name: Checkout master
-        uses: actions/checkout@v2
+        uses: actions/checkout@v4
         with:
           ref: 'master'
           fetch-depth: 0
 
       - name: Set up Python 3.9
-        uses: actions/setup-python@v2
+        uses: actions/setup-python@v4
         with:
           python-version: 3.9
 
diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml
index 5088c1000..9ad9a5d90 100644
--- a/.github/workflows/production.yml
+++ b/.github/workflows/production.yml
@@ -14,13 +14,13 @@ jobs:
 
     steps:
       - name: Checkout production
-        uses: actions/checkout@v2
+        uses: actions/checkout@v4
         with:
           ref: ${{ github.event.release.target_commitish }}
           fetch-depth: 0
 
       - name: Set up Python 3.9
-        uses: actions/setup-python@v2
+        uses: actions/setup-python@v4
         with:
           python-version: 3.9
 
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 8c4e65615..5cc0d97f0 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -7,13 +7,13 @@ version: 2
 build:
   os: ubuntu-20.04
   tools:
-    python: "3.9"
+    python: "3.10"
 
 sphinx:
   configuration: docs/conf.py
 
 python:
   install:
+    - requirements: docs/requirements.txt
     - method: setuptools
-      path: .
-    - requirements: docs/requirements.txt
\ No newline at end of file
+      path: .
\ No newline at end of file
diff --git a/README.md b/README.md
index 395aecf69..bf34f00c6 100644
--- a/README.md
+++ b/README.md
@@ -38,8 +38,20 @@ To download any of these models, please [follow this link](https://uts.nlm.nih.g
 - **Paper**: [What’s in a Summary? Laying the Groundwork for Advances in Hospital-Course Summarization](https://www.aclweb.org/anthology/2021.naacl-main.382.pdf)
 - ([more...](https://github.com/CogStack/MedCAT/blob/master/media/news.md))
 
+## Installation
+To install the latest version of MedCAT run the following command:
+```
+pip install medcat
+```
+Normal installations of MedCAT will install torch-gpu and all relevant dependancies (such as CUDA). This can require as much as 10 GB more disk space, which isn't required for CPU only usage.
+
+To install the latest version of MedCAT without torch GPU support run the following command:
+```
+pip install medcat --extra_index_url https://download.pytorch.org/whl/cpu/
+```
 ## Demo
 A demo application is available at [MedCAT](https://medcat.rosalind.kcl.ac.uk). This was trained on MIMIC-III and all of SNOMED-CT.
+PS: This link can take a long time to load the first time around. The machine spins up as needed and spins down when inactive.
 
 ## Tutorials
 A guide on how to use MedCAT is available at [MedCAT Tutorials](https://github.com/CogStack/MedCATtutorials). Read more about MedCAT on [Towards Data Science](https://towardsdatascience.com/medcat-introduction-analyzing-electronic-health-records-e1c420afa13a).
diff --git a/docs/requirements.txt b/docs/requirements.txt
index be517876f..7e7df6e01 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,6 +1,104 @@
-Sphinx~=4.0
+sphinx==6.2.1
 sphinx-rtd-theme~=1.0
 myst-parser~=0.17
-sphinx-autoapi~=1.8
-setuptools>=60.0
-aiohttp==3.8.5
\ No newline at end of file
+sphinx-autoapi~=3.0.0
+MarkupSafe==2.1.3
+accelerate==0.23.0
+aiofiles==23.2.1
+aiohttp==3.8.5
+aiosignal==1.3.1
+asttokens==2.4.0
+async-timeout==4.0.3
+attrs==23.1.0
+backcall==0.2.0
+blis==0.7.11
+catalogue==2.0.10
+certifi==2023.7.22
+charset-normalizer==3.3.0
+click==8.1.7
+comm==0.1.4
+confection==0.1.3
+cymem==2.0.8
+datasets==2.14.5
+decorator==5.1.1
+dill==0.3.7
+exceptiongroup==1.1.3
+executing==2.0.0
+filelock==3.12.4
+flake8==4.0.1
+frozenlist==1.4.0
+fsspec==2023.6.0
+gensim==4.3.2
+huggingface-hub==0.17.3
+idna==3.4
+ipython==8.16.1
+ipywidgets==8.1.1
+jedi==0.19.1
+jinja2==3.1.2
+joblib==1.3.2
+jsonpickle==3.0.2
+jupyterlab-widgets==3.0.9
+langcodes==3.3.0
+matplotlib-inline==0.1.6
+mccabe==0.6.1
+mpmath==1.3.0
+multidict==6.0.4
+multiprocess==0.70.15
+murmurhash==1.0.10
+mypy==1.0.0
+mypy-extensions==0.4.3
+networkx==3.1
+numpy==1.25.2
+packaging==23.2
+pandas==2.1.1
+parso==0.8.3
+pathy==0.10.2
+pexpect==4.8.0
+pickleshare==0.7.5
+preshed==3.0.9
+prompt-toolkit==3.0.39
+psutil==5.9.5
+ptyprocess==0.7.0
+pure-eval==0.2.2
+pyarrow==13.0.0
+pycodestyle==2.8.0
+pydantic==1.10.13
+pyflakes==2.4.0
+pygments==2.16.1
+python-dateutil==2.8.2
+pytz==2023.3.post1
+pyyaml==6.0.1
+regex==2023.10.3
+requests==2.31.0
+safetensors==0.4.0
+scikit-learn==1.3.1
+scipy==1.9.3
+six==1.16.0
+smart-open==6.4.0
+spacy==3.4.4
+spacy-legacy==3.0.12
+spacy-loggers==1.0.5
+srsly==2.4.8
+stack-data==0.6.3
+sympy==1.12
+thinc==8.1.12
+threadpoolctl==3.2.0
+tokenizers==0.14.1
+tomli==2.0.1
+torch==2.1.0
+tqdm==4.66.1
+traitlets==5.11.2
+transformers==4.34.0
+triton==2.1.0
+typer==0.7.0
+types-PyYAML==6.0.3
+types-aiofiles==0.8.3
+types-setuptools==57.4.10
+typing-extensions==4.8.0
+tzdata==2023.3
+urllib3==2.0.6
+wasabi==0.10.1
+wcwidth==0.2.8
+widgetsnbextension==4.0.9
+xxhash==3.4.1
+yarl==1.9.2
\ No newline at end of file
diff --git a/medcat/cat.py b/medcat/cat.py
index 2323cd737..d3003b24b 100644
--- a/medcat/cat.py
+++ b/medcat/cat.py
@@ -2,7 +2,6 @@
 import glob
 import shutil
 import pickle
-import traceback
 import json
 import logging
 import math
@@ -24,7 +23,6 @@
 from medcat.pipe import Pipe
 from medcat.preprocessing.taggers import tag_skip_and_punct
 from medcat.cdb import CDB
-from medcat.utils.matutils import intersect_nonempty_set
 from medcat.utils.data_utils import make_mc_train_test, get_false_positives
 from medcat.utils.normalizers import BasicSpellChecker
 from medcat.utils.checkpoint import Checkpoint, CheckpointConfig, CheckpointManager
@@ -32,15 +30,16 @@
 from medcat.utils.hasher import Hasher
 from medcat.ner.vocab_based_ner import NER
 from medcat.linking.context_based_linker import Linker
-from medcat.utils.filters import get_project_filters
 from medcat.preprocessing.cleaners import prepare_name
 from medcat.meta_cat import MetaCAT
 from medcat.utils.meta_cat.data_utils import json_to_fake_spacy
-from medcat.config import Config, LinkingFilters
+from medcat.config import Config
 from medcat.vocab import Vocab
 from medcat.utils.decorators import deprecated
 from medcat.ner.transformers_ner import TransformersNER
 from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY
+from medcat.stats.stats import get_stats
+from medcat.utils.filters import set_project_filters
 
 
 logger = logging.getLogger(__name__) # separate logger from the package-level one
@@ -271,6 +270,10 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
         cdb_path = os.path.join(save_dir_path, "cdb.dat")
         self.cdb.save(cdb_path, json_path)
 
+        # Save the config
+        config_path = os.path.join(save_dir_path, "config.json")
+        self.cdb.config.save(config_path)
+
         # Save the Vocab
         vocab_path = os.path.join(save_dir_path, "vocab.dat")
         if self.vocab is not None:
@@ -362,6 +365,10 @@ def load_model_pack(cls,
         logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format')
         cdb = CDB.load(cdb_path, json_path)
 
+        # load config
+        config_path = os.path.join(model_pack_path, "config.json")
+        cdb.load_config(config_path)
+
         # TODO load addl_ner
 
         # Modify the config to contain full path to spacy model
@@ -434,7 +441,8 @@ def _print_stats(self,
                      use_overlaps: bool = False,
                      use_cui_doc_limit: bool = False,
                      use_groups: bool = False,
-                     extra_cui_filter: Optional[Set] = None) -> Tuple:
+                     extra_cui_filter: Optional[Set] = None,
+                     do_print: bool = True) -> Tuple:
         """TODO: Refactor and make nice
         Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP.
 
@@ -474,204 +482,12 @@ def _print_stats(self,
                 Number of occurrence for each CUI.
             examples (dict):
                 Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][].
+            do_print (bool):
+                Whether to print stats out. Defaults to True.
         """
-        tp = 0
-        fp = 0
-        fn = 0
-        fps: Dict = {}
-        fns: Dict = {}
-        tps: Dict = {}
-        cui_prec: Dict = {}
-        cui_rec: Dict = {}
-        cui_f1: Dict = {}
-        cui_counts: Dict = {}
-        examples: Dict = {'fp': {}, 'fn': {}, 'tp': {}}
-
-        fp_docs: Set = set()
-        fn_docs: Set = set()
-
-        orig_filters = self.config.linking.filters.copy_of()
-        local_filters = self.config.linking.filters
-        for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False):
-            local_filters.cuis = set()
-
-            # Add extra filter if set
-            self._set_project_filters(local_filters, project, extra_cui_filter, use_project_filters)
-
-            for dind, doc in tqdm(
-                enumerate(project["documents"]),
-                desc="Stats document",
-                total=len(project["documents"]),
-                leave=False,
-            ):
-                anns = self._get_doc_annotations(doc)
-
-                # Apply document level filtering, in this case project_filter is ignored while the extra_cui_filter is respected still
-                if use_cui_doc_limit:
-                    _cuis = set([ann['cui'] for ann in anns])
-                    if _cuis:
-                        local_filters.cuis = intersect_nonempty_set(_cuis, extra_cui_filter)
-                    else:
-                        local_filters.cuis = {'empty'}
-
-                spacy_doc: Doc = self(doc['text'])  # type: ignore
-
-                if use_overlaps:
-                    p_anns = spacy_doc._.ents
-                else:
-                    p_anns = spacy_doc.ents
-
-                anns_norm = []
-                anns_norm_neg = []
-                anns_examples = []
-                anns_norm_cui = []
-                for ann in anns:
-                    cui = ann['cui']
-                    if local_filters.check_filters(cui):
-                        if use_groups:
-                            cui = self.cdb.addl_info['cui2group'].get(cui, cui)
-
-                        if ann.get('validated', True) and (not ann.get('killed', False) and not ann.get('deleted', False)):
-                            anns_norm.append((ann['start'], cui))
-                            anns_examples.append({"text": doc['text'][max(0, ann['start']-60):ann['end']+60],
-                                                  "cui": cui,
-                                                  "start": ann['start'],
-                                                  "end": ann['end'],
-                                                  "source value": ann['value'],
-                                                  "acc": 1,
-                                                  "project name": project.get('name'),
-                                                  "document name": doc.get('name'),
-                                                  "project id": project.get('id'),
-                                                  "document id": doc.get('id')})
-                        elif ann.get('validated', True) and (ann.get('killed', False) or ann.get('deleted', False)):
-                            anns_norm_neg.append((ann['start'], cui))
-
-                        if ann.get("validated", True):
-                            # This is used to test was someone annotating for this CUI in this document
-                            anns_norm_cui.append(cui)
-                            cui_counts[cui] = cui_counts.get(cui, 0) + 1
-
-                p_anns_norm = []
-                p_anns_examples = []
-                for ann in p_anns:
-                    cui = ann._.cui
-                    if use_groups:
-                        cui = self.cdb.addl_info['cui2group'].get(cui, cui)
-
-                    p_anns_norm.append((ann.start_char, cui))
-                    p_anns_examples.append({"text": doc['text'][max(0, ann.start_char-60):ann.end_char+60],
-                                            "cui": cui,
-                                            "start": ann.start_char,
-                                            "end": ann.end_char,
-                                            "source value": ann.text,
-                                            "acc": float(ann._.context_similarity),
-                                            "project name": project.get('name'),
-                                            "document name": doc.get('name'),
-                                            "project id": project.get('id'),
-                                            "document id": doc.get('id')})
-                for iann, ann in enumerate(p_anns_norm):
-                    cui = ann[1]
-                    if ann in anns_norm:
-                        tp += 1
-                        tps[cui] = tps.get(cui, 0) + 1
-
-                        example = p_anns_examples[iann]
-                        examples['tp'][cui] = examples['tp'].get(cui, []) + [example]
-                    else:
-                        fp += 1
-                        fps[cui] = fps.get(cui, 0) + 1
-                        fp_docs.add(doc.get('name', 'unk'))
-
-                        # Add example for this FP prediction
-                        example = p_anns_examples[iann]
-                        if ann in anns_norm_neg:
-                            # Means that it really was annotated as negative
-                            example['real_fp'] = True
-
-                        examples['fp'][cui] = examples['fp'].get(cui, []) + [example]
-
-                for iann, ann in enumerate(anns_norm):
-                    if ann not in p_anns_norm:
-                        cui = ann[1]
-                        fn += 1
-                        fn_docs.add(doc.get('name', 'unk'))
-
-                        fns[cui] = fns.get(cui, 0) + 1
-                        examples['fn'][cui] = examples['fn'].get(cui, []) + [anns_examples[iann]]
-
-        try:
-            prec = tp / (tp + fp)
-            rec = tp / (tp + fn)
-            f1 = 2*(prec*rec) / (prec + rec)
-            print("Epoch: {}, Prec: {}, Rec: {}, F1: {}\n".format(epoch, prec, rec, f1))
-            print("Docs with false positives: {}\n".format("; ".join([str(x) for x in list(fp_docs)[0:10]])))
-            print("Docs with false negatives: {}\n".format("; ".join([str(x) for x in list(fn_docs)[0:10]])))
-
-            # Sort fns & prec
-            fps = {k: v for k, v in sorted(fps.items(), key=lambda item: item[1], reverse=True)}
-            fns = {k: v for k, v in sorted(fns.items(), key=lambda item: item[1], reverse=True)}
-            tps = {k: v for k, v in sorted(tps.items(), key=lambda item: item[1], reverse=True)}
-
-
-            # F1 per concept
-            for cui in tps.keys():
-                prec = tps[cui] / (tps.get(cui, 0) + fps.get(cui, 0))
-                rec = tps[cui] / (tps.get(cui, 0) + fns.get(cui, 0))
-                f1 = 2*(prec*rec) / (prec + rec)
-                cui_prec[cui] = prec
-                cui_rec[cui] = rec
-                cui_f1[cui] = f1
-
-
-            # Get top 10
-            pr_fps = [(self.cdb.cui2preferred_name.get(cui,
-                list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fps[cui]) for cui in list(fps.keys())[0:10]]
-            pr_fns = [(self.cdb.cui2preferred_name.get(cui,
-                list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fns[cui]) for cui in list(fns.keys())[0:10]]
-            pr_tps = [(self.cdb.cui2preferred_name.get(cui,
-                list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, tps[cui]) for cui in list(tps.keys())[0:10]]
-
-
-            print("\n\nFalse Positives\n")
-            for one in pr_fps:
-                print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2]))
-            print("\n\nFalse Negatives\n")
-            for one in pr_fns:
-                print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2]))
-            print("\n\nTrue Positives\n")
-            for one in pr_tps:
-                print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2]))
-            print("*"*110 + "\n")
-
-        except Exception:
-            traceback.print_exc()
-
-        self.config.linking.filters = orig_filters
-
-        return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples
-
-    def _set_project_filters(self, local_filters: LinkingFilters, project: dict,
-             extra_cui_filter: Optional[Set], use_project_filters: bool):
-        """Set the project filters to a LinkingFilters object based on
-        the specified project.
-
-        Args:
-            local_filters (LinkingFilters): The linking filters instance
-            project (dict): The project
-            extra_cui_filter (Optional[Set]): Extra CUIs (if specified)
-            use_project_filters (bool): Whether to use per-project filters
-        """
-        if isinstance(extra_cui_filter, set):
-            local_filters.cuis = extra_cui_filter
-
-        if use_project_filters:
-            project_filter = get_project_filters(cuis=project.get('cuis', None),
-                                                    type_ids=project.get('tuis', None),
-                                                    cdb=self.cdb,
-                                                    project=project)
-            # Intersect project filter with existing if it has something
-            if project_filter:
-                local_filters.cuis = intersect_nonempty_set(project_filter, local_filters.cuis)
+        return get_stats(self, data=data, epoch=epoch, use_project_filters=use_project_filters,
+                         use_overlaps=use_overlaps, use_cui_doc_limit=use_cui_doc_limit,
+                         use_groups=use_groups, extra_cui_filter=extra_cui_filter, do_print=do_print)
 
     def _init_ckpts(self, is_resumed, checkpoint):
         if self.config.general.checkpoint.steps is not None or checkpoint is not None:
@@ -832,9 +648,13 @@ def add_and_train_concept(self,
                 Refer to medcat.cat.cdb.CDB.add_concept
         """
         names = prepare_name(name, self.pipe.spacy_nlp, {}, self.config)
+        if not names and cui not in self.cdb.cui2preferred_name and name_status == 'P':
+            logger.warning("No names were able to be prepared in CAT.add_and_train_concept "
+                           "method. As such no preferred name will be able to be specifeid. "
+                           "The CUI: '%s' and raw name: '%s'", cui, name)
         # Only if not negative, otherwise do not add the new name if in fact it should not be detected
         if do_add_concept and not negative:
-            self.cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids, description=description,
+            self.cdb._add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids, description=description,
                                  full_build=full_build)
 
         if spacy_entity is not None and spacy_doc is not None:
@@ -1102,15 +922,15 @@ def train_supervised_raw(self,
                 # then add the extra CUI filters
                 if retain_filters and extra_cui_filter and not retain_extra_cui_filter:
                     # adding project filters without extra_cui_filters
-                    self._set_project_filters(local_filters, project, set(), use_filters)
+                    set_project_filters(self.cdb.addl_info, local_filters, project, set(), use_filters)
                     orig_filters.merge_with(local_filters)
                     # adding extra_cui_filters, but NOT project filters
-                    self._set_project_filters(local_filters, project, extra_cui_filter, False)
+                    set_project_filters(self.cdb.addl_info, local_filters, project, extra_cui_filter, False)
                     # refrain from doing it again for subsequent epochs
                     retain_filters = False
                 else:
                     # Set filters in case we are using the train_from_fp
-                    self._set_project_filters(local_filters, project, extra_cui_filter, use_filters)
+                    set_project_filters(self.cdb.addl_info, local_filters, project, extra_cui_filter, use_filters)
 
                 for idx_doc in trange(current_document, len(project['documents']), initial=current_document, total=len(project['documents']), desc='Document', leave=False):
                     doc = project['documents'][idx_doc]
@@ -1327,19 +1147,42 @@ def _save_docs_to_file(self, docs: Iterable, annotated_ids: List[str], save_dir_
             pickle.dump((annotated_ids, part_counter), open(annotated_ids_path, 'wb'))
         return part_counter
 
+    @deprecated(message="Use `multiprocessing_batch_char_size` instead")
     def multiprocessing(self,
                         data: Union[List[Tuple], Iterable[Tuple]],
                         nproc: int = 2,
                         batch_size_chars: int = 5000 * 1000,
                         only_cui: bool = False,
-                        addl_info: List[str] = [],
+                        addl_info: List[str] = ['cui2icd10', 'cui2ontologies', 'cui2snomed'],
                         separate_nn_components: bool = True,
                         out_split_size_chars: Optional[int] = None,
                         save_dir_path: str = os.path.abspath(os.getcwd()),
                         min_free_memory=0.1) -> Dict:
+        return self.multiprocessing_batch_char_size(data=data, nproc=nproc,
+                                                    batch_size_chars=batch_size_chars,
+                                                    only_cui=only_cui, addl_info=addl_info,
+                                                    separate_nn_components=separate_nn_components,
+                                                    out_split_size_chars=out_split_size_chars,
+                                                    save_dir_path=save_dir_path,
+                                                    min_free_memory=min_free_memory)
+
+    def multiprocessing_batch_char_size(self,
+                                        data: Union[List[Tuple], Iterable[Tuple]],
+                                        nproc: int = 2,
+                                        batch_size_chars: int = 5000 * 1000,
+                                        only_cui: bool = False,
+                                        addl_info: List[str] = [],
+                                        separate_nn_components: bool = True,
+                                        out_split_size_chars: Optional[int] = None,
+                                        save_dir_path: str = os.path.abspath(os.getcwd()),
+                                        min_free_memory=0.1) -> Dict:
         r"""Run multiprocessing for inference, if out_save_path and out_split_size_chars is used this will also continue annotating
         documents if something is saved in that directory.
 
+        This method batches the data based on the number of characters as specified by user.
+
+        PS: This method is unlikely to work on a Windows machine.
+
         Args:
             data:
                 Iterator or array with format: [(id, text), (id, text), ...]
@@ -1523,15 +1366,35 @@ def _multiprocessing_batch(self,
 
         return docs
 
-    def multiprocessing_pipe(self,
-                             in_data: Union[List[Tuple], Iterable[Tuple]],
+    @deprecated(message="Use `multiprocessing_batch_docs_size` instead")
+    def multiprocessing_pipe(self, in_data: Union[List[Tuple], Iterable[Tuple]],
                              nproc: Optional[int] = None,
                              batch_size: Optional[int] = None,
                              only_cui: bool = False,
                              addl_info: List[str] = [],
                              return_dict: bool = True,
                              batch_factor: int = 2) -> Union[List[Tuple], Dict]:
-        """Run multiprocessing NOT FOR TRAINING
+        return self.multiprocessing_batch_docs_size(in_data=in_data, nproc=nproc,
+                                                     batch_size=batch_size,
+                                                     only_cui=only_cui,
+                                                     addl_info=addl_info,
+                                                     return_dict=return_dict,
+                                                     batch_factor=batch_factor)
+
+    def multiprocessing_batch_docs_size(self,
+                             in_data: Union[List[Tuple], Iterable[Tuple]],
+                             nproc: Optional[int] = None,
+                             batch_size: Optional[int] = None,
+                             only_cui: bool = False,
+                             addl_info: List[str] = ['cui2icd10', 'cui2ontologies', 'cui2snomed'],
+                             return_dict: bool = True,
+                             batch_factor: int = 2) -> Union[List[Tuple], Dict]:
+        """Run multiprocessing NOT FOR TRAINING.
+
+        This method batches the data based on the number of documents as specified by the user.
+
+        PS:
+        This method supports Windows.
 
         Args:
             in_data (Union[List[Tuple], Iterable[Tuple]]): List with format: [(id, text), (id, text), ...]
diff --git a/medcat/cdb.py b/medcat/cdb.py
index 44d4fd9dd..76cb7327e 100644
--- a/medcat/cdb.py
+++ b/medcat/cdb.py
@@ -5,13 +5,15 @@
 import logging
 import aiofiles
 import numpy as np
-from typing import Dict, Set, Optional, List, Union
+from typing import Dict, Set, Optional, List, Union, cast
 from functools import partial
+import os
 
 from medcat import __version__
 from medcat.utils.hasher import Hasher
 from medcat.utils.matutils import unitvec
 from medcat.utils.ml_utils import get_lr_linking
+from medcat.utils.decorators import deprecated
 from medcat.config import Config, weighted_average, workers
 from medcat.utils.saving.serializer import CDBSerializer
 
@@ -61,8 +63,10 @@ class CDB(object):
     def __init__(self, config: Union[Config, None] = None) -> None:
         if config is None:
             self.config = Config()
+            self._config_from_file = False
         else:
             self.config = config
+            self._config_from_file = True
         self.name2cuis: Dict = {}
         self.name2cuis2status: Dict = {}
 
@@ -95,6 +99,12 @@ def __init__(self, config: Union[Config, None] = None) -> None:
         self._optim_params = None
         self.is_dirty = False
         self._hash: Optional[str] = None
+        # the config hash is kept track of here so that
+        # the CDB hash can be re-calculated when the config changes
+        # it can also be used to make sure the config loaded with
+        # a CDB matches the config it was saved with
+        # since the config is now saved separately
+        self._config_hash: Optional[str] = None
         self._memory_optimised_parts: Set[str] = set()
 
     def get_name(self, cui: str) -> str:
@@ -213,8 +223,9 @@ def add_names(self, cui: str, names: Dict, name_status: str = 'A', full_build: b
             # Name status must be one of the three
             name_status = 'A'
 
-        self.add_concept(cui=cui, names=names, ontologies=set(), name_status=name_status, type_ids=set(), description='', full_build=full_build)
+        self._add_concept(cui=cui, names=names, ontologies=set(), name_status=name_status, type_ids=set(), description='', full_build=full_build)
 
+    @deprecated("Use `cdb._add_concept` as this will be removed in a future release.")
     def add_concept(self,
                     cui: str,
                     names: Dict,
@@ -223,6 +234,43 @@ def add_concept(self,
                     type_ids: Set[str],
                     description: str,
                     full_build: bool = False) -> None:
+        """
+        Deprecated: Use `cdb._add_concept` as this will be removed in a future release.
+
+        Add a concept to internal Concept Database (CDB). Depending on what you are providing
+        this will add a large number of properties for each concept.
+
+        Args:
+            cui (str):
+                Concept ID or unique identifier in this database, all concepts that have
+                the same CUI will be merged internally.
+            names (Dict[str, Dict]):
+                Names for this concept, or the value that if found in free text can be linked to this concept.
+                Names is a dict like: `{name: {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}`
+                Names should be generated by helper function 'medcat.preprocessing.cleaners.prepare_name'
+            ontologies (Set[str]):
+                ontologies in which the concept exists (e.g. SNOMEDCT, HPO)
+            name_status (str):
+                One of `P`, `N`, `A`
+            type_ids (Set[str]):
+                Semantic type identifier (have a look at TUIs in UMLS or SNOMED-CT)
+            description (str):
+                Description of this concept.
+            full_build (bool):
+                If True the dictionary self.addl_info will also be populated, contains a lot of extra information
+                about concepts, but can be very memory consuming. This is not necessary
+                for normal functioning of MedCAT (Default Value `False`).
+        """
+        self._add_concept(cui, names, ontologies, name_status, type_ids, description, full_build)
+
+    def _add_concept(self,
+                    cui: str,
+                    names: Dict,
+                    ontologies: set,
+                    name_status: str,
+                    type_ids: Set[str],
+                    description: str,
+                    full_build: bool = False) -> None:
         """Add a concept to internal Concept Database (CDB). Depending on what you are providing
         this will add a large number of properties for each concept.
 
@@ -232,7 +280,8 @@ def add_concept(self,
                 the same CUI will be merged internally.
             names (Dict[str, Dict]):
                 Names for this concept, or the value that if found in free text can be linked to this concept.
-                Names is an dict like: `{name: {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}`
+                Names is a dict like: `{name: {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}`
+                Names should be generated by helper function 'medcat.preprocessing.cleaners.prepare_name'
             ontologies (Set[str]):
                 ontologies in which the concept exists (e.g. SNOMEDCT, HPO)
             name_status (str):
@@ -309,6 +358,21 @@ def add_concept(self,
             if name_status == 'P' and cui not in self.cui2preferred_name:
                 # Do not overwrite old preferred names
                 self.cui2preferred_name[cui] = name_info['raw_name']
+        elif names:
+            # if no name_info and names is NOT an empty dict
+            # this shouldn't really happen in the current setup
+            raise ValueError("Unknown state where there is no name_info, "
+                             "yet the `names` dict is not empty (%s)", names)
+        elif name_status == 'P' and cui not in self.cui2preferred_name:
+            # this means names is an empty `names` dict
+            logger.warning("Did not manage to add a preferred name in `add_cui`. "
+                           "Was trying to do so for cui: '%s'"
+                           "This means that the `names` dict passed was empty. "
+                           "This is _usually_ caused by either no name or too short "
+                           "a name passed to the `prepare_name` method. "
+                           "The minimum length is defined in: "
+                           "'config.cdb_maker.min_letters_required' and "
+                           "is currently set at %s", cui, self.config.cdb_maker['min_letters_required'])
 
         # Add other fields if full_build
         if full_build:
@@ -458,6 +522,35 @@ async def save_async(self, path: str) -> None:
             }
             await f.write(dill.dumps(to_save))
 
+    def load_config(self, config_path: str) -> None:
+        if not os.path.exists(config_path):
+            if not self._config_from_file:
+                # if there's no config defined anywhere
+                raise ValueError("Could not find a config in the CDB nor "
+                                 "in the config.json for this model "
+                                 f"({os.path.dirname(config_path)})",
+                                 )
+            # if there is a config, but it's defined in the cdb.dat file
+            logger.warning("Could not find config.json in model pack folder "
+                           f"({os.path.dirname(config_path)}). "
+                           "This is probably an older model. Please save the model "
+                           "again in the new format to avoid potential issues.")
+        else:
+            if self._config_from_file:
+                # if there's a config.json and one defined in the cbd.dat file
+                raise ValueError("Found a config in the CDB and in the config.json "
+                                 f"for model ({os.path.dirname(config_path)}) - "
+                                 "this is ambiguous. Please either remove the "
+                                 "config.json or load the CDB without the config.json "
+                                 "in the folder and re-save in the newer format "
+                                 "(the default save in this version)")
+            # if the only config is in the separate config.json file
+            # this should be the behaviour for all newer models
+            self.config = cast(Config, Config.load(config_path))
+            logger.debug("Loaded config from CDB from %s", config_path)
+        # mark config read from file
+        self._config_from_file = True
+
     @classmethod
     def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[Dict] = None) -> "CDB":
         """Load and return a CDB. This allows partial loads in probably not the right way at all.
@@ -777,8 +870,34 @@ def _check_medcat_version(cls, config_data: Dict) -> None:
 or download the compatible model."""
             )
 
+    def _should_recalc_hash(self, force_recalc: bool) -> bool:
+        if force_recalc:
+            return True
+        if self.config.hash is None:
+            # TODO - perhaps this is not the best?
+            # as this is a side effect
+            # get and save result in config
+            self.config.get_hash()
+        if not self._hash or self.is_dirty:
+            # if no hash saved or is dirty
+            # need to calculate
+            logger.debug("Recalculating hash due to %s",
+                         "no hash saved" if not self._hash else "CDB is dirty")
+            return True
+        # recalc config hash in case it changed
+        self.config.get_hash()
+        if self._config_hash is None or self._config_hash != self.config.hash:
+            # if no config hash saved
+            # or if the config hash is different from one saved in here
+            logger.debug("Recalculating hash due to %s",
+                         "no config hash saved" if not self._config_hash
+                         else "config hash has changed")
+            return True
+        return False
+
     def get_hash(self, force_recalc: bool = False):
-        if not force_recalc and self._hash and not self.is_dirty:
+        should_recalc = self._should_recalc_hash(force_recalc)
+        if not should_recalc:
             logger.info("Reusing old hash of CDB since the CDB has not changed: %s", self._hash)
             return self._hash
         self.is_dirty = False
@@ -791,7 +910,7 @@ def calculate_hash(self):
         for k,v in self.__dict__.items():
             if k in ['cui2countext_vectors', 'name2cuis']:
                 hasher.update(v, length=False)
-            elif k in ['_hash', 'is_dirty']:
+            elif k in ['_hash', 'is_dirty', '_config_hash']:
                 # ignore _hash since if it previously didn't exist, the
                 # new hash would be different when the value does exist
                 # and ignore is_dirty so that we get the same hash as previously
@@ -799,6 +918,9 @@ def calculate_hash(self):
             elif k != 'config':
                 hasher.update(v, length=True)
 
+        # set cached config hash
+        self._config_hash = self.config.hash
+
         self._hash = hasher.hexdigest()
         logger.info("Found new CDB hash: %s", self._hash)
         return self._hash
diff --git a/medcat/cdb_maker.py b/medcat/cdb_maker.py
index e9c72d12e..a4dd7dd27 100644
--- a/medcat/cdb_maker.py
+++ b/medcat/cdb_maker.py
@@ -49,6 +49,14 @@ def __init__(self, config: Config, cdb: Optional[CDB] = None) -> None:
                              name='skip_and_punct',
                              additional_fields=['is_punct'])
 
+    def reset_cdb(self) -> None:
+        """This will re-create a new internal CDB based on the same config.
+
+        This will be necessary if/when you're wishing to call `prepare_csvs`
+        multiple times on the same object `CDBMaker` instance.
+        """
+        self.cdb = CDB(config=self.config)
+
     def prepare_csvs(self,
                      csv_paths: Union[pd.DataFrame, List[str]],
                      sep: str = ',',
@@ -59,6 +67,12 @@ def prepare_csvs(self,
                      only_existing_cuis: bool = False, **kwargs) -> CDB:
         r"""Compile one or multiple CSVs into a CDB.
 
+        Note: This class/method generally uses the same instance of the CDB.
+              So if you're using the same CDBMaker and calling `prepare_csvs`
+              multiple times, you are likely to get leakage from prior calls
+              into new ones.
+              To reset the CDB, call `reset_cdb`.
+
         Args:
             csv_paths (Union[pd.DataFrame, List[str]]):
                 An array of paths to the csv files that should be processed. Can also be an array of pd.DataFrames
@@ -173,7 +187,7 @@ def prepare_csvs(self,
                             if len(raw_name) >= self.config.cdb_maker['remove_parenthesis']:
                                 prepare_name(raw_name, self.pipe.spacy_nlp, names, self.config)
 
-                    self.cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids,
+                    self.cdb._add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids,
                                          description=description, full_build=full_build)
                     # DEBUG
                     logger.debug("\n\n**** Added\n CUI: %s\n Names: %s\n Ontologies: %s\n Name status: %s\n Type IDs: %s\n Description: %s\n Is full build: %s",
diff --git a/medcat/config.py b/medcat/config.py
index b2e324deb..e60c2eafc 100644
--- a/medcat/config.py
+++ b/medcat/config.py
@@ -433,6 +433,19 @@ class LinkingFilters(MixingConfig, BaseModel):
     cuis: Set[str] = set()
     cuis_exclude: Set[str] = set()
 
+    def __init__(self, **data):
+        if 'cuis' in data:
+            cuis = data['cuis']
+            if isinstance(cuis, dict) and len(cuis) == 0:
+                logger.warning("Loading an old model where "
+                               "config.linking.filters.cuis has been "
+                               "dict to an empty dict instead of an empty "
+                               "set. Converting the dict to a set in memory "
+                               "as that is what is expected. Please consider "
+                               "saving the model again.")
+                data['cuis'] = set(cuis.keys())
+        super().__init__(**data)
+
     def check_filters(self, cui: str) -> bool:
         """Checks is a CUI in the filters
 
@@ -535,6 +548,7 @@ class Config(MixingConfig, BaseModel):
     linking: Linking = Linking()
     word_skipper: re.Pattern = re.compile('') # empty pattern gets replaced upon init
     punct_checker: re.Pattern = re.compile('') # empty pattern gets replaced upon init
+    hash: Optional[str] = None
 
     class Config:
         # this if for word_skipper and punct_checker which would otherwise
@@ -559,6 +573,9 @@ def rebuild_re(self) -> None:
     def get_hash(self):
         hasher = Hasher()
         for k, v in self.dict().items():
+            if k in ['hash', ]:
+                # ignore hash
+                continue
             if k not in ['version', 'general', 'linking']:
                 hasher.update(v, length=True)
             elif k == 'general':
@@ -574,5 +591,5 @@ def get_hash(self):
                         hasher.update(v2, length=False)
                     else:
                         hasher.update(v2, length=True)
-
-        return hasher.hexdigest()
+        self.hash = hasher.hexdigest()
+        return self.hash
diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py
index 9623b1b93..227ccc083 100644
--- a/medcat/ner/transformers_ner.py
+++ b/medcat/ner/transformers_ner.py
@@ -1,6 +1,7 @@
 import os
 import json
 import logging
+import datasets
 from spacy.tokens import Doc
 from datetime import datetime
 from typing import Iterable, Iterator, Optional, Dict, List, cast, Union
@@ -18,7 +19,7 @@
 
 from transformers import Trainer, AutoModelForTokenClassification, AutoTokenizer
 from transformers import pipeline, TrainingArguments
-import datasets
+from transformers.trainer_callback import TrainerCallback
 
 # It should be safe to do this always, as all other multiprocessing
 #will be finished before data comes to meta_cat
@@ -137,7 +138,12 @@ def merge_data_loaded(base, other):
 
         return out_path
 
-    def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=False, dataset=None, meta_requirements=None):
+    def train(self,
+              json_path: Union[str, list, None]=None,
+              ignore_extra_labels=False,
+              dataset=None,
+              meta_requirements=None,
+              trainer_callbacks: Optional[List[TrainerCallback]]=None):
         """Train or continue training a model give a json_path containing a MedCATtrainer export. It will
         continue training if an existing model is loaded or start new training if the model is blank/new.
 
@@ -149,6 +155,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals
             ignore_extra_labels:
                 Makes only sense when an existing deid model was loaded and from the new data we want to ignore
                 labels that did not exist in the old model.
+            trainer_callbacks (List[TrainerCallback]):
+                A list of trainer callbacks for collecting metrics during the training at the client side. The
+                transformers Trainer object will be passed in when each callback is called.
         """
 
         if dataset is None and json_path is not None:
@@ -193,6 +202,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals
                 compute_metrics=lambda p: metrics(p, tokenizer=self.tokenizer, dataset=encoded_dataset['test'], verbose=self.config.general['verbose_metrics']),
                 data_collator=data_collator, # type: ignore
                 tokenizer=None)
+        if trainer_callbacks:
+            for callback in trainer_callbacks:
+                trainer.add_callback(callback(trainer))
 
         trainer.train() # type: ignore
 
diff --git a/medcat/pipe.py b/medcat/pipe.py
index 3861267df..7bf06364b 100644
--- a/medcat/pipe.py
+++ b/medcat/pipe.py
@@ -1,4 +1,5 @@
 import types
+import os
 import spacy
 import gc
 import logging
@@ -17,11 +18,15 @@
 from medcat.pipeline.pipe_runner import PipeRunner
 from medcat.preprocessing.taggers import tag_skip_and_punct
 from medcat.ner.transformers_ner import TransformersNER
+from medcat.utils.helpers import ensure_spacy_model
 
 
 logger = logging.getLogger(__name__) # different logger from the package-level one
 
 
+DEFAULT_SPACY_MODEL = 'en_core_web_md'
+
+
 class Pipe(object):
     """A wrapper around the standard spacy pipeline.
 
@@ -38,9 +43,27 @@ class Pipe(object):
     """
 
     def __init__(self, tokenizer: Tokenizer, config: Config) -> None:
-        self._nlp = spacy.load(config.general.spacy_model, disable=config.general.spacy_disabled_components)
         if config.preprocessing.stopwords is not None:
-            self._nlp.Defaults.stop_words = set(config.preprocessing.stopwords)
+            lang = os.path.basename(config.general.spacy_model).split('_', 1)[0]
+            cls = spacy.util.get_lang_class(lang)
+            cls.Defaults.stop_words = set(config.preprocessing.stopwords)
+        try:
+            self._nlp = self._init_nlp(config)
+        except Exception as e:
+            if config.general.spacy_model == DEFAULT_SPACY_MODEL:
+                raise e
+            logger.warning("Could not load spacy model from '%s'. "
+                           "Falling back to installed en_core_web_md. "
+                           "For best compatibility, we'd recommend "
+                           "packaging and using your model pack with "
+                           "the spacy model it was designed for",
+                           config.general.spacy_model, exc_info=e)
+            # we're changing the config value so that this propages
+            # to other places that try to load the model. E.g:
+            # medcat.utils.normalizers.TokenNormalizer.__init__
+            ensure_spacy_model(DEFAULT_SPACY_MODEL)
+            config.general.spacy_model = DEFAULT_SPACY_MODEL
+            self._nlp = self._init_nlp(config)
         self._nlp.tokenizer = tokenizer(self._nlp, config)
         # Set max document length
         self._nlp.max_length = config.preprocessing.max_document_length
@@ -48,6 +71,9 @@ def __init__(self, tokenizer: Tokenizer, config: Config) -> None:
         # Set log level
         logger.setLevel(self.config.general.log_level)
 
+    def _init_nlp(selef, config: Config) -> Language:
+        return spacy.load(config.general.spacy_model, disable=config.general.spacy_disabled_components)
+
     def add_tagger(self, tagger: Callable, name: Optional[str] = None, additional_fields: List[str] = []) -> None:
         """Add any kind of a tagger for tokens.
 
diff --git a/medcat/preprocessing/cleaners.py b/medcat/preprocessing/cleaners.py
index 18314d562..43e8098e2 100644
--- a/medcat/preprocessing/cleaners.py
+++ b/medcat/preprocessing/cleaners.py
@@ -48,7 +48,8 @@ def prepare_name(raw_name: str, nlp: Language, names: Dict, config: Config) -> D
             snames = set()
             name = config.general['separator'].join(tokens)
 
-            if not config.cdb_maker.get('min_letters_required', 0) or len(re.sub("[^A-Za-z]*", '', name)) >= config.cdb_maker.get('min_letters_required', 0):
+            min_letters = config.cdb_maker.get('min_letters_required', 0)
+            if not min_letters or len(re.sub("[^A-Za-z]*", '', name)) >= min_letters:
                 if name not in names:
                     sname = ""
                     for token in tokens:
diff --git a/medcat/stats/__init__.py b/medcat/stats/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/medcat/stats/stats.py b/medcat/stats/stats.py
new file mode 100644
index 000000000..06b712158
--- /dev/null
+++ b/medcat/stats/stats.py
@@ -0,0 +1,340 @@
+from typing import Dict, Optional, Set, Tuple, Callable, List, cast
+
+from tqdm import tqdm
+import traceback
+
+from spacy.tokens import Doc
+
+from medcat.utils.filters import set_project_filters
+from medcat.utils.matutils import intersect_nonempty_set
+from medcat.config import LinkingFilters
+
+
+class StatsBuilder:
+
+    def __init__(self,
+                 filters: LinkingFilters,
+                 addl_info: dict,
+                 doc_getter: Callable[[Optional[str], bool], Optional[Doc]],
+                 doc_annotation_getter: Callable[[dict], list],
+                 cui2group: Dict[str, str],
+                 cui2preferred_name: Dict[str, str],
+                 cui2names: Dict[str, Set[str]],
+                 use_project_filters: bool = False,
+                 use_overlaps: bool = False,
+                 use_cui_doc_limit: bool = False,
+                 use_groups: bool = False,
+                 extra_cui_filter: Optional[Set] = None) -> None:
+        self.filters = filters
+        self.addl_info = addl_info
+        self.doc_getter = doc_getter
+        self._get_doc_annotations = doc_annotation_getter
+        self.cui2group = cui2group
+        self.cui2preferred_name = cui2preferred_name
+        self.cui2names = cui2names
+        self.use_project_filters = use_project_filters
+        self.use_overlaps = use_overlaps
+        self.use_cui_doc_limit = use_cui_doc_limit
+        self.use_groups = use_groups
+        self.extra_cui_filter = extra_cui_filter
+        self._reset_stats()
+
+    def _reset_stats(self):
+        self.tp = 0
+        self.fp = 0
+        self.fn = 0
+        self.fps: Dict = {}
+        self.fns: Dict = {}
+        self.tps: Dict = {}
+        self.cui_prec: Dict = {}
+        self.cui_rec: Dict = {}
+        self.cui_f1: Dict = {}
+        self.cui_counts: Dict = {}
+        self.examples: Dict = {'fp': {}, 'fn': {}, 'tp': {}}
+        self.fp_docs: Set = set()
+        self.fn_docs: Set = set()
+
+    def process_project(self, project: dict) -> None:
+        self.filters.cuis = set()
+
+        # Add extra filter if set
+        set_project_filters(self.addl_info, self.filters, project, self.extra_cui_filter, self.use_project_filters)
+
+        documents = project["documents"]
+        for dind, doc in tqdm(
+            enumerate(documents),
+            desc="Stats document",
+            total=len(documents),
+            leave=False,
+        ):
+            self.process_document(cast(str, project.get('name')),
+                                  cast(str, project.get('id')), doc)
+
+    def process_document(self, project_name: str, project_id: str, doc: dict) -> None:
+        anns = self._get_doc_annotations(doc)
+
+        # Apply document level filtering, in this case project_filter is ignored while the extra_cui_filter is respected still
+        if self.use_cui_doc_limit:
+            _cuis = set([ann['cui'] for ann in anns])
+            if _cuis:
+                self.filters.cuis = intersect_nonempty_set(_cuis, self.extra_cui_filter)
+            else:
+                self.filters.cuis = {'empty'}
+
+        spacy_doc: Doc = self.doc_getter(doc['text'])  # type: ignore
+
+        if self.use_overlaps:
+            p_anns = spacy_doc._.ents
+        else:
+            p_anns = spacy_doc.ents
+
+        (anns_norm, anns_norm_neg,
+         anns_examples, _) = self._preprocess_annotations(project_name, project_id, doc, anns)
+
+        p_anns_norm, p_anns_examples = self._process_p_anns(project_name, project_id,
+                                                            doc, p_anns)
+        self._count_p_anns_norm(doc, anns_norm, anns_norm_neg,
+                                p_anns_norm, p_anns_examples)
+        self._process_anns_norm(doc, anns_norm, p_anns_norm, anns_examples)
+
+    def _process_anns_norm(self, doc: dict, anns_norm: list, p_anns_norm: list,
+                           anns_examples: list) -> None:
+        for iann, ann in enumerate(anns_norm):
+            if ann not in p_anns_norm:
+                cui = ann[1]
+                self.fn += 1
+                self.fn_docs.add(doc.get('name', 'unk'))
+
+                self.fns[cui] = self.fns.get(cui, 0) + 1
+                self.examples['fn'][cui] = self.examples['fn'].get(cui, []) + [anns_examples[iann]]
+
+    def _process_p_anns(self, project_name: str, project_id: str, doc: dict, p_anns: list) -> Tuple[list, list]:
+        p_anns_norm = []
+        p_anns_examples = []
+        for ann in p_anns:
+            cui = ann._.cui
+            if self.use_groups:
+                cui = self.cui2group.get(cui, cui)
+
+            p_anns_norm.append((ann.start_char, cui))
+            p_anns_examples.append(self._create_annoation_2(project_name, project_id, cui, doc, ann))
+        return p_anns_norm, p_anns_examples
+
+    def _count_p_anns_norm(self, doc: dict, anns_norm: list, anns_norm_neg: list,
+                           p_anns_norm: list, p_anns_examples: list) -> None:
+        for iann, ann in enumerate(p_anns_norm):
+            cui = ann[1]
+            if ann in anns_norm:
+                self.tp += 1
+                self.tps[cui] = self.tps.get(cui, 0) + 1
+
+                example = p_anns_examples[iann]
+                self.examples['tp'][cui] = self.examples['tp'].get(cui, []) + [example]
+            else:
+                self.fp += 1
+                self.fps[cui] = self.fps.get(cui, 0) + 1
+                self.fp_docs.add(doc.get('name', 'unk'))
+
+                # Add example for this FP prediction
+                example = p_anns_examples[iann]
+                if ann in anns_norm_neg:
+                    # Means that it really was annotated as negative
+                    example['real_fp'] = True
+
+                self.examples['fp'][cui] = self.examples['fp'].get(cui, []) + [example]
+
+    def _create_annoation(self, project_name: str, project_id: str, cui: str, doc: dict, ann: Dict) -> Dict:
+        return {"text": doc['text'][max(0, ann['start']-60):ann['end']+60],
+                "cui": cui,
+                "start": ann['start'],
+                "end": ann['end'],
+                "source value": ann['value'],
+                "acc": 1,
+                "project name": project_name,
+                "document name": doc.get('name'),
+                "project id": project_id,
+                "document id": doc.get('id')}
+
+    def _create_annoation_2(self, project_name: str, project_id: str, cui: str, doc: dict, ann) -> Dict:
+        return {"text": doc['text'][max(0, ann.start_char-60):ann.end_char+60],
+                "cui": cui,
+                "start": ann.start_char,
+                "end": ann.end_char,
+                "source value": ann.text,
+                "acc": float(ann._.context_similarity),
+                "project name": project_name,
+                "document name": doc.get('name'),
+                "project id": project_id,
+                "document id": doc.get('id')}
+
+    def _preprocess_annotations(self, project_name: str, project_id: str,
+                                doc: dict, anns: List[Dict]) -> Tuple[list, list, list, list]:
+        anns_norm = []
+        anns_norm_neg = []
+        anns_examples = []
+        anns_norm_cui = []
+        for ann in anns:
+            cui = ann['cui']
+            if self.filters.check_filters(cui):
+                if self.use_groups:
+                    cui = self.cui2group.get(cui, cui)
+
+                if ann.get('validated', True) and (not ann.get('killed', False) and not ann.get('deleted', False)):
+                    anns_norm.append((ann['start'], cui))
+                    anns_examples.append(self._create_annoation(project_name, project_id, cui, doc, ann))
+                elif ann.get('validated', True) and (ann.get('killed', False) or ann.get('deleted', False)):
+                    anns_norm_neg.append((ann['start'], cui))
+
+                if ann.get("validated", True):
+                    # This is used to test was someone annotating for this CUI in this document
+                    anns_norm_cui.append(cui)
+                    self.cui_counts[cui] = self.cui_counts.get(cui, 0) + 1
+        return anns_norm, anns_norm_neg, anns_examples, anns_norm_cui
+
+    def finalise_report(self, epoch: int, do_print: bool = True):
+        try:
+            prec = self.tp / (self.tp + self.fp)
+            rec = self.tp / (self.tp + self.fn)
+            f1 = 2*(prec*rec) / (prec + rec)
+            if do_print:
+                print("Epoch: {}, Prec: {}, Rec: {}, F1: {}\n".format(epoch, prec, rec, f1))
+                print("Docs with false positives: {}\n".format("; ".join([str(x) for x in list(self.fp_docs)[0:10]])))
+                print("Docs with false negatives: {}\n".format("; ".join([str(x) for x in list(self.fn_docs)[0:10]])))
+
+            # Sort fns & prec
+            fps = {k: v for k, v in sorted(self.fps.items(), key=lambda item: item[1], reverse=True)}
+            fns = {k: v for k, v in sorted(self.fns.items(), key=lambda item: item[1], reverse=True)}
+            tps = {k: v for k, v in sorted(self.tps.items(), key=lambda item: item[1], reverse=True)}
+
+
+            # F1 per concept
+            for cui in tps.keys():
+                prec = tps[cui] / (tps.get(cui, 0) + fps.get(cui, 0))
+                rec = tps[cui] / (tps.get(cui, 0) + fns.get(cui, 0))
+                f1 = 2*(prec*rec) / (prec + rec)
+                self.cui_prec[cui] = prec
+                self.cui_rec[cui] = rec
+                self.cui_f1[cui] = f1
+
+
+            # Get top 10
+            pr_fps = [(self.cui2preferred_name.get(cui,
+                list(self.cui2names.get(cui, [cui]))[0]), cui, fps[cui]) for cui in list(fps.keys())[0:10]]
+            pr_fns = [(self.cui2preferred_name.get(cui,
+                list(self.cui2names.get(cui, [cui]))[0]), cui, fns[cui]) for cui in list(fns.keys())[0:10]]
+            pr_tps = [(self.cui2preferred_name.get(cui,
+                list(self.cui2names.get(cui, [cui]))[0]), cui, tps[cui]) for cui in list(tps.keys())[0:10]]
+
+            if do_print:
+                print("\n\nFalse Positives\n")
+                for one in pr_fps:
+                    print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2]))
+                print("\n\nFalse Negatives\n")
+                for one in pr_fns:
+                    print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2]))
+                print("\n\nTrue Positives\n")
+                for one in pr_tps:
+                    print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2]))
+                print("*"*110 + "\n")
+
+        except Exception:
+            traceback.print_exc()
+
+    def unwrap(self) -> Tuple:
+        return (self.fps, self.fns, self.tps,
+                self.cui_prec, self.cui_rec, self.cui_f1,
+                self.cui_counts, self.examples)
+
+    @classmethod
+    def from_cat(cls, cat,
+                 local_filters: LinkingFilters,
+                 use_project_filters: bool = False,
+                 use_overlaps: bool = False,
+                 use_cui_doc_limit: bool = False,
+                 use_groups: bool = False,
+                 extra_cui_filter: Optional[Set] = None) -> 'StatsBuilder':
+        return StatsBuilder(filters=local_filters,
+                            addl_info=cat.cdb.addl_info,
+                            doc_getter=cat.__call__,
+                            doc_annotation_getter=cat._get_doc_annotations,
+                            cui2group=cat.cdb.addl_info['cui2group'],
+                            cui2preferred_name=cat.cdb.cui2preferred_name,
+                            cui2names=cat.cdb.cui2names,
+                            use_project_filters=use_project_filters,
+                            use_overlaps=use_overlaps,
+                            use_cui_doc_limit=use_cui_doc_limit,
+                            use_groups=use_groups,
+                            extra_cui_filter=extra_cui_filter)
+
+
+def get_stats(cat,
+              data: Dict,
+              epoch: int = 0,
+              use_project_filters: bool = False,
+              use_overlaps: bool = False,
+              use_cui_doc_limit: bool = False,
+              use_groups: bool = False,
+              extra_cui_filter: Optional[Set] = None,
+              do_print: bool = True) -> Tuple:
+    """TODO: Refactor and make nice
+    Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP.
+
+    Args:
+        cat: (CAT):
+            The model pack.
+        data (list of dict):
+            The json object that we get from MedCATtrainer on export.
+        epoch (int):
+            Used during training, so we know what epoch is it.
+        use_project_filters (boolean):
+            Each project in MedCATtrainer can have filters, do we want to respect those filters
+            when calculating metrics.
+        use_overlaps (boolean):
+            Allow overlapping entities, nearly always False as it is very difficult to annotate overlapping entites.
+        use_cui_doc_limit (boolean):
+            If True the metrics for a CUI will be only calculated if that CUI appears in a document, in other words
+            if the document was annotated for that CUI. Useful in very specific situations when during the annotation
+            process the set of CUIs changed.
+        use_groups (boolean):
+            If True concepts that have groups will be combined and stats will be reported on groups.
+        extra_cui_filter(Optional[Set]):
+            This filter will be intersected with all other filters, or if all others are not set then only this one will be used.
+
+    Returns:
+        fps (dict):
+            False positives for each CUI.
+        fns (dict):
+            False negatives for each CUI.
+        tps (dict):
+            True positives for each CUI.
+        cui_prec (dict):
+            Precision for each CUI.
+        cui_rec (dict):
+            Recall for each CUI.
+        cui_f1 (dict):
+            F1 for each CUI.
+        cui_counts (dict):
+            Number of occurrence for each CUI.
+        examples (dict):
+            Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][].
+        do_print (bool):
+            Whether to print stats out. Defaults to True.
+    """
+    orig_filters = cat.config.linking.filters.copy_of()
+    local_filters = cat.config.linking.filters
+    builder = StatsBuilder.from_cat(cat,
+                                    local_filters=local_filters,
+                                    use_project_filters=use_project_filters,
+                                    use_overlaps=use_overlaps,
+                                    use_cui_doc_limit=use_cui_doc_limit,
+                                    use_groups=use_groups,
+                                    extra_cui_filter=extra_cui_filter)
+    for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False):
+        builder.process_project(project)
+
+    # this is the part that prints out the stats
+    builder.finalise_report(epoch, do_print=do_print)
+
+    cat.config.linking.filters = orig_filters
+
+    return builder.unwrap()
diff --git a/medcat/utils/cdb_utils.py b/medcat/utils/cdb_utils.py
new file mode 100644
index 000000000..445fb7d6f
--- /dev/null
+++ b/medcat/utils/cdb_utils.py
@@ -0,0 +1,117 @@
+import logging
+import numpy as np
+
+from copy import deepcopy
+from medcat.cdb import CDB
+
+logger = logging.getLogger(__name__) # separate logger from the package-level one
+
+
+def merge_cdb(cdb1: "CDB", 
+            cdb2: "CDB", 
+            overwrite_training: int = 0,
+            full_build: bool = False):
+    """Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed.
+    `addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build`
+
+        Args:
+            cdb1 (medcat.cdb.CDB):
+                The first medcat cdb to merge. In cases where merging isn't suitable isn't ideal (such as
+                cui2preferred_name), this cdb values will be prioritised over cdb2.
+            cdb2 (medcat.cdb.CDB):
+                The second medcat cdb to merge.
+            overwrite_training (int):
+                Choose to prioritise a CDB's context vectors values over merging gracefully. 0 - no prio, 1 - CDB1, 2 - CDB2
+            full_build (bool):
+                Add additional information from "addl_info" dicts "cui2ontologies" and "cui2description"
+    """
+    config = deepcopy(cdb1.config)
+    cdb = CDB(config)
+
+    # Copy CDB 1 - as all settings from CDB 1 will be carried over
+    cdb.cui2names = deepcopy(cdb1.cui2names)
+    cdb.cui2snames = deepcopy(cdb1.cui2snames)
+    cdb.cui2count_train = deepcopy(cdb1.cui2count_train)
+    cdb.cui2info = deepcopy(cdb1.cui2info)
+    cdb.cui2context_vectors = deepcopy(cdb1.cui2context_vectors)
+    cdb.cui2tags = deepcopy(cdb1.cui2tags)
+    cdb.cui2type_ids = deepcopy(cdb1.cui2type_ids)
+    cdb.cui2preferred_name = deepcopy(cdb1.cui2preferred_name)
+    cdb.name2cuis = deepcopy(cdb1.name2cuis)
+    cdb.name2cuis2status = deepcopy(cdb1.name2cuis2status)
+    cdb.name2count_train = deepcopy(cdb1.name2count_train)
+    cdb.name_isupper = deepcopy(cdb1.name_isupper)
+    if full_build:
+        cdb.addl_info = deepcopy(cdb1.addl_info)
+
+    # handles cui2names, cui2snames, name_isupper, name2cuis, name2cuis2status, cui2preferred_name
+    for cui in cdb2.cui2names:
+        names = dict()
+        for name in cdb2.cui2names[cui]:
+            names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)}
+            name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A'
+        # For addl_info check cui2original_names as they MUST be added
+        ontologies = set()
+        description = ''
+        to_build = False
+        if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']):
+            to_build = True
+            if 'cui2ontologies' in cdb2.addl_info:
+                ontologies.update(cdb2.addl_info['cui2ontologies'][cui])
+            if 'cui2description' in cdb2.addl_info:
+                description = cdb2.addl_info['cui2description'][cui]
+        cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status,
+                        type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build)
+        if cui in cdb1.cui2names:
+            if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train): 
+                if overwrite_training == 2 and cui in cdb2.cui2count_train:
+                    cdb.cui2count_train[cui] = cdb2.cui2count_train[cui]
+                else:
+                    cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0)
+            if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]):
+                if overwrite_training == 2 and cui in cdb2.cui2context_vectors:
+                    weights = [0, 1]
+                else:
+                    norm = cdb.cui2count_train[cui]
+                    weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)]
+                contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short
+                for s in contexts: 
+                    cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300))))
+            if cui in cdb1.cui2tags: 
+                cdb.cui2tags[cui].append(cdb2.cui2tags[cui])
+            if cui in cdb1.cui2type_ids: 
+                cdb.cui2type_ids[cui] = cdb1.cui2type_ids[cui].union(cdb2.cui2type_ids[cui])
+        else:
+            if cui in cdb2.cui2count_train: 
+                cdb.cui2count_train[cui] = cdb2.cui2names[cui]
+            if cui in cdb2.cui2info: 
+                cdb.cui2info[cui] = cdb2.cui2info[cui]
+            if cui in cdb2.cui2context_vectors: 
+                cdb.cui2context_vectors[cui] = cdb2.cui2context_vectors[cui]
+            if cui in cdb2.cui2tags: 
+                cdb.cui2tags[cui] = cdb2.cui2tags[cui]
+            if cui in cdb2.cui2type_ids: 
+                cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui]
+
+    if overwrite_training != 1:
+        for name in cdb2.name2cuis:
+            if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs
+                if name in cdb1.name2count_train and name in cdb2.name2count_train:
+                    cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason
+            else:
+                if name in cdb2.name2count_train: 
+                    cdb.name2count_train[name] = cdb2.name2count_train[name]
+
+    # snames
+    cdb.snames = cdb1.snames.union(cdb2.snames)
+
+    # vocab, adding counts if they occur in both
+    cdb.vocab = deepcopy(cdb1.vocab)
+    if overwrite_training != 1:
+        for word in cdb2.vocab:
+            if word in cdb.vocab and overwrite_training == 0:
+                cdb.vocab[word] += cdb2.vocab[word]
+            else:
+                cdb.vocab[word] = cdb2.vocab[word]
+
+    return cdb
diff --git a/medcat/utils/filters.py b/medcat/utils/filters.py
index c4803027a..cb85e0e26 100644
--- a/medcat/utils/filters.py
+++ b/medcat/utils/filters.py
@@ -1,3 +1,9 @@
+from typing import Optional, Set, Dict
+
+from medcat.config import LinkingFilters
+from medcat.utils.matutils import intersect_nonempty_set
+
+
 def check_filters(cui, filters):
     """Checks is a CUI in the filters
 
@@ -15,7 +21,7 @@ def check_filters(cui, filters):
         return False
 
 
-def get_all_irrelevant_cuis(project, cdb):
+def get_all_irrelevant_cuis(project):
     i_cuis = set()
     for d in project['documents']:
         for a in d['annotations']:
@@ -24,7 +30,7 @@ def get_all_irrelevant_cuis(project, cdb):
     return i_cuis
 
 
-def get_project_filters(cuis, type_ids, cdb, project=None):
+def get_project_filters(cuis, type_ids, addl_info: Dict, project=None):
     cui_filter = set()
     if isinstance(cuis, str):
         if cuis is not None and cuis:
@@ -33,10 +39,10 @@ def get_project_filters(cuis, type_ids, cdb, project=None):
             type_ids = [x.strip().upper() for x in type_ids.split(",")]
 
             # Convert type_ids to cuis
-            if 'type_id2cuis' in cdb.addl_info:
+            if 'type_id2cuis' in addl_info:
                 for type_id in type_ids:
-                    if type_id in cdb.addl_info['type_id2cuis']:
-                        cui_filter.update(cdb.addl_info['type_id2cuis'][type_id])
+                    if type_id in addl_info['type_id2cuis']:
+                        cui_filter.update(addl_info['type_id2cuis'][type_id])
                     else:
                         raise Exception("Impossible to create filters, disable them.")
             else:
@@ -45,8 +51,33 @@ def get_project_filters(cuis, type_ids, cdb, project=None):
         cui_filter = set(cuis)
 
     if project is not None:
-        i_cuis = get_all_irrelevant_cuis(project, cdb)
+        i_cuis = get_all_irrelevant_cuis(project)
         for i_cui in i_cuis:
             cui_filter.remove(i_cui)
 
     return cui_filter
+
+
+def set_project_filters(addl_info: Dict, local_filters: LinkingFilters, project: dict,
+            extra_cui_filter: Optional[Set], use_project_filters: bool):
+    """Set the project filters to a LinkingFilters object based on
+    the specified project.
+
+    Args:
+        addl_info (Dict): The CDB additional information
+        local_filters (LinkingFilters): The linking filters instance
+        project (dict): The project
+        extra_cui_filter (Optional[Set]): Extra CUIs (if specified)
+        use_project_filters (bool): Whether to use per-project filters
+    """
+    if isinstance(extra_cui_filter, set):
+        local_filters.cuis = extra_cui_filter
+
+    if use_project_filters:
+        project_filter = get_project_filters(cuis=project.get('cuis', None),
+                                             type_ids=project.get('tuis', None),
+                                             addl_info=addl_info,
+                                             project=project)
+        # Intersect project filter with existing if it has something
+        if project_filter:
+            local_filters.cuis = intersect_nonempty_set(project_filter, local_filters.cuis)
diff --git a/medcat/utils/helpers.py b/medcat/utils/helpers.py
index f783a9b06..816b316ce 100644
--- a/medcat/utils/helpers.py
+++ b/medcat/utils/helpers.py
@@ -537,3 +537,35 @@ def has_new_spacy() -> bool:
     return (major > 3 or
             (major == 3 and minor > 3) or
             (major == 3 and minor == 3 and patch >= 1))
+
+
+def has_spacy_model(model_name: str) -> bool:
+    """Checks if the spacy model is available.
+
+    Args:
+        model_name (str): The model name.
+
+    Returns:
+        bool: True if the model is available, False otherwise.
+    """
+    import spacy.util
+    return model_name in spacy.util.get_installed_models()
+
+
+def ensure_spacy_model(model_name: str) -> None:
+    """Ensure the specified spacy model exists.
+
+    If the model does not currently exist, it will attempt downloading it.
+
+    Args:
+        model_name (str): The spacy model name.
+    """
+    import subprocess
+    if has_spacy_model(model_name):
+        return
+    # running in subprocess so that we can catch the exception
+    # if the model name is unknown. Otherwise we'd just be bumped
+    # out of python (sys.exit).
+    logger.info("Installing the spacy model %s using the CLI command "
+                "'python -m spacy download %s'", model_name, model_name)
+    subprocess.run(["python", "-m", "spacy", "download", model_name], check=True)
diff --git a/medcat/utils/regression/targeting.py b/medcat/utils/regression/targeting.py
index 19f19bb3f..7a13b2bcc 100644
--- a/medcat/utils/regression/targeting.py
+++ b/medcat/utils/regression/targeting.py
@@ -25,12 +25,12 @@ class TranslationLayer:
 
     Args:
         cui2names (Dict[str, Set[str]]): The map from CUI to names
-        name2cuis (Dict[str, Set[str]]): The map from name to CUIs
+        name2cuis (Dict[str, List[str]]): The map from name to CUIs
         cui2type_ids (Dict[str, Set[str]]): The map from CUI to type_ids
         cui2children (Dict[str, Set[str]]): The map from CUI to child CUIs
     """
 
-    def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, Set[str]],
+    def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, List[str]],
                  cui2type_ids: Dict[str, Set[str]], cui2children: Dict[str, Set[str]]) -> None:
         self.cui2names = cui2names
         self.name2cuis = name2cuis
diff --git a/medcat/utils/saving/serializer.py b/medcat/utils/saving/serializer.py
index d82df751c..25529c778 100644
--- a/medcat/utils/saving/serializer.py
+++ b/medcat/utils/saving/serializer.py
@@ -135,13 +135,12 @@ def serialize(self, cdb, overwrite: bool = False) -> None:
             raise ValueError(f'Unable to overwrite shelf path "{self.json_path}"'
                              ' - specify overrwrite=True if you wish to overwrite')
         to_save = {}
-        to_save['config'] = cdb.config.asdict()
         # This uses different names so as to not be ambiguous
         # when looking at files whether the json parts should
         # exist separately or not
         to_save['cdb_main' if self.jsons is not None else 'cdb'] = dict(
             ((key, val) for key, val in cdb.__dict__.items() if
-             key != 'config' and
+             key not in ('config', '_config_from_file') and
              (self.jsons is None or key not in SPECIALITY_NAMES)))
         logger.info('Dumping CDB to %s', self.main_path)
         with open(self.main_path, 'wb') as f:
@@ -165,7 +164,17 @@ def deserialize(self, cdb_cls):
         logger.info('Reading CDB data from %s', self.main_path)
         with open(self.main_path, 'rb') as f:
             data = dill.load(f)
-        config = cast(Config, Config.from_dict(data['config']))
+        if 'config' in data:
+            logger.warning("Found config in CDB for model (%s). "
+                           "This is an old format. Please re-save the "
+                           "model in the new format to avoid potential issues",
+                           os.path.dirname(self.main_path))
+            config = cast(Config, Config.from_dict(data['config']))
+        else:
+            # by passing None as config to constructor
+            # the CDB should identify that there has been
+            # no config loaded
+            config = None
         cdb = cdb_cls(config=config)
         if self.jsons is None:
             cdb_main = data['cdb']
diff --git a/medcat/utils/spacy_compatibility.py b/medcat/utils/spacy_compatibility.py
new file mode 100644
index 000000000..a64737f21
--- /dev/null
+++ b/medcat/utils/spacy_compatibility.py
@@ -0,0 +1,211 @@
+"""This module attempts to read the spacy compatibilty of
+a model pack and (if necessary) compare it to the installed
+spacy version.
+"""
+from typing import Tuple, List, cast
+import os
+import re
+from packaging import version
+from packaging.specifiers import SpecifierSet
+
+import spacy
+
+
+SPACY_MODEL_REGEX = re.compile(r"(\w{2}_core_(\w{3,4})_(sm|md|lg|trf|xxl|\w+))|(spacy_model)")
+
+
+def _is_spacy_model_folder(folder_name: str) -> bool:
+    """Check if a folder within a model pack contains a spacy model.
+
+    The idea is to do this without loading the model. That is because
+    the version of the model may be incompatible with what we've got.
+    And as such, loading may not be possible.
+
+    Args:
+        folder_name (str): The folder to check.
+
+    Returns:
+        bool: Whether the folder contains a spacy model.
+    """
+    # since we're trying to identify this solely from the
+    # folder name, we only care about the base name.
+    folder_name = os.path.basename(folder_name)
+    if folder_name.startswith("meta_"):
+        # these are MetaCat stuff (or should be)
+        return False
+    return bool(SPACY_MODEL_REGEX.match(folder_name))
+
+
+def _find_spacy_model_folder(model_pack_folder: str) -> str:
+    """Find the spacy model folder in a model pack folder.
+
+    Args:
+        model_pack_folder (str): The model pack folder
+
+    Raises:
+        ValueError: If it's ambiguous or there's no model folder.
+
+    Returns:
+        str: The full path to the model folder.
+    """
+    options: List[str] = []
+    for folder_name in os.listdir(model_pack_folder):
+        full_folder_path = os.path.join(model_pack_folder, folder_name)
+        if not os.path.isdir(full_folder_path):
+            continue
+        if _is_spacy_model_folder(folder_name):
+            options.append(full_folder_path)
+    if len(options) != 1:
+        raise ValueError("Unable to determine spacy folder name from "
+                         f"{len(options)} ambiguous folders: {options}")
+    return options[0]
+
+
+def get_installed_spacy_version() -> str:
+    """Get the spacy version installed currently.
+
+    Returns:
+        str: The currently installed spacy verison.
+    """
+    return spacy.__version__
+
+
+def get_installed_model_version(model_name: str) -> str:
+    """Get the version of a model installed in spacy.
+
+    Args:
+        model_name (str): The model name.
+
+    Returns:
+        str: The version of the installed model.
+    """
+    if model_name not in spacy.util.get_installed_models():
+        return 'N/A'
+    # NOTE: I don't really know when spacy.info
+    # might return a str instead
+    return cast(dict, spacy.info(model_name))['version']
+
+
+def _get_name_and_meta_of_spacy_model_in_medcat_modelpack(model_pack_path: str) -> Tuple[str, dict]:
+    """Gets the name and meta information about a spacy model within a medcat model pack.
+
+    PS: This gets the raw (folder) name of the spacy model.
+        While this is usually (in models created after v1.2.4)
+        identical to the spacy model version, that may not always
+        be the case.
+
+    Args:
+        model_pack_path (str): The model pack path.
+
+    Returns:
+        Tuple[str, dict]: The name of the spacy model, and the meta information.
+    """
+    spacy_model_folder = _find_spacy_model_folder(model_pack_path)
+    # NOTE: I don't really know when spacy.info
+    # might return a str instead
+    info = cast(dict, spacy.info(spacy_model_folder))
+    return os.path.basename(spacy_model_folder), info
+
+
+def get_name_and_version_of_spacy_model_in_medcat_modelpack(model_pack_path: str) -> Tuple[str, str, str]:
+    """Get the name, version, and compatible spacy versions of a spacy model within a medcat model pack.
+
+    PS: This gets the real name of the spacy model.
+        While this is usually (in models created after v1.2.4)
+        identical to the folder name, that may not always
+        be the case.
+
+    Args:
+        model_pack_path (str): The model pack path.
+
+    Returns:
+        Tuple[str, str, str]: The name of the spacy model, its version, and supported spacy version.
+    """
+    _, info = _get_name_and_meta_of_spacy_model_in_medcat_modelpack(model_pack_path)
+    true_name = info["lang"] + "_" + info['name']
+    return true_name, info['version'], info["spacy_version"]
+
+
+def _is_spacy_version_within_range(spacy_version_range: str) -> bool:
+    """Checks whether the spacy version is within the specified range.
+
+    The expected format of the version range is similar to that used
+    in requirements and/or pip installs. E.g:
+        - >=3.1.0,<3.2.0
+        - ==3.1.0
+        - >=3.1.0
+        - <3.20
+
+    Args:
+        spacy_version_range (str): The requires spacy version range.
+
+    Returns:
+        bool: Whether the specified range is compatible.
+    """
+    spacy_version = version.parse(get_installed_spacy_version())
+    range = SpecifierSet(spacy_version_range)
+    return range.contains(spacy_version)
+
+
+def medcat_model_pack_has_compatible_spacy_model(model_pack_path: str) -> bool:
+    """Checks whether a medcat model pack has a spacy model compatible with installed spacy version.
+
+    Args:
+        model_pack_path (str): The model pack path.
+
+    Returns:
+        bool: Whether the spacy model in the model pack is compatible.
+    """
+    _, _, spacy_range = get_name_and_version_of_spacy_model_in_medcat_modelpack(model_pack_path)
+    return _is_spacy_version_within_range(spacy_range)
+
+
+def is_older_spacy_version(model_version: str) -> bool:
+    """Checks if the specified version is older than the installed version.
+
+    Args:
+        model_version (str): The specified spacy version.
+
+    Returns:
+        bool: Whether the specified version is older.
+    """
+    installed_version = version.parse(get_installed_spacy_version())
+    model_version = version.parse(model_version)
+    return model_version <= installed_version
+
+
+def medcat_model_pack_has_semi_compatible_spacy_model(model_pack_path: str) -> bool:
+    """Checks whether the spacy model within a medcat model pack is
+        compatible or older than the installed spacy version.
+
+    This method returns `True` if the spacy model is compatible or
+    released with a lower version number compared to the spacy
+    version currently installed.
+
+    We've found that most of the time older models will work with
+    a newer version of spacy. Though there is a warning on spacy's
+    side and they do not guarantee 100% compatibility, we've not
+    seen issues so far.
+
+    E.g for installed spacy 3.4.4 all the following will be suiable:
+        - en_core_web_md-3.1.0
+        - en_core_web_md-3.2.0
+        - en_core_web_md-3.3.0
+        - en_core_web_md-3.4.1
+    However, for the same version, the following would not be suitable:
+        - en_core_web_md-3.5.0
+        - en_core_web_md-3.6.0
+        - en_core_web_md-3.7.1
+
+    Args:
+        model_pack_path (str): The model pack path.
+
+    Returns:
+        bool: Whether the spacy model in the model pack is compatible.
+    """
+    (_,
+     model_version,
+     spacy_range) = get_name_and_version_of_spacy_model_in_medcat_modelpack(model_pack_path)
+    if _is_spacy_version_within_range(spacy_range):
+        return True
+    return is_older_spacy_version(model_version)
diff --git a/setup.py b/setup.py
index 8b152cb77..34963943a 100644
--- a/setup.py
+++ b/setup.py
@@ -16,7 +16,7 @@
     url="https://github.com/CogStack/MedCAT",
     packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets',
               'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner',
-              'medcat.utils.saving', 'medcat.utils.regression'],
+              'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'],
     install_requires=[
         'numpy>=1.22.0', # first to support 3.11
         'pandas>=1.4.2', # first to support 3.11
@@ -40,12 +40,6 @@
         'blis>=0.7.5', # allow later versions, tested with 0.7.9
         'click>=8.0.4', # allow later versions, tested with 8.1.3
         'pydantic>=1.10.0,<2.0', # for spacy compatibility; avoid 2.0 due to breaking changes
-        # the following are not direct dependencies of MedCAT but needed for docs/building
-        # hopefully will no longer need the transitive dependencies
-        'aiohttp==3.8.5', # 3.8.3 is needed for compatibility with fsspec <- datasets <- medcat
-        'blis<0.8.0,>=0.7.8', # as required by thinc <- spacy <- medcat
-        # 'smart-open==5.2.1', # 5.2.1 is needed for compatibility with pathy
-        # 'joblib~=1.2',
         ],
     classifiers=[
         "Programming Language :: Python :: 3",
diff --git a/tests/archive_tests/test_cdb_maker_archive.py b/tests/archive_tests/test_cdb_maker_archive.py
index 329408999..9e2fc2d72 100644
--- a/tests/archive_tests/test_cdb_maker_archive.py
+++ b/tests/archive_tests/test_cdb_maker_archive.py
@@ -108,7 +108,7 @@ def test_concept_similarity(self):
         for i in range(500):
             cui = "C" + str(i)
             type_ids = {'T-' + str(i%10)}
-            cdb.add_concept(cui=cui, names=prepare_name('Name: ' + str(i), self.maker.pipe.get_spacy_nlp(), {}, self.config), ontologies=set(),
+            cdb._add_concept(cui=cui, names=prepare_name('Name: ' + str(i), self.maker.pipe.get_spacy_nlp(), {}, self.config), ontologies=set(),
                             name_status='P', type_ids=type_ids, description='', full_build=True)
 
             vectors = {}
diff --git a/tests/helper.py b/tests/helper.py
index 9fb66589b..52943c3cd 100644
--- a/tests/helper.py
+++ b/tests/helper.py
@@ -6,6 +6,8 @@
 import numpy as np
 
 from medcat.vocab import Vocab
+from medcat.cdb_maker import CDBMaker
+from medcat.config import Config
 
 
 class AsyncMock(unittest.mock.MagicMock):
@@ -86,3 +88,36 @@ def check_or_download(self):
             return
         with open(self.vocab_path, 'wb') as f:
             f.write(tmp.content)
+
+
+class ForCDBMerging:
+
+    def __init__(self) -> None:
+        # generating cdbs - two maker are requested as they point to the same created CDB. 
+        config = Config()
+        config.general["spacy_model"] = "en_core_web_md"
+        maker1 = CDBMaker(config)
+        maker2 = CDBMaker(config) # second maker is required as it will otherwise point to same object
+        path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "model_creator", "umls_sample.csv")
+        self.cdb1 = maker1.prepare_csvs(csv_paths=[path])
+        self.cdb2 = maker2.prepare_csvs(csv_paths=[path])
+
+        # generating context vectors here for for testing the weighted average function (based off cui2count_train)
+        zeroes = np.zeros(shape=(1,300))
+        ones = np.ones(shape=(1,300))
+        for i, cui in enumerate(self.cdb1.cui2names):
+            self.cdb1.cui2context_vectors[cui] = {"short": ones}
+            self.cdb2.cui2context_vectors[cui] = {"short": zeroes}
+            self.cdb1.cui2count_train[cui] = 1
+            self.cdb2.cui2count_train[cui] = i + 1
+        # adding new names and cuis to each cdb to test after merging
+        test_add = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
+        self.cdb1.add_names("C0006826", test_add)
+        unique_test = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
+        self.cdb2.add_names("UniqueTest", unique_test)
+        self.cdb2.cui2context_vectors["UniqueTest"] = {"short": zeroes}
+        self.cdb2.addl_info["cui2ontologies"] = {}
+        self.cdb2.addl_info["cui2description"] = {}
+        for cui in self.cdb2.cui2names:
+            self.cdb2.addl_info["cui2ontologies"][cui] = {"test_ontology"}
+            self.cdb2.addl_info["cui2description"][cui] = "test_description"
diff --git a/tests/ner/test_transformers_ner.py b/tests/ner/test_transformers_ner.py
new file mode 100644
index 000000000..de9eae32c
--- /dev/null
+++ b/tests/ner/test_transformers_ner.py
@@ -0,0 +1,50 @@
+import os
+import unittest
+from spacy.lang.en import English
+from spacy.tokens import Doc, Span
+from transformers import TrainerCallback
+from medcat.ner.transformers_ner import TransformersNER
+from medcat.config import Config
+from medcat.cdb_maker import CDBMaker
+
+
+class TransformerNERTest(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        config = Config()
+        config.general["spacy_model"] = "en_core_web_md"
+        cdb_maker = CDBMaker(config)
+        cdb_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "cdb.csv")
+        cdb = cdb_maker.prepare_csvs([cdb_csv], full_build=True)
+        Doc.set_extension("ents", default=[], force=True)
+        Span.set_extension("confidence", default=-1, force=True)
+        Span.set_extension("id", default=0, force=True)
+        Span.set_extension("detected_name", default=None, force=True)
+        Span.set_extension("link_candidates", default=None, force=True)
+        Span.set_extension("cui", default=-1, force=True)
+        Span.set_extension("context_similarity", default=-1, force=True)
+        cls.undertest = TransformersNER(cdb)
+        cls.undertest.create_eval_pipeline()
+
+    def test_pipe(self):
+        doc = English().make_doc("\nPatient Name: John Smith\nAddress: 15 Maple Avenue\nCity: New York\nCC: Chronic back pain\n\nHX: Mr. Smith")
+        doc = next(self.undertest.pipe([doc]))
+        assert len(doc.ents) > 0, "No entities were recognised"
+
+    def test_train(self):
+        tracker = unittest.mock.Mock()
+        class _DummyCallback(TrainerCallback):
+            def __init__(self, trainer) -> None:
+                self._trainer = trainer
+            def on_epoch_end(self, *args, **kwargs) -> None:
+                tracker.call()
+
+        train_data = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "resources", "deid_train_data.json")
+        self.undertest.training_arguments.num_train_epochs = 1
+        df, examples, dataset = self.undertest.train(train_data, trainer_callbacks=[_DummyCallback, _DummyCallback])
+        assert "fp" in examples
+        assert "fn" in examples
+        assert dataset["train"].num_rows == 48
+        assert dataset["test"].num_rows == 12
+        self.assertEqual(tracker.call.call_count, 2)
diff --git a/tests/preprocessing/__init__.py b/tests/preprocessing/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/preprocessing/test_cleaners.py b/tests/preprocessing/test_cleaners.py
new file mode 100644
index 000000000..b879d9ee6
--- /dev/null
+++ b/tests/preprocessing/test_cleaners.py
@@ -0,0 +1,104 @@
+from medcat.preprocessing.cleaners import prepare_name
+from medcat.config import Config
+from medcat.cdb_maker import CDBMaker
+
+import logging, os
+
+import unittest
+
+
+class BaseCDBMakerTests(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        config = Config()
+        config.general['log_level'] = logging.DEBUG
+        config.general["spacy_model"] = "en_core_web_md"
+        cls.maker = CDBMaker(config)
+        csvs = [
+            os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', 'examples', 'cdb.csv'),
+            os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', 'examples', 'cdb_2.csv')
+        ]
+        cls.cdb = cls.maker.prepare_csvs(csvs, full_build=True)
+
+
+class BasePrepareNameTest(BaseCDBMakerTests):
+    raw_name = 'raw'
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        cls.do_prepare_name()
+
+    # method called after setup, when raw_name has been specified
+    @classmethod
+    def do_prepare_name(cls) -> None:
+        cls.name = cls.cdb.config.general.separator.join(cls.raw_name.split())
+        cls.names = prepare_name(cls.raw_name, cls.maker.pipe.spacy_nlp, {}, cls.cdb.config)
+
+    def _dict_has_key_val_type(self, d: dict, key, val_type):
+        self.assertIn(key, d)
+        self.assertIsInstance(d[key], val_type)
+
+    def _names_has_key_val_type(self, key, val_type):
+        self._dict_has_key_val_type(self.names, key, val_type)
+
+    def test_result_has_name(self):
+        self._names_has_key_val_type(self.name, dict)
+
+    def test_name_info_has_tokens(self):
+        self._dict_has_key_val_type(self.names[self.name], 'tokens', list)
+
+    def test_name_info_has_words_as_tokens(self):
+        name_info = self.names[self.name]
+        tokens = name_info['tokens']
+        for word in self.raw_name.split():
+            with self.subTest(word):
+                self.assertIn(word, tokens)
+    
+
+class NamePreparationTests_OneLetter(BasePrepareNameTest):
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        cls.raw_name = "a"
+        # the minimum name length is defined by the following config option
+        # if I don't set this to 1 here, I would see the tests fail
+        # that would be because the result from `prepare_names` would be empty
+        cls.cdb.config.cdb_maker.min_letters_required = 1
+        super().do_prepare_name()
+    
+
+class NamePreparationTests_TwoLetters(BasePrepareNameTest):
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        cls.raw_name = "an"
+        super().do_prepare_name()
+    
+
+class NamePreparationTests_MultiToken(BasePrepareNameTest):
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        cls.raw_name = "this raw name"
+        super().do_prepare_name()
+    
+
+class NamePreparationTests_Empty(BaseCDBMakerTests):
+    """In case of an empty name, I would expect the names dict
+    returned by `prepare_name` to be empty.
+    """
+    empty_raw_name = ''
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        cls.names = prepare_name(cls.empty_raw_name, cls.maker.pipe.spacy_nlp, {}, cls.cdb.config)
+
+    def test_names_dict_is_empty(self):
+        self.assertEqual(len(self.names), 0)
+        self.assertEqual(self.names, {})
diff --git a/tests/resources/ff_core_fake_dr/meta.json b/tests/resources/ff_core_fake_dr/meta.json
new file mode 100644
index 000000000..fe9825db7
--- /dev/null
+++ b/tests/resources/ff_core_fake_dr/meta.json
@@ -0,0 +1,8 @@
+{
+    "lang":"ff",
+    "name":"core_fake_dr",
+    "version":"3.1.0",
+    "description":"This is a FAKE model",
+    "author":"Fakio Martimus",
+    "spacy_version":">=3.1.0,<3.2.0"
+  }
\ No newline at end of file
diff --git a/tests/test_cat.py b/tests/test_cat.py
index 0baa0d35d..bc49a2808 100644
--- a/tests/test_cat.py
+++ b/tests/test_cat.py
@@ -4,10 +4,14 @@
 import unittest
 import tempfile
 import shutil
+import logging
+import contextlib
 from transformers import AutoTokenizer
 from medcat.vocab import Vocab
-from medcat.cdb import CDB
-from medcat.cat import CAT
+from medcat.cdb import CDB, logger as cdb_logger
+from medcat.cat import CAT, logger as cat_logger
+from medcat.config import Config
+from medcat.pipe import logger as pipe_logger
 from medcat.utils.checkpoint import Checkpoint
 from medcat.meta_cat import MetaCAT
 from medcat.config_meta_cat import ConfigMetaCAT
@@ -15,11 +19,13 @@
 
 
 class CATTests(unittest.TestCase):
+    SUPERVISED_TRAINING_JSON = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export.json")
 
     @classmethod
     def setUpClass(cls) -> None:
         cls.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat"))
         cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat"))
+        cls.vocab.make_unigram_table()
         cls.cdb.config.general.spacy_model = "en_core_web_md"
         cls.cdb.config.ner.min_name_len = 2
         cls.cdb.config.ner.upper_case_limit_len = 3
@@ -36,7 +42,8 @@ def setUpClass(cls) -> None:
     @classmethod
     def tearDownClass(cls) -> None:
         cls.undertest.destroy_pipe()
-        shutil.rmtree(cls.meta_cat_dir)
+        if os.path.exists(cls.meta_cat_dir):
+            shutil.rmtree(cls.meta_cat_dir)
 
     def tearDown(self) -> None:
         self.cdb.config.annotation_output.include_text_in_output = False
@@ -60,7 +67,7 @@ def test_multiprocessing(self):
             (2, ""),
             (3, None)
         ]
-        out = self.undertest.multiprocessing(in_data, nproc=1)
+        out = self.undertest.multiprocessing_batch_char_size(in_data, nproc=1)
 
         self.assertEqual(3, len(out))
         self.assertEqual(1, len(out[1]['entities']))
@@ -73,7 +80,7 @@ def test_multiprocessing_pipe(self):
             (2, "The dog is sitting outside the house."),
             (3, "The dog is sitting outside the house."),
         ]
-        out = self.undertest.multiprocessing_pipe(in_data, nproc=2, return_dict=False)
+        out = self.undertest.multiprocessing_batch_docs_size(in_data, nproc=2, return_dict=False)
         self.assertTrue(type(out) == list)
         self.assertEqual(3, len(out))
         self.assertEqual(1, out[0][0])
@@ -89,7 +96,7 @@ def test_multiprocessing_pipe_with_malformed_texts(self):
             (2, ""),
             (3, None),
         ]
-        out = self.undertest.multiprocessing_pipe(in_data, nproc=1, batch_size=1, return_dict=False)
+        out = self.undertest.multiprocessing_batch_docs_size(in_data, nproc=1, batch_size=1, return_dict=False)
         self.assertTrue(type(out) == list)
         self.assertEqual(3, len(out))
         self.assertEqual(1, out[0][0])
@@ -105,7 +112,7 @@ def test_multiprocessing_pipe_return_dict(self):
             (2, "The dog is sitting outside the house."),
             (3, "The dog is sitting outside the house.")
         ]
-        out = self.undertest.multiprocessing_pipe(in_data, nproc=2, return_dict=True)
+        out = self.undertest.multiprocessing_batch_docs_size(in_data, nproc=2, return_dict=True)
         self.assertTrue(type(out) == dict)
         self.assertEqual(3, len(out))
         self.assertEqual({'entities': {}, 'tokens': []}, out[1])
@@ -211,7 +218,7 @@ def test_get_entities_multi_texts_including_text(self):
     def test_train_supervised(self):
         nepochs = 3
         num_of_documents = 27
-        data_path = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export.json")
+        data_path = self.SUPERVISED_TRAINING_JSON
         ckpt_dir_path = tempfile.TemporaryDirectory().name
         checkpoint = Checkpoint(dir_path=ckpt_dir_path, steps=1, max_to_keep=sys.maxsize)
         fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised(data_path,
@@ -367,7 +374,7 @@ def test_load_model_pack(self):
         meta_cat = _get_meta_cat(self.meta_cat_dir)
         cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat])
         full_model_pack_name = cat.create_model_pack(save_dir_path.name, model_pack_name="mp_name")
-        cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"))
+        cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"))
         self.assertTrue(isinstance(cat, CAT))
         self.assertIsNotNone(cat.config.version.medcat_version)
         self.assertEqual(repr(cat._meta_cats), repr([meta_cat]))
@@ -377,7 +384,7 @@ def test_load_model_pack_without_meta_cat(self):
         meta_cat = _get_meta_cat(self.meta_cat_dir)
         cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat])
         full_model_pack_name = cat.create_model_pack(save_dir_path.name, model_pack_name="mp_name")
-        cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"), load_meta_models=False)
+        cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"), load_meta_models=False)
         self.assertTrue(isinstance(cat, CAT))
         self.assertIsNotNone(cat.config.version.medcat_version)
         self.assertEqual(cat._meta_cats, [])
@@ -385,9 +392,208 @@ def test_load_model_pack_without_meta_cat(self):
     def test_hashing(self):
         save_dir_path = tempfile.TemporaryDirectory()
         full_model_pack_name = self.undertest.create_model_pack(save_dir_path.name, model_pack_name="mp_name")
-        cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"))
+        cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"))
         self.assertEqual(cat.get_hash(), cat.config.version.id)
 
+    def test_print_stats(self):
+        # based on current JSON
+        EXP_FALSE_NEGATIVES = {'C0017168': 2, 'C0020538': 43, 'C0038454': 4, 'C0007787': 1, 'C0155626': 4, 'C0011860': 12,
+                               'C0042029': 6, 'C0010068': 2, 'C0007222': 1, 'C0027051': 6, 'C0878544': 1, 'C0020473': 12,
+                               'C0037284': 21, 'C0003864': 4, 'C0011849': 12, 'C0005686': 1, 'C0085762': 3, 'C0030920': 2,
+                               'C0854135': 3, 'C0004096': 4, 'C0010054': 10, 'C0497156': 10, 'C0011334': 2, 'C0018939': 1,
+                               'C1561826': 2, 'C0276289': 2, 'C0041834': 9, 'C0000833': 2, 'C0238792': 1, 'C0040034': 3,
+                               'C0035078': 5, 'C0018799': 5, 'C0042109': 1, 'C0035439': 1, 'C0035435': 1, 'C0018099': 1,
+                               'C1277187': 1, 'C0024117': 7, 'C0004238': 4, 'C0032227': 6, 'C0008679': 1, 'C0013146': 6,
+                               'C0032285': 1, 'C0002871': 7, 'C0149871': 4, 'C0442886': 1, 'C0022104': 1, 'C0034065': 5,
+                               'C0011854': 6, 'C1398668': 1, 'C0020676': 2, 'C1301700': 1, 'C0021167': 1, 'C0029456': 2,
+                               'C0011570': 10, 'C0009324': 1, 'C0011882': 1, 'C0020615': 1, 'C0242510': 2, 'C0033581': 2,
+                               'C0011168': 3, 'C0039082': 2, 'C0009241': 2, 'C1404970': 1, 'C0018524': 3, 'C0150063': 1,
+                               'C0917799': 1, 'C0178417': 1, 'C0033975': 1, 'C0011253': 1, 'C0018802': 8, 'C0022661': 4,
+                               'C0017658': 1, 'C0023895': 2, 'C0003123': 1, 'C0041582': 4, 'C0085096': 4, 'C0403447': 2,
+                               'C2363741': 2, 'C0457949': 1, 'C0040336': 1, 'C0037315': 2, 'C0024236': 3, 'C0442874': 1,
+                               'C0028754': 4, 'C0520679': 5, 'C0028756': 2, 'C0029408': 5, 'C0409959': 2, 'C0018801': 1, 
+                               'C3844825': 1, 'C0022660': 2, 'C0005779': 4, 'C0011175': 1, 'C0018965': 4, 'C0018889': 1,
+                               'C0022354': 2, 'C0033377': 1, 'C0042769': 1, 'C0035222': 1, 'C1456868': 2, 'C1145670': 1,
+                               'C0018790': 1, 'C0263746': 1, 'C0206172': 1, 'C0021400': 1, 'C0243026': 1, 'C0020443': 1,
+                               'C0001883': 1, 'C0031350': 1, 'C0010709': 4, 'C1565489': 7, 'C3489393': 1, 'C0005586': 2,
+                               'C0158288': 5, 'C0700594': 4, 'C0158266': 3, 'C0006444': 2, 'C0024003': 1}
+        with open(self.SUPERVISED_TRAINING_JSON) as f:
+            data = json.load(f)
+        (fps, fns, tps,
+         cui_prec, cui_rec, cui_f1,
+         cui_counts, examples) = self.undertest._print_stats(data)
+        self.assertEqual(fps, {})
+        self.assertEqual(fns, EXP_FALSE_NEGATIVES)
+        self.assertEqual(tps, {})
+        self.assertEqual(cui_prec, {})
+        self.assertEqual(cui_rec, {})
+        self.assertEqual(cui_f1, {})
+        self.assertEqual(len(cui_counts), 136)
+        self.assertEqual(len(examples), 3)
+
+    def _assertNoLogs(self, logger: logging.Logger, level: int):
+        if hasattr(self, 'assertNoLogs'):
+            return self.assertNoLogs(logger=logger, level=level)
+        else:
+            return self.__assertNoLogs(logger=logger, level=level)
+    
+    @contextlib.contextmanager
+    def __assertNoLogs(self, logger: logging.Logger, level: int):
+        try:
+            with self.assertLogs(logger, level) as captured_logs:
+                yield
+        except AssertionError:
+            return
+        if captured_logs:
+            raise AssertionError("Logs were found: {}".format(captured_logs))
+
+    def assertLogsDuringAddAndTrainConcept(self, logger: logging.Logger, log_level,
+                                           name: str, name_status: str, nr_of_calls: int):
+        cui = 'CUI-%d'%(hash(name) + id(name))
+        with (self.assertLogs(logger=logger, level=log_level)
+              if nr_of_calls == 1
+              else self._assertNoLogs(logger=logger, level=log_level)):
+            self.undertest.add_and_train_concept(cui, name, name_status=name_status)
+
+    def test_add_and_train_concept_cat_nowarn_long_name(self):
+        long_name = 'a very long name'
+        self.assertLogsDuringAddAndTrainConcept(cat_logger, logging.WARNING, name=long_name, name_status='', nr_of_calls=0)
+
+    def test_add_and_train_concept_cdb_nowarn_long_name(self):
+        long_name = 'a very long name'
+        self.assertLogsDuringAddAndTrainConcept(cdb_logger, logging.WARNING, name=long_name, name_status='', nr_of_calls=0)
+
+    def test_add_and_train_concept_cat_nowarn_short_name_not_pref(self):
+        short_name = 'a'
+        self.assertLogsDuringAddAndTrainConcept(cat_logger, logging.WARNING, name=short_name, name_status='', nr_of_calls=0)
+
+    def test_add_and_train_concept_cdb_nowarn_short_name_not_pref(self):
+        short_name = 'a'
+        self.assertLogsDuringAddAndTrainConcept(cdb_logger, logging.WARNING, name=short_name, name_status='', nr_of_calls=0)
+
+    def test_add_and_train_concept_cat_warns_short_name(self):
+        short_name = 'a'
+        self.assertLogsDuringAddAndTrainConcept(cat_logger, logging.WARNING, name=short_name, name_status='P', nr_of_calls=1)
+
+    def test_add_and_train_concept_cdb_warns_short_name(self):
+        short_name = 'a'
+        self.assertLogsDuringAddAndTrainConcept(cdb_logger, logging.WARNING, name=short_name, name_status='P', nr_of_calls=1)
+
+
+class GetEntitiesWithStopWords(unittest.TestCase):
+    # NB! The order in which the different CDBs are created
+    # is important here since the way that the stop words are
+    # set is class-based, it creates the side effect of having
+    # the same stop words the next time around
+    # regardless of whether or not they should've been set
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls.cdb1 = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat"))
+        cls.cdb2 = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat"))
+        cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat"))
+        cls.vocab.make_unigram_table()
+        cls.cdb1.config.general.spacy_model = "en_core_web_md"
+        cls.cdb1.config.ner.min_name_len = 2
+        cls.cdb1.config.ner.upper_case_limit_len = 3
+        cls.cdb1.config.general.spell_check = True
+        cls.cdb1.config.linking.train_count_threshold = 10
+        cls.cdb1.config.linking.similarity_threshold = 0.3
+        cls.cdb1.config.linking.train = True
+        cls.cdb1.config.linking.disamb_length_limit = 5
+        cls.cdb1.config.general.full_unlink = True
+        cls.cdb2.config = Config.from_dict(cls.cdb1.config.asdict())
+        # the regular CAT without stopwords
+        cls.no_stopwords = CAT(cdb=cls.cdb1, config=cls.cdb1.config, vocab=cls.vocab, meta_cats=[])
+        # this (the following two lines)
+        # needs to be done before initialising the CAT
+        # since that initialises the pipe
+        cls.cdb2.config.preprocessing.stopwords = {"stop", "words"}
+        cls.cdb2.config.preprocessing.skip_stopwords = True
+        # the CAT that skips the stopwords
+        cls.w_stopwords = CAT(cdb=cls.cdb2, config=cls.cdb2.config, vocab=cls.vocab, meta_cats=[])
+
+    def test_stopwords_are_skipped(self, text: str = "second words csv"):
+        # without stopwords no entities are captured
+        # with stopwords, the `second words csv` entity is captured
+        doc_no_stopwords = self.no_stopwords(text)
+        doc_w_stopwords = self.w_stopwords(text)
+        self.assertGreater(len(doc_w_stopwords._.ents), len(doc_no_stopwords._.ents))
+
+
+class ModelWithTwoConfigsLoadTests(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls.model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples")
+        cdb = CDB.load(os.path.join(cls.model_path, "cdb.dat"))
+        # save config next to the CDB
+        cls.config_path = os.path.join(cls.model_path, 'config.json')
+        cdb.config.save(cls.config_path)
+
+
+    @classmethod
+    def tearDownClass(cls) -> None:
+        # REMOVE config next to the CDB
+        os.remove(cls.config_path)
+
+    def test_loading_model_pack_with_cdb_config_and_config_json_raises_exception(self):
+        with self.assertRaises(ValueError):
+            CAT.load_model_pack(self.model_path)
+
+
+class ModelLoadsUnreadableSpacy(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls.temp_dir = tempfile.TemporaryDirectory()
+        model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples")
+        cdb = CDB.load(os.path.join(model_path, 'cdb.dat'))
+        cdb.config.general.spacy_model = os.path.join(cls.temp_dir.name, "en_core_web_md")
+        # save CDB in new location
+        cdb.save(os.path.join(cls.temp_dir.name, 'cdb.dat'))
+        # save config in new location
+        cdb.config.save(os.path.join(cls.temp_dir.name, 'config.json'))
+        # copy vocab into new location
+        vocab_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat")
+        cls.vocab_path = os.path.join(cls.temp_dir.name, 'vocab.dat')
+        shutil.copyfile(vocab_path, cls.vocab_path)
+
+    @classmethod
+    def tearDownClass(cls) -> None:
+        # REMOVE temp dir
+        cls.temp_dir.cleanup()
+
+    def test_loads_without_specified_spacy_model(self):
+        with self.assertLogs(logger=pipe_logger, level=logging.WARNING):
+            cat = CAT.load_model_pack(self.temp_dir.name)
+        self.assertTrue(isinstance(cat, CAT))
+
+
+class ModelWithZeroConfigsLoadTests(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat")
+        cdb = CDB.load(cdb_path)
+        vocab_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat")
+        # copy the CDB and vocab to a temp dir
+        cls.temp_dir = tempfile.TemporaryDirectory()
+        cls.cdb_path = os.path.join(cls.temp_dir.name, 'cdb.dat')
+        cdb.save(cls.cdb_path) # save without internal config
+        cls.vocab_path = os.path.join(cls.temp_dir.name, 'vocab.dat')
+        shutil.copyfile(vocab_path, cls.vocab_path)
+
+
+    @classmethod
+    def tearDownClass(cls) -> None:
+        # REMOVE temp dir
+        cls.temp_dir.cleanup()
+
+    def test_loading_model_pack_without_any_config_raises_exception(self):
+        with self.assertRaises(ValueError):
+            CAT.load_model_pack(self.temp_dir.name)
+
 
 def _get_meta_cat(meta_cat_dir):
     config = ConfigMetaCAT()
diff --git a/tests/test_cdb.py b/tests/test_cdb.py
index 96425bc8c..1be74edfe 100644
--- a/tests/test_cdb.py
+++ b/tests/test_cdb.py
@@ -6,6 +6,7 @@
 import numpy as np
 from medcat.config import Config
 from medcat.cdb_maker import CDBMaker
+from medcat.cdb import CDB
 
 
 class CDBTests(unittest.TestCase):
@@ -21,11 +22,21 @@ def setUp(self) -> None:
         cdb_2_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb_2.csv")
         self.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
         os.makedirs(self.tmp_dir, exist_ok=True)
+        # resetting the CDB because otherwise the CDBMaker
+        # will refer to and modify the same instance of the CDB
+        # and this can (and does!) create side effects
+        CDBTests.cdb_maker.reset_cdb()
         self.undertest = CDBTests.cdb_maker.prepare_csvs([cdb_csv, cdb_2_csv], full_build=True)
 
     def tearDown(self) -> None:
         shutil.rmtree(self.tmp_dir)
 
+    def test_setup_changes_cdb(self):
+        id1 = id(self.undertest)
+        self.setUp()
+        id2 = id(self.undertest)
+        self.assertNotEqual(id1, id2)
+
     def test_name2cuis(self):
         self.assertEqual({
             'second~csv': ['C0000239'],
@@ -53,6 +64,13 @@ def test_save_and_load(self):
             self.undertest.save(f.name)
             self.undertest.load(f.name)
 
+    def test_load_has_no_config(self):
+        with tempfile.NamedTemporaryFile() as f:
+            self.undertest.save(f.name)
+            cdb = CDB.load(f.name)
+            self.assertFalse(cdb._config_from_file)
+
+
     def test_save_async_and_load(self):
         with tempfile.NamedTemporaryFile() as f:
             asyncio.run(self.undertest.save_async(f.name))
diff --git a/tests/test_config.py b/tests/test_config.py
index 2f9cd5a84..ce6ed76eb 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,7 +1,7 @@
 import unittest
 import pickle
 import tempfile
-from medcat.config import Config, MixingConfig, VersionInfo, General
+from medcat.config import Config, MixingConfig, VersionInfo, General, LinkingFilters
 from pydantic import ValidationError
 import os
 
@@ -179,6 +179,54 @@ def test_from_dict(self):
         config = Config.from_dict({"key": "value"})
         self.assertEqual("value", config.key)
 
+    def test_config_no_hash_before_get(self):
+        config = Config()
+        self.assertIsNone(config.hash)
+
+    def test_config_has_hash_after_get(self):
+        config = Config()
+        config.get_hash()
+        self.assertIsNotNone(config.hash)
+
+    def test_config_hash_recalc_same_def(self):
+        config = Config()
+        h1 = config.get_hash()
+        h2 = config.get_hash()
+        self.assertEqual(h1, h2)
+
+    def test_config_hash_changes_after_change(self):
+        config = Config()
+        h1 = config.get_hash()
+        config.linking.filters.cuis = {"a", "b"}
+        h2 = config.get_hash()
+        self.assertNotEqual(h1, h2)
+
+    def test_config_hash_recalc_same_changed(self):
+        config = Config()
+        config.linking.filters.cuis = {"a", "b"}
+        h1 = config.get_hash()
+        h2 = config.get_hash()
+        self.assertEqual(h1, h2)
+
+
+class ConfigLinkingFiltersTests(unittest.TestCase):
+
+    def test_allows_empty_dict_for_cuis(self):
+        lf = LinkingFilters(cuis={})
+        self.assertIsNotNone(lf)
+
+    def test_empty_dict_converted_to_empty_set(self):
+        lf = LinkingFilters(cuis={})
+        self.assertEqual(lf.cuis, set())
+
+    def test_not_allow_nonempty_dict_for_cuis(self):
+        with self.assertRaises(ValidationError):
+            LinkingFilters(cuis={"KEY": "VALUE"})
+
+    def test_not_allow_empty_dict_for_cuis_exclude(self):
+        with self.assertRaises(ValidationError):
+            LinkingFilters(cuis_exclude={})
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/tests/test_pipe.py b/tests/test_pipe.py
index e6da42898..8ce47cfb5 100644
--- a/tests/test_pipe.py
+++ b/tests/test_pipe.py
@@ -28,6 +28,7 @@ def setUpClass(cls) -> None:
         cls.config.ner['max_skip_tokens'] = 1
         cls.config.ner['upper_case_limit_len'] = 4
         cls.config.linking['disamb_length_limit'] = 2
+        cls.config.preprocessing.stopwords = {'stop', 'words'}
         cls.cdb = CDB(config=cls.config)
 
         downloader = VocabDownloader()
@@ -42,7 +43,7 @@ def setUpClass(cls) -> None:
         _tokenizer = TokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased"))
         cls.meta_cat = MetaCAT(tokenizer=_tokenizer)
 
-        cls.text = "CDB - I was running and then Movar Virus attacked and CDb"
+        cls.text = "stop of CDB - I was running and then Movar Virus attacked and CDb"
         cls.undertest = Pipe(tokenizer=spacy_split_all, config=cls.config)
 
     @classmethod
@@ -81,6 +82,12 @@ def test_add_meta_cat(self):
         PipeTests.undertest.add_meta_cat(PipeTests.meta_cat)
 
         self.assertEqual(PipeTests.meta_cat.name, Language.get_factory_meta(PipeTests.meta_cat.name).factory)
+    
+    def test_stopwords_loading(self):
+        self.assertEqual(PipeTests.undertest._nlp.Defaults.stop_words, PipeTests.config.preprocessing.stopwords)
+        doc = PipeTests.undertest(PipeTests.text)
+        self.assertEqual(doc[0].is_stop, True)
+        self.assertEqual(doc[1].is_stop, False)
 
     def test_batch_multi_process(self):
         PipeTests.undertest.add_tagger(tagger=tag_skip_and_punct, additional_fields=["is_punct"])
diff --git a/tests/utils/saving/test_serialization.py b/tests/utils/saving/test_serialization.py
index f0cc75de1..c2c44da16 100644
--- a/tests/utils/saving/test_serialization.py
+++ b/tests/utils/saving/test_serialization.py
@@ -87,7 +87,7 @@ def test_dill_to_json(self):
         model_pack_folder = os.path.join(
             self.json_model_pack.name, model_pack_path)
         json_path = os.path.join(model_pack_folder, "*.json")
-        jsons = glob.glob(json_path)
+        jsons = [fn for fn in glob.glob(json_path) if not fn.endswith("config.json")]
         # there is also a model_card.json
         # but nothing for cui2many or name2many
         # so can remove the length of ONE2MANY
diff --git a/tests/utils/test_cdb_utils.py b/tests/utils/test_cdb_utils.py
new file mode 100644
index 000000000..777a2506b
--- /dev/null
+++ b/tests/utils/test_cdb_utils.py
@@ -0,0 +1,43 @@
+import unittest
+import numpy as np
+from tests.helper import ForCDBMerging
+from medcat.utils.cdb_utils import merge_cdb
+
+
+class CDBMergeTests(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        to_merge = ForCDBMerging()
+        cls.cdb1 = to_merge.cdb1
+        cls.cdb2 = to_merge.cdb2
+        cls.merged_cdb = merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2)
+        cls.overwrite_cdb = merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2, overwrite_training=2, full_build=True)
+        cls.zeroes = np.zeros(shape=(1,300))
+        cls.ones = np.ones(shape=(1,300))
+
+    def test_merge_inserts(self):
+        self.assertIn("test", self.merged_cdb.cui2names["C0006826"])
+        self.assertIn("test_name", self.merged_cdb.cui2snames["C0006826"])
+        self.assertEqual("Cancer", self.merged_cdb.cui2preferred_name["C0006826"])
+
+    def test_no_full_build(self):
+        self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict())
+        self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict())
+
+    def test_full_build(self):
+        for cui in self.cdb2.cui2names:
+            self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"})
+            self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description")
+
+    def test_vector_merge(self):
+        self.assertTrue(np.array_equal(self.zeroes, self.merged_cdb.cui2context_vectors["UniqueTest"]["short"]))
+        for i, cui in enumerate(self.cdb1.cui2names):
+            self.assertTrue(np.array_equal(self.merged_cdb.cui2context_vectors[cui]["short"], np.divide(self.ones, i+2)))
+
+
+    def test_overwrite_parameter(self):
+        for cui in self.cdb2.cui2names:
+            self.assertTrue(np.array_equal(self.overwrite_cdb.cui2context_vectors[cui]["short"], self.zeroes))
+            self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"})
+            self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description")
diff --git a/tests/utils/test_hashing.py b/tests/utils/test_hashing.py
index 99c10b153..0fd6b5891 100644
--- a/tests/utils/test_hashing.py
+++ b/tests/utils/test_hashing.py
@@ -1,4 +1,5 @@
 import os
+from typing import Optional
 import tempfile
 import unittest
 import unittest.mock
@@ -6,6 +7,7 @@
 from medcat.cat import CAT
 from medcat.cdb import CDB
 from medcat.vocab import Vocab
+from medcat.config import Config
 
 
 class CDBHashingTests(unittest.TestCase):
@@ -30,6 +32,43 @@ def test_CDB_hash_saves_on_disk(self):
         self.assertEqual(h, cdb._hash)
 
 
+class CDBHashingWithConfigTests(unittest.TestCase):
+    temp_dir = tempfile.TemporaryDirectory()
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls.cdb = CDB.load(os.path.join(os.path.dirname(
+            os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat"))
+        # ensure config has hash
+        h = cls.cdb.get_hash()
+        cls.config = cls.config_copy(cls.cdb.config)
+        cls._config_hash = cls.cdb.config.hash
+
+    @classmethod
+    def config_copy(cls, config: Optional[Config] = None) -> Config:
+        if config is None:
+            config = cls.config
+        return Config(**config.asdict())
+
+    def setUp(self) -> None:
+        # reset config
+        self.cdb.config = self.config_copy()
+        # reset config hash
+        self.cdb._config_hash = self._config_hash
+        self.cdb.config.hash = self._config_hash
+
+    def test_CDB_same_hash_no_need_recalc(self):
+        self.assertFalse(self.cdb._should_recalc_hash(force_recalc=False))
+
+    def test_CDB_hash_recalc_if_no_config_hash(self):
+        self.cdb._config_hash = None
+        self.assertTrue(self.cdb._should_recalc_hash(force_recalc=False))
+
+    def test_CDB_hash_recalc_after_config_change(self):
+        self.cdb.config.linking.filters.cuis = {"a", "b", "c"}
+        self.assertTrue(self.cdb._should_recalc_hash(force_recalc=False))
+
+
 class BaseCATHashingTests(unittest.TestCase):
 
     @classmethod
@@ -75,8 +114,14 @@ def test_no_changes_recalc_same(self):
 
 class CATHashingTestsWithoutChange(CATHashingTestsWithFakeHash):
 
-    def test_no_changes_no_calc(self):
+    def setUp(self) -> None:
+        self._calculate_hash = self.undertest.cdb.calculate_hash
+        # make sure the hash exists
+        self.undertest.cdb._config_hash = self.undertest.cdb.config.get_hash()
+        self.undertest.cdb.get_hash()
         self.undertest.cdb.calculate_hash = unittest.mock.Mock()
+
+    def test_no_changes_no_calc(self):
         hash = self.undertest.get_hash()
         self.assertIsInstance(hash, str)
         self.undertest.cdb.calculate_hash.assert_not_called()
@@ -90,7 +135,7 @@ class CATHashingTestsWithChange(CATHashingTestsWithFakeHash):
 
     def test_when_changes_do_calc(self):
         with unittest.mock.patch.object(CDB, 'calculate_hash', return_value='abcd1234') as patch_method:
-            self.undertest.cdb.add_concept(**self.concept_kwargs)
+            self.undertest.cdb._add_concept(**self.concept_kwargs)
             hash = self.undertest.get_hash()
         self.assertIsInstance(hash, str)
         patch_method.assert_called()
@@ -106,10 +151,10 @@ def test_default_cdb_not_dirty(self):
         self.assertFalse(self.undertest.cdb.is_dirty)
 
     def test_after_add_concept_is_dirty(self):
-        self.undertest.cdb.add_concept(**self.concept_kwargs)
+        self.undertest.cdb._add_concept(**self.concept_kwargs)
         self.assertTrue(self.undertest.cdb.is_dirty)
 
     def test_after_recalc_not_dirty(self):
-        self.undertest.cdb.add_concept(**self.concept_kwargs)
+        self.undertest.cdb._add_concept(**self.concept_kwargs)
         self.undertest.get_hash()
         self.assertFalse(self.undertest.cdb.is_dirty)
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
new file mode 100644
index 000000000..6703ce91a
--- /dev/null
+++ b/tests/utils/test_helpers.py
@@ -0,0 +1,24 @@
+from medcat.utils.helpers import has_spacy_model, ensure_spacy_model
+from medcat.pipe import DEFAULT_SPACY_MODEL
+
+import unittest
+import subprocess
+
+
+class HasSpacyModelTests(unittest.TestCase):
+
+    def test_no_rubbish_model(self, model_name='rubbish_model'):
+        self.assertFalse(has_spacy_model(model_name))
+
+    def test_has_def_model(self, model_name=DEFAULT_SPACY_MODEL):
+        self.assertTrue(has_spacy_model(model_name))
+
+
+class EnsureSpacyModelTests(unittest.TestCase):
+
+    def test_fails_rubbish_model(self, model_name='rubbish_model'):
+        with self.assertRaises(subprocess.CalledProcessError):
+            ensure_spacy_model(model_name)
+
+    def test_success_def_model(self, model_name=DEFAULT_SPACY_MODEL):
+        ensure_spacy_model(model_name)
diff --git a/tests/utils/test_spacy_compatibility.py b/tests/utils/test_spacy_compatibility.py
new file mode 100644
index 000000000..5cf0dd03e
--- /dev/null
+++ b/tests/utils/test_spacy_compatibility.py
@@ -0,0 +1,302 @@
+import medcat.utils.spacy_compatibility as module_under_test
+from medcat.utils.spacy_compatibility import _is_spacy_model_folder, _find_spacy_model_folder
+from medcat.utils.spacy_compatibility import get_installed_spacy_version, get_installed_model_version
+from medcat.utils.spacy_compatibility import _get_name_and_meta_of_spacy_model_in_medcat_modelpack
+from medcat.utils.spacy_compatibility import get_name_and_version_of_spacy_model_in_medcat_modelpack
+from medcat.utils.spacy_compatibility import _is_spacy_version_within_range
+from medcat.utils.spacy_compatibility import medcat_model_pack_has_compatible_spacy_model
+from medcat.utils.spacy_compatibility import is_older_spacy_version
+from medcat.utils.spacy_compatibility import medcat_model_pack_has_semi_compatible_spacy_model
+
+import unittest
+
+from typing import Callable
+import random
+import string
+import tempfile
+import os
+from contextlib import contextmanager
+
+
+FAKE_SPACY_MODEL_NAME = "ff_core_fake_dr"
+FAKE_SPACY_MODEL_DIR = os.path.join("tests", "resources", FAKE_SPACY_MODEL_NAME)
+FAKE_MODELPACK_MODEL_DIR = os.path.join(FAKE_SPACY_MODEL_DIR, '..')
+
+
+class SpacyModelFolderIdentifierTests(unittest.TestCase):
+    expected_working_spacy_models = [
+        "en_core_sci_sm",
+        "en_core_web_sm",
+        "en_core_web_md",
+        "en_core_web_lg",
+        "en_core_web_trf",
+        "nl_core_news_sm",
+        "nl_core_news_md",
+        "nl_core_news_lg",
+    ]
+    # the following were used in medcat models created prior
+    # to v1.2.4
+    expected_working_legacy_names = [
+        "spacy_model"
+    ]
+
+    def test_works_expected_models(self):
+        for model_name in self.expected_working_spacy_models:
+            with self.subTest(model_name):
+                self.assertTrue(_is_spacy_model_folder(model_name))
+
+    def test_works_legacy_models(self):
+        for model_name in self.expected_working_legacy_names:
+            with self.subTest(model_name):
+                self.assertTrue(_is_spacy_model_folder(model_name))
+
+    def test_works_fill_path(self):
+        for model_name in self.expected_working_legacy_names:
+            full_folder_path = os.path.join("some", "folder", "structure", model_name)
+            with self.subTest(full_folder_path):
+                self.assertTrue(_is_spacy_model_folder(model_name))
+
+    def get_all_garbage(self) -> list:
+        """Generate garbage "spacy names".
+
+        Returns:
+            List[str]: Some random strings that shouldn't be spacy models.
+        """
+        my_examples = ["garbage_in_and_out", "meta_Presence", "something"]
+        true_randoms_N10 = [''.join(random.choices(string.ascii_uppercase + string.digits, k=10)) for _ in range(10)]
+        true_randoms_N20 = [''.join(random.choices(string.ascii_uppercase + string.digits, k=20)) for _ in range(10)]
+        return my_examples + true_randoms_N10 + true_randoms_N20
+
+    def test_does_not_work_grabage(self):
+        for garbage in self.get_all_garbage():
+            with self.subTest(garbage):
+                self.assertFalse(_is_spacy_model_folder(garbage))
+
+
+class FindSpacyFolderJustOneFolderEmptyFilesTests(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls, spacy_folder_name='en_core_web_md') -> None:
+        # setup temp folder
+        cls.temp_folder = tempfile.TemporaryDirectory()
+        cls.fake_modelpack_folder_name = cls.temp_folder.name
+        # create spacy folder
+        cls.spacy_folder = os.path.join(cls.fake_modelpack_folder_name, spacy_folder_name)
+        os.makedirs(cls.spacy_folder)
+        # create 2 empty files
+        filenames = ["file1.dat", "file2.json"]
+        filenames = [os.path.join(cls.fake_modelpack_folder_name, fn) for fn in filenames]
+        for fn in filenames:
+            with open(fn, 'w'):
+                pass # open and write empty file
+    
+    @classmethod
+    def tearDownClass(cls) -> None:
+        cls.temp_folder.cleanup()
+
+    def test_finds(self):
+        found_folder_path = _find_spacy_model_folder(self.fake_modelpack_folder_name)
+        self.assertEqual(found_folder_path, self.spacy_folder)
+
+
+class FindSpacyFolderMoreFoldersEmptyFilesTests(FindSpacyFolderJustOneFolderEmptyFilesTests):
+
+    @classmethod
+    def setUpClass(cls, spacy_folder_name='en_core_web_md') -> None:
+        super().setUpClass(spacy_folder_name)
+        # add a few folders
+        folder_names = ["meta_Presence", "garbage_in_garbage_out"]
+        folder_names = [os.path.join(cls.fake_modelpack_folder_name, fn) for fn in folder_names]
+        for folder in folder_names:
+            os.makedirs(folder)
+
+
+class SpacyVersionTests(unittest.TestCase):
+
+    def test_version_received(self):
+        installed = get_installed_spacy_version()
+        import spacy
+        expected = spacy.__version__
+        self.assertEqual(installed, expected)
+
+
+class InstalledVersionChecker(unittest.TestCase):
+
+    def test_existing(self, model_name: str = 'en_core_web_md'):
+        version = get_installed_model_version(model_name)
+        self.assertIsInstance(version, str)
+        self.assertNotEqual(version, "N/A")
+
+    def test_non_existing(self, model_name: str = 'en_core_web_lg'):
+        version = get_installed_model_version(model_name)
+        self.assertIsInstance(version, str)
+        self.assertEqual(version, "N/A")
+
+
+class GetSpacyModelInfoTests(unittest.TestCase):
+    expected_version = "3.1.0"
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls.name, cls.info = _get_name_and_meta_of_spacy_model_in_medcat_modelpack(FAKE_MODELPACK_MODEL_DIR)
+
+    def test_reads_name(self):
+        self.assertEqual(self.name, FAKE_SPACY_MODEL_NAME)
+
+    def test_reads_info(self):
+        self.assertIsInstance(self.info, dict)
+        self.assertTrue(self.info)  # not empty
+
+
+class GetSpacyModelVersionTests(GetSpacyModelInfoTests):
+    expected_spacy_version = ">=3.1.0,<3.2.0"
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        (cls.name,
+         cls.version,
+         cls.spacy_version) = get_name_and_version_of_spacy_model_in_medcat_modelpack(FAKE_MODELPACK_MODEL_DIR)
+
+    def test_name_correct(self):
+        self.assertEqual(self.name, FAKE_SPACY_MODEL_NAME)
+
+    def test_version_correct(self):
+        self.assertEqual(self.version, self.expected_version)
+
+    def test_spacy_version_correct(self):
+        self.assertEqual(self.spacy_version, self.expected_spacy_version)
+
+
+@contextmanager
+def custom_spacy_version(mock_version: str):
+    """Changes the apparently installed spacy version.
+    """
+    print(f"Mocking spacy version to: {mock_version}")
+    _old_method = module_under_test.get_installed_spacy_version
+    module_under_test.get_installed_spacy_version = lambda: mock_version
+    yield mock_version
+    print("Returning regular spacy version getter")
+    module_under_test.get_installed_spacy_version = _old_method
+
+
+class VersionMockBaseTests(unittest.TestCase):
+
+    def base_subtest_for(self, target_fun: Callable[[str], bool],
+                     spacy_model_range: str, spacy_version: str, should_work: bool) -> None:
+        with self.subTest(spacy_version):
+            if should_work:
+                self.assertTrue(target_fun(spacy_model_range))
+            else:
+                self.assertFalse(target_fun(spacy_model_range))
+
+    def base_check_version(self, target_fun: Callable[[str], bool],
+                       spacy_model_range: str, spacy_version: str, should_work: bool = True) -> None:
+        with custom_spacy_version(spacy_version):
+            self.base_subtest_for(target_fun, spacy_model_range, spacy_version, should_work)
+
+class SpacyVersionMockBaseTests(VersionMockBaseTests):
+
+    def _subtest_for(self, spacy_model_range: str, spacy_version: str, should_work: bool) -> None:
+        return self.base_subtest_for(_is_spacy_version_within_range,
+                                    spacy_model_range, spacy_version, should_work)
+
+    def _check_version(self, spacy_model_range: str, spacy_version: str, should_work: bool = True) -> None:
+        return self.base_check_version(_is_spacy_version_within_range,
+                                      spacy_model_range, spacy_version, should_work)
+
+
+class SpacyVersionInRangeOldRangeTests(SpacyVersionMockBaseTests):
+    """This is for versions before 1.7.0.
+    Those versions used to have spacy constraints of 'spacy<3.1.4,>=3.1.0'
+    and as such, they used v3.1.0 of en_core_web_md.
+    """
+    spacy_model_range = ">=3.1.0,<3.2.0"  # model range for en_core_web_md-3.1.0
+    useful_spacy_versions = ["3.1.0", "3.1.2", "3.1.3"]
+    unsupported_spacy_versions = ["3.2.0", "3.5.3", "3.6.0"]
+
+    def test_works_in_range(self):
+        for spacy_version in self.useful_spacy_versions:
+            self._check_version(self.spacy_model_range, spacy_version, should_work=True)
+
+    def test_not_suitable_outside_range(self):
+        for spacy_version in self.unsupported_spacy_versions:
+            self._check_version(self.spacy_model_range, spacy_version, should_work=False)
+
+
+class SpacyVersionInRangeNewRangeTests(SpacyVersionInRangeOldRangeTests):
+    """This is for versions AFTER (and includring) 1.7.0.
+    Those versions used to have spacy constraints of 'spacy>=3.1.0'
+    and as such, we use v3.4.0 of en_core_web_md.
+
+    In this setup, generally (in GHA at 14.12.2023)
+    the spacy version for python version:
+        3.8  -> spacy-3.7.2
+        3.9  -> spacy-3.7.2
+        3.10 -> spacy-3.7.2
+        3.11 -> spacy-3.7.2
+    Alongside the `en_core_web_md-3.4.0` is installed.
+    It technically has the compatibility of >=3.4.0,<3.5.0.
+    But practically, I've seen no issues with spacy==3.7.2.
+    """
+    spacy_model_range = ">=3.1.0"  # model range for medcat>=1.7.0
+    useful_spacy_versions = ["3.1.0", "3.1.2", "3.1.3",
+                             "3.7.2", "3.6.3"]
+    unsupported_spacy_versions = ["3.0.0"]
+
+
+class ModelPackHasCompatibleSpacyRangeTests(unittest.TestCase):
+    test_spacy_version = "3.1.0"
+
+    def test_is_in_range(self):
+        with custom_spacy_version(self.test_spacy_version):
+            b = medcat_model_pack_has_compatible_spacy_model(FAKE_MODELPACK_MODEL_DIR)
+            self.assertTrue(b)
+
+class ModelPackHasInCompatibleSpacyRangeTests(unittest.TestCase):
+    test_spacy_version = "3.2.0"
+
+    def test_is_in_range(self):
+        with custom_spacy_version(self.test_spacy_version):
+            b = medcat_model_pack_has_compatible_spacy_model(FAKE_MODELPACK_MODEL_DIR)
+            self.assertFalse(b)
+
+
+class IsOlderSpacyVersionTests(VersionMockBaseTests):
+    test_spacy_version = "3.4.4"
+    expected_older = ["3.1.0", "3.2.0", "3.3.0", "3.4.0"]
+    expected_newer = ["3.5.0", "3.6.0", "3.7.1"]
+
+    def _check_version(self, model_version: str, should_work: bool = True) -> None:
+        self.base_check_version(is_older_spacy_version, model_version, self.test_spacy_version, should_work)
+
+    def test_older_works(self):
+        for model_version in self.expected_older:
+            self._check_version(model_version, should_work=True)
+
+    def test_newer_fails(self):
+        for model_version in self.expected_newer:
+            self._check_version(model_version, should_work=False)
+
+
+class HasSemiCompatibleSpacyModelTests(unittest.TestCase):
+    # model version on file is 3.1.0,
+    # and spacy_version range >=3.1.0,<3.2.0"
+    good_spacy_version = "3.1.3"
+    semi_good_spacy_version = "3.4.4"  # newer than the model
+    bad_spacy_version = "3.0.0"  # older than the model
+
+    def run_subtest(self, spacy_version: str, should_work: bool) -> None:
+        with custom_spacy_version(spacy_version):
+            if should_work:
+                self.assertTrue(medcat_model_pack_has_semi_compatible_spacy_model(FAKE_MODELPACK_MODEL_DIR))
+            else:
+                self.assertFalse(medcat_model_pack_has_semi_compatible_spacy_model(FAKE_MODELPACK_MODEL_DIR))
+
+    def test_works_compatible_spacy_version(self):
+        self.run_subtest(self.good_spacy_version, should_work=True)
+
+    def test_works_semi_compatible_spacy_version(self):
+        self.run_subtest(self.semi_good_spacy_version, should_work=True)
+
+    def test_fails_incompatible_spacy_version(self):
+        self.run_subtest(self.bad_spacy_version, should_work=False)
diff --git a/webapp/webapp/requirements.txt b/webapp/webapp/requirements.txt
index a4b7827ad..ce68f853d 100644
--- a/webapp/webapp/requirements.txt
+++ b/webapp/webapp/requirements.txt
@@ -1,6 +1,6 @@
-Django==3.2.20
+Django==3.2.23
 django-dbbackup==4.0.0b0
 django-storages[boto3]==1.12.3
 django-cron==0.5.1
 medcat==1.2.7
-urllib3==1.26.5
+urllib3==1.26.18

From a429735b4b35faa25c9009adcb348a4fd514ad7a Mon Sep 17 00:00:00 2001
From: Mart Ratas 
Date: Tue, 13 Feb 2024 21:12:27 +0000
Subject: [PATCH 4/9] Cu 8693u6b4u tests continue on fail (#400) (#401)

* CU-8693u6b4u: Make sure failed/errored tests fail the main workflow

* CU-8693u6b4u: Attempt to fix deid multiprocessing, at least for GHA

* CU-8693u6b4u: Fix small docstring issue
---
 .github/workflows/main.yml   | 1 -
 medcat/config_meta_cat.py    | 2 +-
 tests/utils/ner/test_deid.py | 7 +++++++
 3 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 687160ed9..7c7a2b742 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -34,7 +34,6 @@ jobs:
       - name: Test
         run: |
           timeout 17m python -m unittest discover
-        continue-on-error: true
 
   publish-to-test-pypi:
 
diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py
index 47f42dc28..6ddd71d56 100644
--- a/medcat/config_meta_cat.py
+++ b/medcat/config_meta_cat.py
@@ -38,7 +38,7 @@ class General(MixingConfig, BaseModel):
     pipe_batch_size_in_chars: int = 20000000
     """How many characters are piped at once into the meta_cat class"""
     span_group: Optional[str] = None
-    """If set, the spacy span group that the metacat model will assign annotations. 
+    """If set, the spacy span group that the metacat model will assign annotations.
     Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings"""
 
     class Config:
diff --git a/tests/utils/ner/test_deid.py b/tests/utils/ner/test_deid.py
index 97ca8334b..01c9c1af3 100644
--- a/tests/utils/ner/test_deid.py
+++ b/tests/utils/ner/test_deid.py
@@ -154,6 +154,13 @@ def setUpClass(cls) -> None:
         for project in raw_data['projects']:
             for doc in project['documents']:
                 cls.data.append((f"{project['name']}_{doc['name']}", doc['text']))
+        # NOTE: Comment and subsequent code
+        #       copied from CAT.multiprocessing_batch_char_size
+        #       (lines 1234 - 1237)
+        # Hack for torch using multithreading, which is not good if not
+        #separate_nn_components, need for CPU runs only
+        import torch
+        torch.set_num_threads(1)
 
     def assertTextHasBeenDeIded(self, text: str, redacted: bool):
         if not redacted:

From e311d3620221b2e75874c65a09af4ea7b0db0d31 Mon Sep 17 00:00:00 2001
From: Mart Ratas 
Date: Wed, 28 Feb 2024 18:44:59 +0200
Subject: [PATCH 5/9] v1.10.2 PR (#407)

* Cu 8693u6b4u tests continue on fail (#400)

* CU-8693u6b4u: Make sure failed/errored tests fail the main workflow

* CU-8693u6b4u: Attempt to fix deid multiprocessing, at least for GHA

* CU-8693u6b4u: Fix small docstring issue

* CU-8693v3tt6 SOMED opcs refset selection (#402)

* CU-8693v3tt6: Update refset ID for OPCS4 mappings in newer SNOMED releases

* CU-8693v3tt6: Add method to get direct refset mappings

* CU-8693v3tt6: Add tests to direct refset mappings method

* CU-8693v3tt6: Fix OPCS4 refset ID selection logic

* CU-8693v3tt6: Add test for OPCS4 refset ID selection

* CU-8693v6epd: Move typing imports away from pydantic (#403)

* CU-8693qx9yp Deid chunking - hugging face pipeline approach (#405)

* Pushing chunking update

* Update transformers_ner.py

* Pushing update to config

Added NER config in cat load function

* Update cat.py

* Updating chunking overlap

* CU-8693qx9yp: Add warning for deid multiprocessing with (potentially) non-functioning chunking window

* CU-8693qx9yp: Fix linting issue

---------

Co-authored-by: mart-r 

---------

Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com>
---
 medcat/cat.py                         | 11 +++--
 medcat/config.py                      |  3 +-
 medcat/config_transformers_ner.py     |  2 +
 medcat/ner/transformers_ner.py        |  6 ++-
 medcat/utils/ner/deid.py              | 30 +++++++++++--
 medcat/utils/ner/model.py             |  7 +--
 medcat/utils/preprocess_snomed.py     | 37 +++++++++++++++-
 tests/utils/ner/test_deid.py          |  2 +-
 tests/utils/test_preprocess_snomed.py | 64 +++++++++++++++++++++++++++
 9 files changed, 147 insertions(+), 15 deletions(-)
 create mode 100644 tests/utils/test_preprocess_snomed.py

diff --git a/medcat/cat.py b/medcat/cat.py
index 9159eddd8..d36e30611 100644
--- a/medcat/cat.py
+++ b/medcat/cat.py
@@ -334,6 +334,7 @@ def attempt_unpack(cls, zip_path: str) -> str:
     def load_model_pack(cls,
                         zip_path: str,
                         meta_cat_config_dict: Optional[Dict] = None,
+                        ner_config_dict: Optional[Dict] = None,
                         load_meta_models: bool = True,
                         load_addl_ner: bool = True) -> "CAT":
         """Load everything within the 'model pack', i.e. the CDB, config, vocab and any MetaCAT models
@@ -346,6 +347,10 @@ def load_model_pack(cls,
                 A config dict that will overwrite existing configs in meta_cat.
                 e.g. meta_cat_config_dict = {'general': {'device': 'cpu'}}.
                 Defaults to None.
+            ner_config_dict (Optional[Dict]):
+                A config dict that will overwrite existing configs in transformers ner.
+                e.g. ner_config_dict = {'general': {'chunking_overlap_window': 6}.
+                Defaults to None.
             load_meta_models (bool):
                 Whether to load MetaCAT models if present (Default value True).
             load_addl_ner (bool):
@@ -381,15 +386,15 @@ def load_model_pack(cls,
         else:
             vocab = None
 
-        # Find meta models in the model_pack
+        # Find ner models in the model_pack
         trf_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else []
         addl_ner = []
         for trf_path in trf_paths:
-            trf = TransformersNER.load(save_dir_path=trf_path)
+            trf = TransformersNER.load(save_dir_path=trf_path,config_dict=ner_config_dict)
             trf.cdb = cdb # Set the cat.cdb to be the CDB of the TRF model
             addl_ner.append(trf)
 
-        # Find meta models in the model_pack
+        # Find metacat models in the model_pack
         meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else []
         meta_cats = []
         for meta_path in meta_paths:
diff --git a/medcat/config.py b/medcat/config.py
index e60c2eafc..98aee18df 100644
--- a/medcat/config.py
+++ b/medcat/config.py
@@ -1,8 +1,7 @@
 from datetime import datetime
 from pydantic import BaseModel, Extra, ValidationError
-from pydantic.dataclasses import Any, Callable, Dict, Optional, Union
 from pydantic.fields import ModelField
-from typing import List, Set, Tuple, cast
+from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union
 from multiprocessing import cpu_count
 import logging
 import jsonpickle
diff --git a/medcat/config_transformers_ner.py b/medcat/config_transformers_ner.py
index 64435e9cb..9f3102acb 100644
--- a/medcat/config_transformers_ner.py
+++ b/medcat/config_transformers_ner.py
@@ -13,6 +13,8 @@ class General(MixingConfig, BaseModel):
     """How many characters are piped at once into the meta_cat class"""
     ner_aggregation_strategy: str = 'simple'
     """Agg strategy for HF pipeline for NER"""
+    chunking_overlap_window: Optional[int] = 5
+    """Size of the overlap window used for chunking"""
     test_size: float = 0.2
     last_train_on: Optional[int] = None
     verbose_metrics: bool = False
diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py
index 729be4625..78b410230 100644
--- a/medcat/ner/transformers_ner.py
+++ b/medcat/ner/transformers_ner.py
@@ -76,9 +76,11 @@ def __init__(self, cdb, config: Optional[ConfigTransformersNER] = None,
         else:
             self.training_arguments = training_arguments
 
-
     def create_eval_pipeline(self):
-        self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer)
+
+        if self.config.general['chunking_overlap_window'] is None:
+            logger.warning("Chunking overlap window attribute in the config is set to None, hence chunking is disabled. Be cautious, PII data MAY BE REVEALED. To enable chunking, set the value to 0 or above.")
+        self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer,stride=self.config.general['chunking_overlap_window'])
         if not hasattr(self.ner_pipe.tokenizer, '_in_target_context_manager'):
             # NOTE: this will fix the DeID model(s) created before medcat 1.9.3
             #       though this fix may very well be unstable
diff --git a/medcat/utils/ner/deid.py b/medcat/utils/ner/deid.py
index 13ee5e04c..343e89ef0 100644
--- a/medcat/utils/ner/deid.py
+++ b/medcat/utils/ner/deid.py
@@ -34,7 +34,8 @@
 - config
 - cdb
 """
-from typing import Union, Tuple, Any, List, Iterable, Optional
+from typing import Union, Tuple, Any, List, Iterable, Optional, Dict
+import logging
 
 from medcat.cat import CAT
 from medcat.utils.ner.model import NerModel
@@ -42,6 +43,9 @@
 from medcat.utils.ner.helpers import _deid_text as deid_text, replace_entities_in_text
 
 
+logger = logging.getLogger(__name__)
+
+
 class DeIdModel(NerModel):
     """The DeID model.
 
@@ -93,6 +97,25 @@ def deid_multi_texts(self,
         Returns:
             List[str]: List of deidentified documents.
         """
+        # NOTE: we assume we're using the 1st (and generally only)
+        #       additional NER model.
+        #       the same assumption is made in the `train` method
+        chunking_overlap_window = self.cat._addl_ner[0].config.general.chunking_overlap_window
+        if chunking_overlap_window is not None:
+            logger.warning("Chunking overlap window has been set to %s. "
+                           "This may cause multiprocessing to stall in certain"
+                           "environments and/or situations and has not been"
+                           "fully tested.",
+                           chunking_overlap_window)
+            logger.warning("If the following hangs forever (i.e doesn't finish) "
+                           "but you still wish to run on multiple processes you can set "
+                           "`cat._addl_ner[0].config.general.chunking_overlap_window = None` "
+                           "and then either a) save the model on disk and load it back up, or "
+                           " b) call `cat._addl_ner[0].create_eval_pipeline()` to recreate the pipe. "
+                           "However, this will remove chunking from the input text, which means "
+                           "only the first 512 tokens will be recognised and thus only the "
+                           "first part of longer documents (those with more than 512) tokens"
+                           "will be deidentified. ")
         entities = self.cat.get_entities_multi_texts(texts, addl_info=addl_info,
                                                      n_process=n_process, batch_size=batch_size)
         out = []
@@ -110,7 +133,7 @@ def deid_multi_texts(self,
         return out
 
     @classmethod
-    def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
+    def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> 'DeIdModel':
         """Load DeId model from model pack.
 
         The method first loads the CAT instance.
@@ -119,6 +142,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
         valid DeId model.
 
         Args:
+            config: Config for DeId model pack (primarily for stride of overlap window)
             model_pack_path (str): The model pack path.
 
         Raises:
@@ -127,7 +151,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
         Returns:
             DeIdModel: The resulting DeI model.
         """
-        ner_model = NerModel.load_model_pack(model_pack_path)
+        ner_model = NerModel.load_model_pack(model_pack_path,config=config)
         cat = ner_model.cat
         if not cls._is_deid_model(cat):
             raise ValueError(
diff --git a/medcat/utils/ner/model.py b/medcat/utils/ner/model.py
index 553fb4c65..d3ff2eb3b 100644
--- a/medcat/utils/ner/model.py
+++ b/medcat/utils/ner/model.py
@@ -1,4 +1,4 @@
-from typing import Any, List, Tuple, Union, Optional
+from typing import Any, List, Tuple, Union, Optional, Dict
 
 from spacy.tokens import Doc
 
@@ -94,16 +94,17 @@ def create(cls, ner: Union[TransformersNER, List[TransformersNER]]) -> 'NerModel
         return cls(cat)
 
     @classmethod
-    def load_model_pack(cls, model_pack_path: str) -> 'NerModel':
+    def load_model_pack(cls, model_pack_path: str,config: Optional[Dict] = None) -> 'NerModel':
         """Load NER model from model pack.
 
         The method first wraps the loaded CAT instance.
 
         Args:
+            config: Config for DeId model pack (primarily for stride of overlap window)
             model_pack_path (str): The model pack path.
 
         Returns:
             NerModel: The resulting DeI model.
         """
-        cat = CAT.load_model_pack(model_pack_path)
+        cat = CAT.load_model_pack(model_pack_path,ner_config_dict=config)
         return cls(cat)
diff --git a/medcat/utils/preprocess_snomed.py b/medcat/utils/preprocess_snomed.py
index 3ba94b977..1e6efcb79 100644
--- a/medcat/utils/preprocess_snomed.py
+++ b/medcat/utils/preprocess_snomed.py
@@ -35,6 +35,32 @@ def get_all_children(sctid, pt2ch):
     return result
 
 
+def get_direct_refset_mapping(in_dict: dict) -> dict:
+    """This method uses the output from Snomed.map_snomed2icd10 or
+    Snomed.map_snomed2opcs4 and removes the metadata and maps each
+    SNOMED CUI to the prioritised list of the target ontology CUIs.
+
+    The input dict is expected to be in the following format:
+    - Keys are SnomedCT CUIs
+    - The values are lists of dictionaries, each list item (at least)
+      - Has a key 'code' that specifies the target onotlogy CUI
+      - Has a key 'mapPriority' that specifies the priority
+
+    Args:
+        in_dict (dict): The input dict.
+
+    Returns:
+        dict: The map from Snomed CUI to list of priorities list of target ontology CUIs.
+    """
+    ret_dict = dict()
+    for k, vals in in_dict.items():
+        # sort such that highest priority values are first
+        svals = sorted(vals, key=lambda el: el['mapPriority'], reverse=True)
+        # only keep the code / CUI
+        ret_dict[k] = [v['code'] for v in svals]
+    return ret_dict
+
+
 class Snomed:
     """
     Pre-process SNOMED CT release files.
@@ -53,6 +79,15 @@ def __init__(self, data_path, uk_ext=False, uk_drug_ext=False):
         self.release = data_path[-16:-8]
         self.uk_ext = uk_ext
         self.uk_drug_ext = uk_drug_ext
+        self.opcs_refset_id = "1126441000000105"
+        if ((self.uk_ext or self.uk_drug_ext) and
+                # using lexicographical comparison below
+                # e.g "20240101" > "20231122" results in True
+                # yet "20231121" > "20231122" reults in False
+                len(self.release) == len("20231122") and self.release >= "20231122"):
+            # NOTE for UK extensions starting from 20231122 the
+            #      OPCS4 refset ID seems to be different
+            self.opcs_refset_id = '1382401000000109'
 
     def to_concept_df(self):
         """
@@ -398,7 +433,7 @@ def _map_snomed2refset(self):
         mapping_df = pd.concat(dfs2merge)
         del dfs2merge
         if self.uk_ext or self.uk_drug_ext:
-            opcs_df = mapping_df[mapping_df['refsetId'] == '1126441000000105']
+            opcs_df = mapping_df[mapping_df['refsetId'] == self.opcs_refset_id]
             icd10_df = mapping_df[mapping_df['refsetId']
                                   == '999002271000000101']
             return icd10_df, opcs_df
diff --git a/tests/utils/ner/test_deid.py b/tests/utils/ner/test_deid.py
index 01c9c1af3..0eed7b6da 100644
--- a/tests/utils/ner/test_deid.py
+++ b/tests/utils/ner/test_deid.py
@@ -41,11 +41,11 @@ def test_can_create_model(self):
         deid_model = deid.DeIdModel.create(ner)
         self.assertIsNotNone(deid_model)
 
-
 def _add_model(cls):
     cdb = make_or_update_cdb(TRAIN_DATA)
     config = transformers_ner.ConfigTransformersNER()
     config.general['test_size'] = 0.1  # Usually set this to 0.1-0.2
+    config.general['chunking_overlap_window'] = None
     cls.ner = transformers_ner.TransformersNER(cdb=cdb, config=config)
     cls.ner.training_arguments.num_train_epochs = 1  # Use 5-10 normally
     # As we are NOT training on a GPU that can, we'll set it to 1
diff --git a/tests/utils/test_preprocess_snomed.py b/tests/utils/test_preprocess_snomed.py
new file mode 100644
index 000000000..59a00f6fc
--- /dev/null
+++ b/tests/utils/test_preprocess_snomed.py
@@ -0,0 +1,64 @@
+from typing import Dict
+from medcat.utils import preprocess_snomed
+
+import unittest
+
+
+EXAMPLE_REFSET_DICT: Dict = {
+   'SCUI1': [
+       {'code': 'TCUI1', 'mapPriority': '1'},
+       {'code': 'TCUI2', 'mapPriority': '2'},
+       {'code': 'TCUI3', 'mapPriority': '3'},
+       ]
+}
+
+# in order from highest priority to lowest
+EXPECTED_DIRECT_MAPPINGS = {"SCUI1": ['TCUI3', 'TCUI2', 'TCUI1']}
+
+EXAMPLE_REFSET_DICT_WITH_EXTRAS = dict(
+    (k, [dict(v, otherKey=f"val-{k}") for v in vals]) for k, vals in EXAMPLE_REFSET_DICT.items())
+
+EXAMPLE_REFSET_DICT_NO_PRIORITY = dict(
+    (k, [{ik: iv for ik, iv in v.items() if ik != 'mapPriority'} for v in vals]) for k, vals in EXAMPLE_REFSET_DICT.items()
+)
+
+EXAMPLE_REFSET_DICT_NO_CODE = dict(
+    (k, [{ik: iv for ik, iv in v.items() if ik != 'code'} for v in vals]) for k, vals in EXAMPLE_REFSET_DICT.items()
+)
+
+
+class DirectMappingTest(unittest.TestCase):
+
+    def test_example_gets_direct_mappings(self):
+        res = preprocess_snomed.get_direct_refset_mapping(EXAMPLE_REFSET_DICT)
+        self.assertEqual(res, EXPECTED_DIRECT_MAPPINGS)
+
+    def test_example_w_extras_gets_direct_mappings(self):
+        res = preprocess_snomed.get_direct_refset_mapping(EXAMPLE_REFSET_DICT_WITH_EXTRAS)
+        self.assertEqual(res, EXPECTED_DIRECT_MAPPINGS)
+
+    def test_example_no_priority_fails(self):
+        with self.assertRaises(KeyError):
+            preprocess_snomed.get_direct_refset_mapping(EXAMPLE_REFSET_DICT_NO_PRIORITY)
+
+    def test_example_no_codfe_fails(self):
+        with self.assertRaises(KeyError):
+            preprocess_snomed.get_direct_refset_mapping(EXAMPLE_REFSET_DICT_NO_CODE)
+
+EXAMPLE_SNOMED_PATH_OLD = "SnomedCT_InternationalRF2_PRODUCTION_20220831T120000Z"
+EXAMPLE_SNOMED_PATH_NEW = "SnomedCT_UKClinicalRF2_PRODUCTION_20231122T000001Z"
+
+
+class TestSnomedVersionsOPCS4(unittest.TestCase):
+
+    def test_old_gets_old_OPCS4_mapping_nonuk_ext(self):
+        snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_OLD, uk_ext=False)
+        self.assertEqual(snomed.opcs_refset_id, "1126441000000105")
+
+    def test_old_gets_old_OPCS4_mapping_uk_ext(self):
+        snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_OLD, uk_ext=True)
+        self.assertEqual(snomed.opcs_refset_id, "1126441000000105")
+
+    def test_new_gets_new_OCPS4_mapping_uk_ext(self):
+        snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_NEW, uk_ext=True)
+        self.assertEqual(snomed.opcs_refset_id, "1382401000000109")

From ffd22a620281b649653e7e481c51e662b495c563 Mon Sep 17 00:00:00 2001
From: Mart Ratas 
Date: Wed, 19 Jun 2024 17:01:16 +0100
Subject: [PATCH 6/9] v1.12.0 release PR (#455)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Pushing changes for bert-style models for MetaCAT

* Pushing fix for LSTM

* Pushing changes for flake8 and type fixes

* Pushing type fixes

* Fixing type issue

* Pushing changes

1) Added model.zero_grad to clear accumulated gradients
2) Fixed config save issue
3) Re-structured data preparation for oversampled data

* Pushing change and type fixes

Pushing ml_utils file which was missed in the last commit

* Fixing flake8 issues

* Pushing flake8 fixes

* Pushing fixes for flake8

* Pushing flake8 fix

* Adding peft to list of libraries

* Pushing changes with load and train workflow and type fixes

The workflow for inference is: load() and inference
For training: init() and train()
Train will always not load the model dict, except when the phase_number is set to 2 for 2 phase learning's second phase

* Pushing changes with type hints and new documentation

* Pushing type fix

* Fixing type issue

* Adding test case for BERT and reverting config changes

BERT test cases: Testing for BERT model along with 2 phase learning

* Merging changes from master to metacat_bert branch (#431)

* Small addition to contribution guidelines (#420)

* CU-8694cbcpu: Allow specifying an AU Snomed when preprocessing (#421)

* CU-8694dpy1c: Return empty generator upon empty stream (#423)

* CU-8694dpy1c: Return empty generator upon empty stream

* CU-8694dpy1c: Fix empty generator returns

* CU-8694dpy1c: Simplify empty generator returns

* Relation extraction (#173)

* Added files.

* More additions to rel extraction.

* Rel base.

* Update.

* Updates.

* Dependency parsing.

* Updates.

* Added pre-training steps.

* Added training & model utils.

* Cleanup & fixes.

* Update.

* Evaluation updates for pretraining.

* Removed duplicate relation storage.

* Moved RE model file location.

* Structure revisions.

* Added custom config for RE.

* Implemented custom dataset loader for RE.

* More changes.

* Small fix.

* Latest additions to RelCAT (pipe + predictions)

* Setup.py fix.

* RE utils update.

* rel model update.

* rel dataset + tokenizer improvements.

* RelCAT updates.

* RelCAT saving/loading improvements.

* RelCAT saving/loading improvements.

* RelCAT model fixes.

* Attempted gpu learning fix. Dataset label generation fixes.

* Minor train dataset gen fix.

* Minor train dataset gen fix No.2.

* Config updates.

* Gpu support fixes. Added label stats.

* Evaluation stat fixes.

* Cleaned stat output mode during training.

* Build fix.

* removed unused dependencies and fixed code formatting

* Mypy compliance.

* Fixed linting.

* More Gpu mode train fixes.

* Fixed model saving/loading issues when using other baes models.

* More fixes to stat evaluation. Added proper CAT integration of RelCAT.

* Setup.py typo fix.

* RelCAT loading fix.

* RelCAT Config changes.

* Type fix. Minor additions to RelCAT model.

* Type fixes.

* Type corrections.

* RelCAT update.

* Type fixes.

* Fixed type issue.

* RelCATConfig: added seed param.

* Adaptations to the new codebase + type fixes..

* Doc/type fixes.

* Fixed input size issue for model.

* Fixed issue(s) with model size and config.

* RelCAT: updated configs to new style.

* RelCAT: removed old refs to logging.

* Fixed GPU training + added extra stat print for train set.

* Type fixes.

* Updated dev requirements.

* Linting.

* Fixed pin_memory issue when training on CPU.

* Updated RelCAT dataset get + default config.

* Updated RelDS generator + default config

* Linting.

* Updated RelDatset + config.

* Pushing updates to model

Made changes to:
1) Extracting given number of context tokens left and right of the entities
2) Extracting hidden state from bert for all the tokens of the entities and performing max pooling on them

* Fixing formatting

* Update rel_dataset.py

* Update rel_dataset.py

* Update rel_dataset.py

* RelCAT: added test resource files.

* RelCAT: Fixed model load/checkpointing.

* RelCAT: updated to pipe spacy doc call.

* RelCAT: added tests.

* Fixed lint/type issues & added rel tag to test DS.

* Fixed ann id to token issue.

* RelCAT: updated test dataset + tests.

* RelCAT: updates to requested changes + dataset improvements.

* RelCAT: updated docs/logs according to commends.

* RelCAT: type fix.

* RelCAT: mct export dataset updates.

* RelCAT: test updates + requested changes p2.

* RelCAT: log for MCT export train.

* Updated docs + split train_test & dataset for benchmarks.

* type fixes.

---------

Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com>
Co-authored-by: mart-r 

* CU-8694fae3r: Avoid publishing PyPI release when doing GH pre-releases (#424)

* CU-8694fae3r: Avoid publishing PyPI release when doing GH pre-releases

* CU-8694fae3r: Fix pre-releases tagging

* CU-8694fae3r: Allow actions to run on release edit

---------

Co-authored-by: Mart Ratas 
Co-authored-by: Vlad Dinu <62345326+vladd-bit@users.noreply.github.com>

* Pushing changed tests and removing empty change

* Pushing change for logging

* Revert "Pushing change for logging"

This reverts commit fbcdb704dddda2a36626c4f54ce0672ecfcf6321.

* CU-8694hukwm: Document the materialising of generator when multiproce… (#433)

* CU-8694hukwm: Document the materialising of generator when multiprocessing and batching for docs

* CU-8694hukwm: Add TODO note for where the generator is materialised

* CU-8694hukwm: Add warning from large amounts of generator data (10k items) is materialised by the docs size mp method

* CU-8694fk90t (almost) only primitive config (#425)

* CU-8694fk90r: Move backwards compatibility method from CDB to config utils

* CU-8694fk90r: Move weighted_average_function from config to CDB; create necessary backwards compatibility workarounds

* CU-8694fk90r: Move usage of weighted_average_function in tests

* CU-8694fk90r: Add JSON encode and decoder for re.Pattern

* CU-8694fk90r: Rebuild custom decoder if needed

* CU-8694fk90r: Add method to detect old style config

* CU-8694fk90r: Use regular json serialisation for config; Retain option to read old jsonpickled config

* CU-8694fk90r: Add test for config serialisation

* CU-8694fk90r: Make sure to fix weighted_average_function upon setting it

* CU-8694fk90t: Add missing tests for config utils

* CU-8694fk90t: Add tests for better raised exception upon old way of using weighted_average_function

* CU-8694fk90t: Fix exception type in an added test

* CU-8694fk90t: Add further tests for exception payload

* CU-8694fk90t: Add improved exceptions when using old/unsupported value of weighted_average_function in config

* CU-8694fk90t: Add typing fix exceptions

* CU-8694fk90t: Make custom exception derive from AttributeError to correctly handle hasattr calls

* CU-8694gza88: Create codeql.yml (#434)

Run CodeQL to identify vulnerabilities.
This will run on any push or pull request to `master`, but also runs once every day in case some new vulnerabilities are discovered (or something else changes).

* CU-8694mbn03: Remove the web app (#441)

* CU-8694n48uw better deprecation (#443)

* CU-8694n493m: Add deprecation and removal versions to deprecation decorator

* CU-8694n493m: Deprecation version to existing deprecated methods.

Made the removal version 2 minor versions from the minor version
in which the method was deprecated, or the next minor version if
the method had been deprecated for longer.

* CU-8694n4ff0: Raise exception upon deprecated method call at test time

* CU-8694n4ff0: Fix usage of deprecated methods call during test time

* CU-8694pey4u: extract cdb load to cls method, to be used in trainer for model pack loading

* CU-8694pey4u: extract meta cat loading also to a cls method

* CU-8694pey4u: docstrings

* CU-8694pey4u: typehints and mypy issues

* CU-8694pey4u: fix flake8

* CU-8694pey4u: fix flake8

* CU-8694pey4u: missing extra config if passed in

* CU-8694py1jr: Fix issue with reuse of opened file when loading old configs

* CU-8694py1jr: Make old config identifier more robust

* CU-8694py1jr: Add doc string to old config identifier

* CU-8694py1jr: Add test for old style MetaCAT config load

* CU-8694py1jr: Add test for old style main config load (functional)

* CU-8694py1jr: Refactor config utils load tests for more flexibility

* CU-8694py1jr: Add config utils load tests for NER and Rel CAT configs

* CU-8694vcvz7: Trust remote code when loading transfomers NER dataset (#453)

* CU-8694vcvz7: Trust remote code when loading transfomers NER dataset

* CU-8694vcvz7: Add support for older datasets without the remote code trusing kwarg

* CU-8694gzbn3 k fold metrics (#432)

* CU-8694gzbud: Add context manager that is able to snapshot CDB state

* CU-8694gzbud: Add tests to snapshotting CDB state

* CU-8694gzbud: Refactor tests for CDB state snapshotting

* CU-8694gzbud: Remove use of deprecated method in CDB utils and use non-deprecated one instead

* CU-8694gzbud: Add tests for training and CDB state capturing

* CU-8694gzbud: Small refactor in tests

* CU-8694gzbud: Add option to save state on disk

* CU-8694gzbud: Add debug logging output when saving state on disk

* CU-8694gzbud: Remove unused import

* CU-8694gzbud: Add tests for disk-based state save

* CU-8694gzbud: Move CDB state code to its own module

* CU-8694gzbud: Remove unused import

* CU-8694gzbud: Add doc strings to methods

* CU-8694gzbx4: Small optimisation for stats

* CU-8694gzbx4: Add MCTExport related module

* CU-8694gzbx4: Add MCTExport related tests

* CU-8694gzbx4: Add code for k-fold statistics

* CU-8694gzbx4: Add tests for k-fold statistics

* CU-8694gzbx4: Add test-MCT export with fake concepts

* CU-8694gzbx4: Fix a doc string

* CU-8694gzbx4: Fix types in MCT export module

* CU-8694gzbx4: Fix types in k-fold module

* CU-8694gzbx4: Remove accidentally committed test class

* CU-8694gzbn3: Add missing test helper file

* CU-8694gzbn3: Remove whitespace change from otherwise uncahnged file

* CU-8694gzbn3: Allow 5 minutes longer for tests

* CU-8694gzbn3: Move to python 3.8-compatible typed dict

* CU-8694gzbn3: Add more time for tests in worklow (now 30 minutes)

* CU-8694gzbn3: Add more time for tests in worklow (now 45 minutes)

* CU-8694gzbn3: Update test-pypi timeout to 45 minutes

* CU-8694gzbn3: Remove timeout from unit tests in main workflow

* CU-8694gzbn3: Make tests stop upon first failure

* CU-8694gzbn3: Fix test stop upon first failure (arg/option order)

* CU-8694gzbn3: Remove debug code and old comments

* CU-8694gzbn3: Remove all timeouts from main workflow

* CU-8694gzbn3: Remove more old / useless comments in tests

* CU-8694gzbn3: Add debug output when running k-fold tests to see where it may be stalling

* CU-8694gzbn3: Add debug output when ANY tests to see where it may be stalling

* CU-8694gzbn3: Remove explicit debug output from k-fold test cases

* CU-8694gzbn3: Remove timeouts from DEID tests in case they're the ones creating issues

* GHA/test fixes (#437)

* Revert "CU-8694gzbn3: Remove timeouts from DEID tests in case they're the ones creating issues"

This reverts commit faaf7fb1c4b8b1a9c81ac6b81a464fdd2b55afdc.

* Revert "CU-8694gzbn3: Remove explicit debug output from k-fold test cases"

This reverts commit 9b0292517c3f442a57fdc04593f721309f0502f7.

* Revert "CU-8694gzbn3: Add debug output when ANY tests to see where it may be stalling"

This reverts commit 12c519aceb25b960f4b9ff3d3298d18aff8030b4.

* Revert "CU-8694gzbn3: Add debug output when running k-fold tests to see where it may be stalling"

This reverts commit 03531da0288eb3949807f6720c169625c4d3097c.

* Revert "CU-8694gzbn3: Remove all timeouts from main workflow"

This reverts commit e6debce71e053ac825f6c8191e5d0e454a77bb11.

* Revert "CU-8694gzbn3: Fix test stop upon first failure (arg/option order)"

This reverts commit 666c0139f48b7c1565bd250bf0b780069005623f.

* Revert "CU-8694gzbn3: Make tests stop upon first failure"

This reverts commit 94bce5650d967b0273e1546326108e50f611f687.

* Revert "CU-8694gzbn3: Remove timeout from unit tests in main workflow"

This reverts commit 3618b9c7cc5b755f430debf129727fab235d79ce.

* CU-8694gzbn3: Improve state copy code in CDB state tests

* CU-8694gzbn3: Fix a CDB state test issue

* CU-8694gzbn3: Split all tests into 2 halves

* CU-8694gzbn3: Remove legacy / archived / unused tests

* CU-8694gzbn3: Add doc strings for FoldCreator init

* CU-8694gzbn3: Move to a split-type enum

* CU-8694gzbn3: Add documentation to split-type enum

* CU-8694gzbn3: Create separate fold creators for different types of splitting strategies

* CU-8694gzbn3: Resort document order in test time nullification process

* CU-8694gzbn3: Add option to count number of annotations in doc for MCT export

* CU-8694gzbn3: Add weighted documents based split option along with relevant tests

* CU-8694gzbn3: Update default fold creation split type to weighted documents

* CU-8694gzbn3: Add test to ensure weighted documents split creates a reasonable number of annotations per split

* CU-8693n892x environment/dependency snapshots (#438)

* CU-8693n892x: Save environment/dependency snapshot upon model pack creation

* CU-8693n892x: Fix typing for env snapshot module

* CU-8693n892x: Add test for env file existance in .zip

* CU-8693n892x: Add doc strings

* CU-8693n892x: Centralise env snapshot file name

* CU-8693n892x: Add env snapshot file to exceptions in serialisation tests

* CU-8693n892x: Only list direct dependencies

* CU-8693n892x: Add test that verifies all direct dependencies are listed in environment

* CU-8693n892x: Move requirements to separate file and use that for environment snapshot

* CU-8693n892x: Remove unused constants

* CU-8693n892x: Allow URL based dependencies when using direct dependencies

* CU-8693n892x: Distribute install_requires.txt alongside the package; use correct path in distributed version

* CU-8694p8y0k deprecation GHA check (#445)

* CU-8694p8y0k: Add check for deprecations (code)

* CU-8694p8y0k: Add workflow check for deprecations

* CU-8694p8y0k: Fix (hopefully) workflow check for deprecations

* CU-8694p8y0k: Add option to remove version prefix when checking deprecation

* CU-8694p8y0k: Update deprecation checks with more detail (i.e current/next version).

* CU-8694p8y0k: Only run deprecation checking step when merging master into production

* CU-8694u3yd2 cleanup name removal (#450)

* CU-8694u3yd2: Add logged warning for when using full-unlink

* CU-8694u3yd2: Make CDB.remove_names simply expect an iterable of names

* CU-8694u3yd2: Improve CDB.remove_names doc string

* CU-8694u3yd2: Explicitly pass the keys to CDB.remove_names in CAT.unlink_concept_name

* CU-8694u3yd2: Add note regarding state (and order) dependent tests to some CDB maker tests

* CU-8694u3yd2: Rename/make protected CDB.remove_names method

* CU-8694u3yd2: Create deprecated CDB.remove_names method

* CU-8694vte2g 1.12 depr removal (#454)

* CU-8694vte2g: Remove CDB.add_concept method

* CU-8694vte2g: Remove unused import (deprecated decorator)

* CU-8694vte2g: Remove CAT.get_spacy_nlp method

* CU-8694vte2g: Remove CAT.train_supervised method

* CU-8694vte2g: Remove CAT multiprocessing methods

* CU-8694vte2g: Remove MetaCAT.train method

* CU-8694vte2g: Remove medcat.utils.ner.helper.deid_text method

* CU-8694vte2g: Remove use of deprecated method

* CU-8694vte2g: Add back removed deprecation import

---------

Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com>
Co-authored-by: Vlad Dinu <62345326+vladd-bit@users.noreply.github.com>
Co-authored-by: Tom Searle 
---
 .github/workflows/codeql.yml                  |  95 ++++
 .github/workflows/main.yml                    |  29 +-
 install_requires.txt                          |  24 +
 medcat/cat.py                                 | 173 ++++---
 medcat/cdb.py                                 |  91 ++--
 medcat/config.py                              |  57 ++-
 medcat/config_meta_cat.py                     |  19 +-
 medcat/linking/vector_context_model.py        |   4 +-
 medcat/meta_cat.py                            | 135 ++++--
 medcat/ner/transformers_ner.py                |  29 +-
 medcat/stats/kfold.py                         | 436 ++++++++++++++++++
 medcat/stats/mctexport.py                     |  66 +++
 medcat/stats/stats.py                         |   6 +-
 medcat/tokenizers/meta_cat_tokenizers.py      |  13 +-
 medcat/utils/cdb_state.py                     | 179 +++++++
 medcat/utils/cdb_utils.py                     |   2 +-
 medcat/utils/config_utils.py                  |  66 ++-
 medcat/utils/decorators.py                    |  20 +-
 medcat/utils/meta_cat/data_utils.py           | 169 +++++--
 medcat/utils/meta_cat/ml_utils.py             | 206 +++++++--
 medcat/utils/meta_cat/models.py               | 174 ++++---
 medcat/utils/ner/__init__.py                  |   2 +-
 medcat/utils/ner/deid.py                      |  12 +-
 medcat/utils/ner/helpers.py                   |  36 --
 medcat/utils/saving/coding.py                 |  32 +-
 medcat/utils/saving/envsnapshot.py            |  73 +++
 setup.py                                      |  38 +-
 tests/__init__.py                             |  25 +
 tests/archive_tests/test_cdb_maker_archive.py | 124 -----
 tests/archive_tests/test_ner_archive.py       | 139 ------
 tests/check_deprecations.py                   | 178 +++++++
 tests/medmentions/make_cdb.py                 | 120 -----
 tests/medmentions/prepare_data.py             |   7 -
 tests/resources/jsonpickle_config.json        | 274 +++++++++++
 .../resources/jsonpickle_meta_cat_config.json |  89 ++++
 .../resources/jsonpickle_rel_cat_config.json  |  91 ++++
 tests/resources/jsonpickle_tner_config.json   |  23 +
 .../medcat_trainer_export_FAKE_CONCEPTS.json  |  84 ++++
 .../webapp/demo => tests/stats}/__init__.py   |   0
 tests/stats/helpers.py                        |  17 +
 tests/stats/test_kfold.py                     | 298 ++++++++++++
 tests/stats/test_mctexport.py                 |  38 ++
 tests/test_cat.py                             |  28 +-
 tests/test_cdb_maker.py                       |  20 +-
 tests/test_config.py                          |  39 ++
 tests/test_meta_cat.py                        |  50 +-
 tests/utils/saving/test_envsnapshot.py        | 105 +++++
 tests/utils/saving/test_serialization.py      |   8 +-
 tests/utils/test_cdb_state.py                 | 113 +++++
 tests/utils/test_config_utils.py              | 121 +++++
 webapp/.gitignore                             |   6 -
 webapp/README.md                              |   1 -
 webapp/docker-compose.yml                     |  26 --
 webapp/envs/env_db_backup                     |   8 -
 webapp/envs/env_medmen                        |   1 -
 webapp/webapp/.dockerignore                   |   2 -
 webapp/webapp/Dockerfile                      |  37 --
 webapp/webapp/data/.keep                      |   0
 webapp/webapp/db/.keep                        |   0
 webapp/webapp/demo/admin.py                   |  16 -
 webapp/webapp/demo/apps.py                    |   5 -
 webapp/webapp/demo/db_backup.py               |  20 -
 webapp/webapp/demo/forms.py                   |  48 --
 webapp/webapp/demo/migrations/0001_initial.py |  22 -
 .../migrations/0002_downloader_medcatmodel.py |  38 --
 webapp/webapp/demo/migrations/__init__.py     |   0
 webapp/webapp/demo/models.py                  |  31 --
 webapp/webapp/demo/static/css/annotations.css | 110 -----
 webapp/webapp/demo/static/css/base.css        |  86 ----
 webapp/webapp/demo/static/css/home.css        |  23 -
 webapp/webapp/demo/static/image/favicon.ico   | Bin 4641 -> 0 bytes
 webapp/webapp/demo/static/js/.keep            |   0
 webapp/webapp/demo/static/js/anns.js          |  95 ----
 webapp/webapp/demo/templates/base.html        |  33 --
 .../demo/templates/train_annotations.html     | 147 ------
 .../demo/templates/umls_user_validation.html  |  67 ---
 webapp/webapp/demo/tests.py                   |   3 -
 webapp/webapp/demo/urls.py                    |   9 -
 webapp/webapp/demo/views.py                   | 129 ------
 webapp/webapp/etc/cron.d/db-backup-cron       |   1 -
 webapp/webapp/manage.py                       |  21 -
 webapp/webapp/models/.keep                    |   0
 webapp/webapp/requirements.txt                |   6 -
 webapp/webapp/webapp/__init__.py              |   0
 webapp/webapp/webapp/settings.py              | 146 ------
 webapp/webapp/webapp/urls.py                  |  26 --
 webapp/webapp/webapp/wsgi.py                  |  16 -
 87 files changed, 3316 insertions(+), 2040 deletions(-)
 create mode 100644 .github/workflows/codeql.yml
 create mode 100644 install_requires.txt
 create mode 100644 medcat/stats/kfold.py
 create mode 100644 medcat/stats/mctexport.py
 create mode 100644 medcat/utils/cdb_state.py
 create mode 100644 medcat/utils/saving/envsnapshot.py
 delete mode 100644 tests/archive_tests/test_cdb_maker_archive.py
 delete mode 100644 tests/archive_tests/test_ner_archive.py
 create mode 100644 tests/check_deprecations.py
 delete mode 100644 tests/medmentions/make_cdb.py
 delete mode 100644 tests/medmentions/prepare_data.py
 create mode 100644 tests/resources/jsonpickle_config.json
 create mode 100644 tests/resources/jsonpickle_meta_cat_config.json
 create mode 100644 tests/resources/jsonpickle_rel_cat_config.json
 create mode 100644 tests/resources/jsonpickle_tner_config.json
 create mode 100644 tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json
 rename {webapp/webapp/demo => tests/stats}/__init__.py (100%)
 create mode 100644 tests/stats/helpers.py
 create mode 100644 tests/stats/test_kfold.py
 create mode 100644 tests/stats/test_mctexport.py
 create mode 100644 tests/utils/saving/test_envsnapshot.py
 create mode 100644 tests/utils/test_cdb_state.py
 create mode 100644 tests/utils/test_config_utils.py
 delete mode 100644 webapp/.gitignore
 delete mode 100644 webapp/README.md
 delete mode 100644 webapp/docker-compose.yml
 delete mode 100644 webapp/envs/env_db_backup
 delete mode 100644 webapp/envs/env_medmen
 delete mode 100644 webapp/webapp/.dockerignore
 delete mode 100644 webapp/webapp/Dockerfile
 delete mode 100644 webapp/webapp/data/.keep
 delete mode 100644 webapp/webapp/db/.keep
 delete mode 100644 webapp/webapp/demo/admin.py
 delete mode 100644 webapp/webapp/demo/apps.py
 delete mode 100644 webapp/webapp/demo/db_backup.py
 delete mode 100644 webapp/webapp/demo/forms.py
 delete mode 100644 webapp/webapp/demo/migrations/0001_initial.py
 delete mode 100644 webapp/webapp/demo/migrations/0002_downloader_medcatmodel.py
 delete mode 100644 webapp/webapp/demo/migrations/__init__.py
 delete mode 100644 webapp/webapp/demo/models.py
 delete mode 100644 webapp/webapp/demo/static/css/annotations.css
 delete mode 100644 webapp/webapp/demo/static/css/base.css
 delete mode 100644 webapp/webapp/demo/static/css/home.css
 delete mode 100644 webapp/webapp/demo/static/image/favicon.ico
 delete mode 100644 webapp/webapp/demo/static/js/.keep
 delete mode 100644 webapp/webapp/demo/static/js/anns.js
 delete mode 100644 webapp/webapp/demo/templates/base.html
 delete mode 100644 webapp/webapp/demo/templates/train_annotations.html
 delete mode 100644 webapp/webapp/demo/templates/umls_user_validation.html
 delete mode 100644 webapp/webapp/demo/tests.py
 delete mode 100644 webapp/webapp/demo/urls.py
 delete mode 100644 webapp/webapp/demo/views.py
 delete mode 100644 webapp/webapp/etc/cron.d/db-backup-cron
 delete mode 100755 webapp/webapp/manage.py
 delete mode 100644 webapp/webapp/models/.keep
 delete mode 100644 webapp/webapp/requirements.txt
 delete mode 100644 webapp/webapp/webapp/__init__.py
 delete mode 100644 webapp/webapp/webapp/settings.py
 delete mode 100644 webapp/webapp/webapp/urls.py
 delete mode 100644 webapp/webapp/webapp/wsgi.py

diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml
new file mode 100644
index 000000000..9984edc16
--- /dev/null
+++ b/.github/workflows/codeql.yml
@@ -0,0 +1,95 @@
+# For most projects, this workflow file will not need changing; you simply need
+# to commit it to your repository.
+#
+# You may wish to alter this file to override the set of languages analyzed,
+# or to provide custom queries or build logic.
+#
+# ******** NOTE ********
+# We have attempted to detect the languages in your repository. Please check
+# the `language` matrix defined below to confirm you have the correct set of
+# supported CodeQL languages.
+#
+name: "CodeQL"
+
+on:
+  push:
+    branches: [ "master" ]
+  pull_request:
+    branches: [ "master" ]
+  schedule:
+    - cron: '36 14 * * 0'
+
+jobs:
+  analyze:
+    name: Analyze (${{ matrix.language }})
+    # Runner size impacts CodeQL analysis time. To learn more, please see:
+    #   - https://gh.io/recommended-hardware-resources-for-running-codeql
+    #   - https://gh.io/supported-runners-and-hardware-resources
+    #   - https://gh.io/using-larger-runners (GitHub.com only)
+    # Consider using larger runners or machines with greater resources for possible analysis time improvements.
+    runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
+    timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }}
+    permissions:
+      # required for all workflows
+      security-events: write
+
+      # required to fetch internal or private CodeQL packs
+      packages: read
+
+      # only required for workflows in private repositories
+      actions: read
+      contents: read
+
+    strategy:
+      fail-fast: false
+      matrix:
+        include:
+        - language: javascript-typescript
+          build-mode: none
+        - language: python
+          build-mode: none
+        # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift'
+        # Use `c-cpp` to analyze code written in C, C++ or both
+        # Use 'java-kotlin' to analyze code written in Java, Kotlin or both
+        # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
+        # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
+        # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
+        # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
+        # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
+    steps:
+    - name: Checkout repository
+      uses: actions/checkout@v4
+
+    # Initializes the CodeQL tools for scanning.
+    - name: Initialize CodeQL
+      uses: github/codeql-action/init@v3
+      with:
+        languages: ${{ matrix.language }}
+        build-mode: ${{ matrix.build-mode }}
+        # If you wish to specify custom queries, you can do so here or in a config file.
+        # By default, queries listed here will override any specified in a config file.
+        # Prefix the list here with "+" to use these queries and those in the config file.
+
+        # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
+        # queries: security-extended,security-and-quality
+
+    # If the analyze step fails for one of the languages you are analyzing with
+    # "We were unable to automatically build your code", modify the matrix above
+    # to set the build mode to "manual" for that language. Then modify this step
+    # to build your code.
+    # ℹ️ Command-line programs to run using the OS shell.
+    # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
+    - if: matrix.build-mode == 'manual'
+      shell: bash
+      run: |
+        echo 'If you are using a "manual" build mode for one or more of the' \
+          'languages you are analyzing, replace this with the commands to build' \
+          'your code, for example:'
+        echo '  make bootstrap'
+        echo '  make release'
+        exit 1
+
+    - name: Perform CodeQL Analysis
+      uses: github/codeql-action/analyze@v3
+      with:
+        category: "/language:${{matrix.language}}"
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 7c7a2b742..d446160c9 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -33,7 +33,32 @@ jobs:
           flake8 medcat
       - name: Test
         run: |
-          timeout 17m python -m unittest discover
+          all_files=$(git ls-files | grep '^tests/.*\.py$' | grep -v '/__init__\.py$' | sed 's/\.py$//' | sed 's/\//./g')
+          num_files=$(echo "$all_files" | wc -l)
+          midpoint=$((num_files / 2))
+          first_half_nl=$(echo "$all_files" | head -n $midpoint)
+          second_half_nl=$(echo "$all_files" | tail -n +$(($midpoint + 1)))
+          timeout 25m python -m unittest ${first_half_nl[@]}
+          timeout 25m python -m unittest ${second_half_nl[@]}
+
+      - name: Get the latest release version
+        id: get_latest_release
+        uses: actions/github-script@v6
+        with:
+          script: |
+            const latestRelease = await github.rest.repos.getLatestRelease({
+              owner: context.repo.owner,
+              repo: context.repo.repo
+            });
+            core.setOutput('latest_version', latestRelease.data.tag_name);
+
+      - name: Make sure there's no deprecated methods that should be removed.
+        # only run this for master -> production PR. I.e just before doing a release.
+        if: github.event.pull_request.base.ref == 'main' && github.event.pull_request.head.ref == 'production'
+        env:
+          VERSION: ${{ steps.get_latest_release.outputs.latest_version }}
+        run: |
+          python tests/check_deprecations.py "$VERSION" --next-version --remove-prefix
 
   publish-to-test-pypi:
 
@@ -43,7 +68,7 @@ jobs:
       github.event_name == 'push' &&
       startsWith(github.ref, 'refs/tags') != true
     runs-on: ubuntu-20.04
-    timeout-minutes: 20
+    timeout-minutes: 45
     concurrency: publish-to-test-pypi
     needs: [build]
 
diff --git a/install_requires.txt b/install_requires.txt
new file mode 100644
index 000000000..da26267aa
--- /dev/null
+++ b/install_requires.txt
@@ -0,0 +1,24 @@
+'numpy>=1.22.0,<1.26.0'  # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy
+'pandas>=1.4.2' # first to support 3.11
+'gensim>=4.3.0,<5.0.0'  # 5.3.0 is first to support 3.11; avoid major version bump
+'spacy>=3.6.0,<4.0.0'  # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump
+'scipy~=1.9.2'  # 1.9.2 is first to support 3.11
+'transformers>=4.34.0,<5.0.0'  # avoid major version bump
+'accelerate>=0.23.0' # required by Trainer class in de-id
+'torch>=1.13.0,<3.0.0' # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now
+'tqdm>=4.27'
+'scikit-learn>=1.1.3,<2.0.0'  # 1.1.3 is first to supporrt 3.11; avoid major version bump
+'dill>=0.3.6,<1.0.0' # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump
+'datasets>=2.2.2,<3.0.0' # avoid major bump
+'jsonpickle>=2.0.0' # allow later versions, tested with 3.0.0
+'psutil>=5.8.0'
+# 0.70.12 uses older version of dill (i.e less than 0.3.5) which is required for datasets
+'multiprocess~=0.70.12'  # 0.70.14 seemed to work just fine
+'aiofiles>=0.8.0' # allow later versions, tested with 22.1.0
+'ipywidgets>=7.6.5' # allow later versions, tested with 0.8.0
+'xxhash>=3.0.0' # allow later versions, tested with 3.1.0
+'blis>=0.7.5' # allow later versions, tested with 0.7.9
+'click>=8.0.4' # allow later versions, tested with 8.1.3
+'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes
+"humanfriendly~=10.0"  # for human readable file / RAM sizes
+"peft>=0.8.2"
\ No newline at end of file
diff --git a/medcat/cat.py b/medcat/cat.py
index 8df7526b7..2d83ccec5 100644
--- a/medcat/cat.py
+++ b/medcat/cat.py
@@ -16,7 +16,6 @@
 from datetime import date
 from tqdm.autonotebook import tqdm, trange
 from spacy.tokens import Span, Doc, Token
-from spacy.language import Language
 import humanfriendly
 
 from medcat import __version__
@@ -37,9 +36,9 @@
 from medcat.utils.meta_cat.data_utils import json_to_fake_spacy
 from medcat.config import Config
 from medcat.vocab import Vocab
-from medcat.utils.decorators import deprecated
 from medcat.ner.transformers_ner import TransformersNER
 from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY
+from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME
 from medcat.stats.stats import get_stats
 from medcat.utils.filters import set_project_filters
 
@@ -49,6 +48,8 @@
 
 HAS_NEW_SPACY = has_new_spacy()
 
+MIN_GEN_LEN_FOR_WARN = 10_000
+
 
 class CAT(object):
     """The main MedCAT class used to annotate documents, it is built on top of spaCy
@@ -144,15 +145,6 @@ def _create_pipeline(self, config: Config):
         # Set max document length
         self.pipe.spacy_nlp.max_length = config.preprocessing.max_document_length
 
-    @deprecated(message="Replaced with cat.pipe.spacy_nlp.")
-    def get_spacy_nlp(self) -> Language:
-        """Returns the spacy pipeline with MedCAT
-
-        Returns:
-            Language: The spacy Language being used.
-        """
-        return self.pipe.spacy_nlp
-
     def get_hash(self, force_recalc: bool = False) -> str:
         """Will not be a deep hash but will try to catch all the changing parts during training.
 
@@ -315,6 +307,12 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
         with open(model_card_path, 'w') as f:
             json.dump(self.get_model_card(as_dict=True), f, indent=2)
 
+        # add a dependency snapshot
+        env_info = get_environment_info()
+        env_info_path = os.path.join(save_dir_path, ENV_SNAPSHOT_FILE_NAME)
+        with open(env_info_path, 'w') as f:
+            json.dump(env_info, f)
+
         # Zip everything
         shutil.make_archive(os.path.join(_save_dir_path, model_pack_name), 'zip', root_dir=save_dir_path)
 
@@ -387,12 +385,7 @@ def load_model_pack(cls,
         model_pack_path = cls.attempt_unpack(zip_path)
 
         # Load the CDB
-        cdb_path = os.path.join(model_pack_path, "cdb.dat")
-        nr_of_jsons_expected = len(SPECIALITY_NAMES) - len(ONE2MANY)
-        has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= nr_of_jsons_expected
-        json_path = model_pack_path if has_jsons else None
-        logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format')
-        cdb = CDB.load(cdb_path, json_path)
+        cdb: CDB = cls.load_cdb(model_pack_path)
 
         # load config
         config_path = os.path.join(model_pack_path, "config.json")
@@ -419,11 +412,9 @@ def load_model_pack(cls,
             addl_ner.append(trf)
 
         # Find metacat models in the model_pack
-        meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else []
-        meta_cats = []
-        for meta_path in meta_paths:
-            meta_cats.append(MetaCAT.load(save_dir_path=meta_path,
-                                          config_dict=meta_cat_config_dict))
+        meta_cats: List[MetaCAT] = []
+        if load_meta_models:
+            meta_cats = [mc[1] for mc in cls.load_meta_cats(model_pack_path, meta_cat_config_dict)]
 
         # Find Rel models in model_pack
         rel_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('rel_')] if load_rel_models else []
@@ -436,6 +427,47 @@ def load_model_pack(cls,
 
         return cat
 
+    @classmethod
+    def load_cdb(cls, model_pack_path: str) -> CDB:
+        """
+        Loads the concept database from the provided model pack path
+
+        Args:
+            model_pack_path (str): path to model pack, zip or dir.
+
+        Returns:
+            CDB: The loaded concept database
+        """
+        cdb_path = os.path.join(model_pack_path, "cdb.dat")
+        nr_of_jsons_expected = len(SPECIALITY_NAMES) - len(ONE2MANY)
+        has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= nr_of_jsons_expected
+        json_path = model_pack_path if has_jsons else None
+        logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format')
+        cdb = CDB.load(cdb_path, json_path)
+        return cdb
+
+    @classmethod
+    def load_meta_cats(cls, model_pack_path: str, meta_cat_config_dict: Optional[Dict] = None) -> List[Tuple[str, MetaCAT]]:
+        """
+
+        Args:
+            model_pack_path (str): path to model pack, zip or dir.
+            meta_cat_config_dict (Optional[Dict]):
+                A config dict that will overwrite existing configs in meta_cat.
+                e.g. meta_cat_config_dict = {'general': {'device': 'cpu'}}.
+                Defaults to None.
+
+        Returns:
+            List[Tuple(str, MetaCAT)]: list of pairs of meta cat model names (i.e. the task name) and the MetaCAT models.
+        """
+        meta_paths = [os.path.join(model_pack_path, path)
+                      for path in os.listdir(model_pack_path) if path.startswith('meta_')]
+        meta_cats = []
+        for meta_path in meta_paths:
+            meta_cats.append(MetaCAT.load(save_dir_path=meta_path,
+                                          config_dict=meta_cat_config_dict))
+        return list(zip(meta_paths, meta_cats))
+
     def __call__(self, text: Optional[str], do_train: bool = False) -> Optional[Doc]:
         """Push the text through the pipeline.
 
@@ -645,13 +677,16 @@ def unlink_concept_name(self, cui: str, name: str, preprocessed_name: bool = Fal
             names = prepare_name(name, self.pipe.spacy_nlp, {}, self.config)
 
         # If full unlink find all CUIs
-        if self.config.general.get('full_unlink', False):
+        if self.config.general.full_unlink:
+            logger.warning("In the config `full_unlink` is set to `True`. "
+                           "Thus removing all CUIs linked to the specified name"
+                           " (%s)", name)
             for n in names:
                 cuis.extend(self.cdb.name2cuis.get(n, []))
 
         # Remove name from all CUIs
         for c in cuis:
-            self.cdb.remove_names(cui=c, names=names)
+            self.cdb._remove_names(cui=c, names=names.keys())
 
     def add_and_train_concept(self,
                               cui: str,
@@ -725,42 +760,6 @@ def add_and_train_concept(self,
                 for _cui in cuis:
                     self.linker.context_model.train(cui=_cui, entity=spacy_entity, doc=spacy_doc, negative=True)  # type: ignore
 
-    @deprecated(message="Use train_supervised_from_json to train based on data "
-                "loaded from a json file")
-    def train_supervised(self,
-                         data_path: str,
-                         reset_cui_count: bool = False,
-                         nepochs: int = 1,
-                         print_stats: int = 0,
-                         use_filters: bool = False,
-                         terminate_last: bool = False,
-                         use_overlaps: bool = False,
-                         use_cui_doc_limit: bool = False,
-                         test_size: int = 0,
-                         devalue_others: bool = False,
-                         use_groups: bool = False,
-                         never_terminate: bool = False,
-                         train_from_false_positives: bool = False,
-                         extra_cui_filter: Optional[Set] = None,
-                         retain_extra_cui_filter: bool = False,
-                         checkpoint: Optional[Checkpoint] = None,
-                         retain_filters: bool = False,
-                         is_resumed: bool = False) -> Tuple:
-        """Train supervised by reading data from a json file.
-
-        Refer to `train_supervvised_from_json` and/or `train_supervised_raw`
-        for further details.
-
-        # noqa: DAR101
-        # noqa: DAR201
-        """
-        return self.train_supervised_from_json(data_path, reset_cui_count, nepochs,
-                                               print_stats, use_filters, terminate_last,
-                                               use_overlaps, use_cui_doc_limit, test_size,
-                                               devalue_others, use_groups, never_terminate,
-                                               train_from_false_positives, extra_cui_filter,
-                                               retain_extra_cui_filter, checkpoint,
-                                               retain_filters, is_resumed)
 
     def train_supervised_from_json(self,
                                    data_path: str,
@@ -1226,25 +1225,6 @@ def _save_docs_to_file(self, docs: Iterable, annotated_ids: List[str], save_dir_
             pickle.dump((annotated_ids, part_counter), open(annotated_ids_path, 'wb'))
         return part_counter
 
-    @deprecated(message="Use `multiprocessing_batch_char_size` instead")
-    def multiprocessing(self,
-                        data: Union[List[Tuple], Iterable[Tuple]],
-                        nproc: int = 2,
-                        batch_size_chars: int = 5000 * 1000,
-                        only_cui: bool = False,
-                        addl_info: List[str] = ['cui2icd10', 'cui2ontologies', 'cui2snomed'],
-                        separate_nn_components: bool = True,
-                        out_split_size_chars: Optional[int] = None,
-                        save_dir_path: str = os.path.abspath(os.getcwd()),
-                        min_free_memory=0.1) -> Dict:
-        return self.multiprocessing_batch_char_size(data=data, nproc=nproc,
-                                                    batch_size_chars=batch_size_chars,
-                                                    only_cui=only_cui, addl_info=addl_info,
-                                                    separate_nn_components=separate_nn_components,
-                                                    out_split_size_chars=out_split_size_chars,
-                                                    save_dir_path=save_dir_path,
-                                                    min_free_memory=min_free_memory)
-
     def multiprocessing_batch_char_size(self,
                                         data: Union[List[Tuple], Iterable[Tuple]],
                                         nproc: int = 2,
@@ -1499,21 +1479,6 @@ def _multiprocessing_batch(self,
 
         return docs
 
-    @deprecated(message="Use `multiprocessing_batch_docs_size` instead")
-    def multiprocessing_pipe(self, in_data: Union[List[Tuple], Iterable[Tuple]],
-                             nproc: Optional[int] = None,
-                             batch_size: Optional[int] = None,
-                             only_cui: bool = False,
-                             addl_info: List[str] = [],
-                             return_dict: bool = True,
-                             batch_factor: int = 2) -> Union[List[Tuple], Dict]:
-        return self.multiprocessing_batch_docs_size(in_data=in_data, nproc=nproc,
-                                                    batch_size=batch_size,
-                                                    only_cui=only_cui,
-                                                    addl_info=addl_info,
-                                                    return_dict=return_dict,
-                                                    batch_factor=batch_factor)
-
     def multiprocessing_batch_docs_size(self,
                                         in_data: Union[List[Tuple], Iterable[Tuple]],
                                         nproc: Optional[int] = None,
@@ -1526,6 +1491,11 @@ def multiprocessing_batch_docs_size(self,
 
         This method batches the data based on the number of documents as specified by the user.
 
+        NOTE: When providing a generator for `data`, the generator is evaluated (`list(in_data)`)
+              and thus all the data is kept in memory and (potentially) duplicated for use in
+              multiple threads. So if you're using a lot of data, it may be better to use
+              `CAT.multiprocessing_batch_char_size` instead.
+
         PS:
         This method supports Windows.
 
@@ -1550,7 +1520,20 @@ def multiprocessing_batch_docs_size(self,
         if nproc == 0:
             raise ValueError("nproc cannot be set to zero")
 
-        in_data = list(in_data) if isinstance(in_data, Iterable) else in_data
+        # TODO: Surely there's a way to not materialise all of the incoming data in memory?
+        #       This is counter productive for allowing the passing of generators.
+        if isinstance(in_data, Iterable):
+            in_data = list(in_data)
+            in_data_len = len(in_data)
+            if in_data_len > MIN_GEN_LEN_FOR_WARN:
+                # only point this out when it's relevant, i.e over 10k items
+                logger.warning("The `CAT.multiprocessing_batch_docs_size` method just "
+                               f"materialised {in_data_len} items from the generator it "
+                               "was provided. This may use up a considerable amount of "
+                               "RAM, especially since the data may be duplicated across "
+                               "multiple threads when multiprocessing is used. If the "
+                               "process is kiled after this warning, please use the "
+                               "alternative method `multiprocessing_batch_char_size` instead")
         n_process = nproc if nproc is not None else min(max(cpu_count() - 1, 1), math.ceil(len(in_data) / batch_factor))
         batch_size = batch_size if batch_size is not None else math.ceil(len(in_data) / (batch_factor * abs(n_process)))
 
diff --git a/medcat/cdb.py b/medcat/cdb.py
index 6ae15d3f5..e63843364 100644
--- a/medcat/cdb.py
+++ b/medcat/cdb.py
@@ -5,17 +5,20 @@
 import logging
 import aiofiles
 import numpy as np
-from typing import Dict, Set, Optional, List, Union, cast
-from functools import partial
+from typing import Dict, Set, Optional, List, Union, cast, Iterable
 import os
 
 from medcat import __version__
 from medcat.utils.hasher import Hasher
 from medcat.utils.matutils import unitvec
 from medcat.utils.ml_utils import get_lr_linking
+from medcat.config import Config, workers
 from medcat.utils.decorators import deprecated
-from medcat.config import Config, weighted_average, workers
 from medcat.utils.saving.serializer import CDBSerializer
+from medcat.utils.config_utils import get_and_del_weighted_average_from_config
+from medcat.utils.config_utils import default_weighted_average
+from medcat.utils.config_utils import ensure_backward_compatibility
+from medcat.utils.config_utils import fix_waf_lambda, attempt_fix_weighted_average_function
 
 
 logger = logging.getLogger(__name__)
@@ -98,6 +101,7 @@ def __init__(self, config: Union[Config, None] = None) -> None:
         self.vocab: Dict = {} # Vocabulary of all words ever in our cdb
         self._optim_params = None
         self.is_dirty = False
+        self._init_waf_from_config()
         self._hash: Optional[str] = None
         # the config hash is kept track of here so that
         # the CDB hash can be re-calculated when the config changes
@@ -107,6 +111,18 @@ def __init__(self, config: Union[Config, None] = None) -> None:
         self._config_hash: Optional[str] = None
         self._memory_optimised_parts: Set[str] = set()
 
+    def _init_waf_from_config(self):
+        waf = get_and_del_weighted_average_from_config(self.config)
+        if waf is not None:
+            logger.info("Using (potentially) custom value of weighed "
+                        "average function")
+            self.weighted_average_function = attempt_fix_weighted_average_function(waf)
+        elif hasattr(self, 'weighted_average_function'):
+            # keep existing
+            pass
+        else:
+            self.weighted_average_function = default_weighted_average
+
     def get_name(self, cui: str) -> str:
         """Returns preferred name if it exists, otherwise it will return
         the longest name assigned to the concept.
@@ -132,7 +148,12 @@ def update_cui2average_confidence(self, cui: str, new_sim: float) -> None:
                                             (self.cui2count_train.get(cui, 0) + 1)
         self.is_dirty = True
 
-    def remove_names(self, cui: str, names: Dict[str, Dict]) -> None:
+    @deprecated("Deprecated. For internal use only. Use CAT.unlink_concept_name instead",
+                depr_version=(1, 12, 0), removal_version=(1, 13, 0))
+    def remove_names(self, cui: str, names: Iterable[str]) -> None:
+        self._remove_names(cui, names)
+
+    def _remove_names(self, cui: str, names: Iterable[str]) -> None:
         """Remove names from an existing concept - effect is this name will never again be used to link to this concept.
         This will only remove the name from the linker (namely name2cuis and name2cuis2status), the name will still be present everywhere else.
         Why? Because it is bothersome to remove it from everywhere, but
@@ -141,10 +162,10 @@ def remove_names(self, cui: str, names: Dict[str, Dict]) -> None:
         Args:
             cui (str):
                 Concept ID or unique identifer in this database.
-            names (Dict[str, Dict]):
-                Names to be removed, should look like: `{'name': {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}`
+            names (Iterable[str]):
+                Names to be removed (e.g list, set, or even a dict (in which case keys will be used)).
         """
-        for name in names.keys():
+        for name in names:
             if name in self.name2cuis:
                 if cui in self.name2cuis[name]:
                     self.name2cuis[name].remove(cui)
@@ -231,44 +252,6 @@ def add_names(self, cui: str, names: Dict[str, Dict], name_status: str = 'A', fu
 
         self._add_concept(cui=cui, names=names, ontologies=set(), name_status=name_status, type_ids=set(), description='', full_build=full_build)
 
-    @deprecated("Use `cdb._add_concept` as this will be removed in a future release.")
-    def add_concept(self,
-                    cui: str,
-                    names: Dict[str, Dict],
-                    ontologies: Set[str],
-                    name_status: str,
-                    type_ids: Set[str],
-                    description: str,
-                    full_build: bool = False) -> None:
-        """
-        Deprecated: Use `cdb._add_concept` as this will be removed in a future release.
-
-        Add a concept to internal Concept Database (CDB). Depending on what you are providing
-        this will add a large number of properties for each concept.
-
-        Args:
-            cui (str):
-                Concept ID or unique identifier in this database, all concepts that have
-                the same CUI will be merged internally.
-            names (Dict[str, Dict]):
-                Names for this concept, or the value that if found in free text can be linked to this concept.
-                Names is a dict like: `{name: {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}`
-                Names should be generated by helper function 'medcat.preprocessing.cleaners.prepare_name'
-            ontologies (Set[str]):
-                ontologies in which the concept exists (e.g. SNOMEDCT, HPO)
-            name_status (str):
-                One of `P`, `N`, `A`
-            type_ids (Set[str]):
-                Semantic type identifier (have a look at TUIs in UMLS or SNOMED-CT)
-            description (str):
-                Description of this concept.
-            full_build (bool):
-                If True the dictionary self.addl_info will also be populated, contains a lot of extra information
-                about concepts, but can be very memory consuming. This is not necessary
-                for normal functioning of MedCAT (Default Value `False`).
-        """
-        self._add_concept(cui, names, ontologies, name_status, type_ids, description, full_build)
-
     def _add_concept(self,
                     cui: str,
                     names: Dict[str, Dict],
@@ -558,6 +541,8 @@ def load_config(self, config_path: str) -> None:
             # this should be the behaviour for all newer models
             self.config = cast(Config, Config.load(config_path))
             logger.debug("Loaded config from CDB from %s", config_path)
+            # new config, potentially new weighted_average_function to read
+            self._init_waf_from_config()
         # mark config read from file
         self._config_from_file = True
 
@@ -582,7 +567,8 @@ def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[
         ser = CDBSerializer(path, json_path)
         cdb = ser.deserialize(CDB)
         cls._check_medcat_version(cdb.config.asdict())
-        cls._ensure_backward_compatibility(cdb.config)
+        fix_waf_lambda(cdb)
+        ensure_backward_compatibility(cdb.config, workers)
 
         # Overwrite the config with new data
         if config_dict is not None:
@@ -855,19 +841,6 @@ def most_similar(self,
 
         return res
 
-    @staticmethod
-    def _ensure_backward_compatibility(config: Config) -> None:
-        # Hacky way of supporting old CDBs
-        weighted_average_function = config.linking.weighted_average_function
-        if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "":
-            # the following type ignoring is for mypy because it is unable to detect the signature
-            config.linking.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore
-        if config.general.workers is None:
-            config.general.workers = workers()
-        disabled_comps = config.general.spacy_disabled_components
-        if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps:
-            config.general.spacy_disabled_components.append('lemmatizer')
-
     @classmethod
     def _check_medcat_version(cls, config_data: Dict) -> None:
         cdb_medcat_version = config_data.get('version', {}).get('medcat_version', None)
diff --git a/medcat/config.py b/medcat/config.py
index 88c4aad14..cdb30d0fe 100644
--- a/medcat/config.py
+++ b/medcat/config.py
@@ -1,17 +1,19 @@
 from datetime import datetime
 from pydantic import BaseModel, Extra, ValidationError
 from pydantic.fields import ModelField
-from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union
+from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union, Type
 from multiprocessing import cpu_count
 import logging
 import jsonpickle
+import json
 from functools import partial
 import re
 
 from medcat.utils.hasher import Hasher
 from medcat.utils.matutils import intersect_nonempty_set
 from medcat.utils.config_utils import attempt_fix_weighted_average_function
-from medcat.utils.config_utils import weighted_average
+from medcat.utils.config_utils import weighted_average, is_old_type_config_dict
+from medcat.utils.saving.coding import CustomDelegatingEncoder, default_hook
 
 
 logger = logging.getLogger(__name__)
@@ -31,6 +33,7 @@ def __getitem__(self, arg: str) -> Any:
             raise KeyError from e
 
     def __setattr__(self, arg: str, val) -> None:
+        # TODO: remove this in the future when we stop stupporting this in config
         if isinstance(self, Linking) and arg == "weighted_average_function":
             val = attempt_fix_weighted_average_function(val)
         super().__setattr__(arg, val)
@@ -103,8 +106,8 @@ def save(self, save_path: str) -> None:
             save_path(str): Where to save the created json file
         """
         # We want to save the dict here, not the whole class
-        json_string = jsonpickle.encode(
-            {field: getattr(self, field) for field in self.fields()})
+        json_string = json.dumps(self.asdict(), cls=cast(Type[json.JSONEncoder],
+                                                         CustomDelegatingEncoder.def_inst))
 
         with open(save_path, 'w') as f:
             f.write(json_string)
@@ -204,7 +207,12 @@ def load(cls, save_path: str) -> "MixingConfig":
 
         # Read the jsonpickle string
         with open(save_path) as f:
-            config_dict = jsonpickle.decode(f.read())
+            config_dict = json.load(f, object_hook=default_hook)
+        if is_old_type_config_dict(config_dict):
+            logger.warning("Loading an old type of config (jsonpickle) from '%s'",
+                            save_path)
+            with open(save_path) as f:
+                config_dict = jsonpickle.decode(f.read())
 
         config.merge_config(config_dict)
 
@@ -511,9 +519,6 @@ class Linking(MixingConfig, BaseModel):
     similarity calculation and will have a similarity of -1."""
     always_calculate_similarity: bool = False
     """Do we want to calculate context similarity even for concepts that are not ambigous."""
-    weighted_average_function: Callable[..., Any] = _DEFAULT_PARTIAL
-    """Weights for a weighted average
-    'weighted_average_function': partial(weighted_average, factor=0.02),"""
     calculate_dynamic_threshold: bool = False
     """Concepts below this similarity will be ignored. Type can be static/dynamic - if dynamic each CUI has a different TH
     and it is calcualted as the average confidence for that CUI * similarity_threshold. Take care that dynamic works only
@@ -597,3 +602,39 @@ def get_hash(self):
                         hasher.update(v2, length=True)
         self.hash = hasher.hexdigest()
         return self.hash
+
+
+class UseOfOldConfigOptionException(AttributeError):
+
+    def __init__(self, conf_type: Type[FakeDict], arg_name: str, advice: str) -> None:
+        super().__init__(f"Tried to use {conf_type.__name__}.{arg_name}. "
+                         f"Advice: {advice}")
+        self.conf_type = conf_type
+        self.arg_name = arg_name
+        self.advice = advice
+
+
+# NOTE: The following is for backwards compatibility and should be removed
+#       at some point in the future
+
+# wrapper for functions for a better error in case of weighted_average_function
+# access
+def _wrapper(func, check_type: Type[FakeDict], advice: str, exp_type: Type[Exception]):
+    def wrapper(*args, **kwargs):
+        try:
+            res = func(*args, **kwargs)
+        except exp_type as ex:
+            if ((len(args) == 2 and len(kwargs) == 0) and
+                    (isinstance(args[0], check_type) and
+                    args[1] == "weighted_average_function")):
+                raise UseOfOldConfigOptionException(Linking, args[1], advice) from ex
+            raise ex
+        return res
+    return wrapper
+
+
+# wrap Linking.__getattribute__ so that when getting weighted_average_function
+# we get a nicer exceptio
+_waf_advice = "You can use `cat.cdb.weighted_average_function` to access it directly"
+Linking.__getattribute__ = _wrapper(Linking.__getattribute__, Linking, _waf_advice, AttributeError)  # type: ignore
+Linking.__getitem__ = _wrapper(Linking.__getitem__, Linking, _waf_advice, KeyError)  # type: ignore
diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py
index 6ddd71d56..686029052 100644
--- a/medcat/config_meta_cat.py
+++ b/medcat/config_meta_cat.py
@@ -1,5 +1,4 @@
 from typing import Dict, Any
-
 from medcat.config import MixingConfig, BaseModel, Optional, Extra
 
 
@@ -49,10 +48,20 @@ class Config:
 class Model(MixingConfig, BaseModel):
     """The model part of the metaCAT config"""
     model_name: str = 'lstm'
+    """NOTE: When changing model, make sure to change the tokenizer as well"""
+    model_variant: str = 'bert-base-uncased'
+    model_freeze_layers: bool = True
     num_layers: int = 2
     input_size: int = 300
     hidden_size: int = 300
     dropout: float = 0.5
+    phase_number: int = 0
+    """Indicates whether or not two phase learning is being performed.
+    1: Phase 1 - Train model on undersampled data
+    2: Phase 2 - Continue training on full data
+    0: None - 2 phase learning is not performed"""
+    category_undersample: str = ''
+    model_architecture_config: Dict = {'fc2': True, 'fc3': False,'lr_scheduler': True}
     num_directions: int = 2
     """2 - bidirectional model, 1 - unidirectional"""
     nclasses: int = 2
@@ -61,7 +70,7 @@ class Model(MixingConfig, BaseModel):
     emb_grad: bool = True
     """If True the embeddings will also be trained"""
     ignore_cpos: bool = False
-    """If set to True center positions will be ignored when calculating represenation"""
+    """If set to True center positions will be ignored when calculating representation"""
 
     class Config:
         extra = Extra.allow
@@ -77,6 +86,8 @@ class Train(MixingConfig, BaseModel):
     shuffle_data: bool = True
     """Used only during training, if set the dataset will be shuffled before train/test split"""
     class_weights: Optional[Any] = None
+    compute_class_weights: bool = False
+    """If true and if class weights are not provided, the class weights will be calculated based on the data"""
     score_average: str = 'weighted'
     """What to use for averaging F1/P/R across labels"""
     prerequisites: dict = {}
@@ -88,6 +99,10 @@ class Train(MixingConfig, BaseModel):
     """When was the last training run"""
     metric: Dict[str, str] = {'base': 'weighted avg', 'score': 'f1-score'}
     """What metric should be used for choosing the best model"""
+    loss_funct: str = 'cross_entropy'
+    """Loss function for the model"""
+    gamma: int = 2
+    """Focal Loss - how much the loss focuses on hard-to-classify examples."""
 
     class Config:
         extra = Extra.allow
diff --git a/medcat/linking/vector_context_model.py b/medcat/linking/vector_context_model.py
index 7c4c11a69..e4875c32f 100644
--- a/medcat/linking/vector_context_model.py
+++ b/medcat/linking/vector_context_model.py
@@ -71,7 +71,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict:
 
             values = []
             # Add left
-            values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_)
+            values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_)
                            for step, tkn in enumerate(tokens_left) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None])
 
             if not self.config.linking['context_ignore_center_tokens']:
@@ -83,7 +83,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict:
                     values.extend([self.vocab.vec(tkn.lower_) for tkn in tokens_center if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None])
 
             # Add right
-            values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_)
+            values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_)
                            for step, tkn in enumerate(tokens_right) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None])
 
             if len(values) > 0:
diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py
index 78fcd9982..349b848ed 100644
--- a/medcat/meta_cat.py
+++ b/medcat/meta_cat.py
@@ -11,18 +11,17 @@
 from medcat.utils.hasher import Hasher
 from medcat.config_meta_cat import ConfigMetaCAT
 from medcat.utils.meta_cat.ml_utils import predict, train_model, set_all_seeds, eval_model
-from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values
+from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values, prepare_for_oversampled_data
 from medcat.pipeline.pipe_runner import PipeRunner
 from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
 from medcat.utils.meta_cat.data_utils import Doc as FakeDoc
-from medcat.utils.decorators import deprecated
+from peft import get_peft_model, LoraConfig, TaskType
 
 # It should be safe to do this always, as all other multiprocessing
 # will be finished before data comes to meta_cat
 os.environ["TOKENIZERS_PARALLELISM"] = "true"
 
-
-logger = logging.getLogger(__name__) # separate logger from the package-level one
+logger = logging.getLogger(__name__)  # separate logger from the package-level one
 
 
 class MetaCAT(PipeRunner):
@@ -77,7 +76,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
                 The embedding densor
 
         Raises:
-            ValueError: If the meta model is not LSTM
+            ValueError: If the meta model is not LSTM or BERT
 
         Returns:
             nn.Module:
@@ -86,7 +85,22 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
         config = self.config
         if config.model['model_name'] == 'lstm':
             from medcat.utils.meta_cat.models import LSTM
-            model = LSTM(embeddings, config)
+            model: nn.Module = LSTM(embeddings, config)
+            logger.info("LSTM model used for classification")
+
+        elif config.model['model_name'] == 'bert':
+            from medcat.utils.meta_cat.models import BertForMetaAnnotation
+            model = BertForMetaAnnotation(config)
+
+            if not config.model.model_freeze_layers:
+                peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16,
+                                         target_modules=["query", "value"], lora_dropout=0.2)
+
+                model = get_peft_model(model, peft_config)
+                # model.print_trainable_parameters()
+
+            logger.info("BERT model used for classification")
+
         else:
             raise ValueError("Unknown model name %s" % config.model['model_name'])
 
@@ -106,24 +120,8 @@ def get_hash(self) -> str:
         hasher.update(self.config.get_hash())
         return hasher.hexdigest()
 
-    @deprecated(message="Use `train_from_json` or `train_raw` instead")
-    def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict:
-        """Train or continue training a model give a json_path containing a MedCATtrainer export. It will
-        continue training if an existing model is loaded or start new training if the model is blank/new.
-
-        Args:
-            json_path (Union[str, list]):
-                Path/Paths to a MedCATtrainer export containing the meta_annotations we want to train for.
-            save_dir_path (Optional[str]):
-                In case we have aut_save_model (meaning during the training the best model will be saved)
-                we need to set a save path. Defaults to `None`.
-
-        Returns:
-            Dict: The resulting report.
-        """
-        return self.train_from_json(json_path, save_dir_path)
-
-    def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict:
+    def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None,
+                        data_oversampled: Optional[list] = None) -> Dict:
         """Train or continue training a model give a json_path containing a MedCATtrainer export. It will
         continue training if an existing model is loaded or start new training if the model is blank/new.
 
@@ -133,6 +131,8 @@ def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[s
             save_dir_path (Optional[str]):
                 In case we have aut_save_model (meaning during the training the best model will be saved)
                 we need to set a save path. Defaults to `None`.
+            data_oversampled (Optional[list]):
+                In case of oversampling being performed, the data will be passed in the parameter
 
         Returns:
             Dict: The resulting report.
@@ -157,9 +157,9 @@ def merge_data_loaded(base, other):
         for path in json_path:
             with open(path, 'r') as f:
                 data_loaded = merge_data_loaded(data_loaded, json.load(f))
-        return self.train_raw(data_loaded, save_dir_path)
+        return self.train_raw(data_loaded, save_dir_path, data_oversampled=data_oversampled)
 
-    def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> Dict:
+    def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data_oversampled: Optional[list] = None) -> Dict:
         """Train or continue training a model given raw data. It will
         continue training if an existing model is loaded or start new training if the model is blank/new.
 
@@ -187,6 +187,10 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D
             save_dir_path (Optional[str]):
                 In case we have aut_save_model (meaning during the training the best model will be saved)
                 we need to set a save path. Defaults to `None`.
+            data_oversampled (Optional[list]):
+                In case of oversampling being performed, the data will be passed in the parameter
+                The format of which is expected: [[['text','of','the','document'], [index of medical entity], "label" ],
+                ['text','of','the','document'], [index of medical entity], "label" ]]
 
         Returns:
             Dict: The resulting report.
@@ -194,6 +198,8 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D
         Raises:
             Exception: If no save path is specified, or category name not in data.
             AssertionError: If no tokeniser is set
+            FileNotFoundError: If phase_number is set to 2 and model.dat file is not found
+            KeyError: If phase_number is set to 2 and model.dat file contains mismatched architecture
         """
         g_config = self.config.general
         t_config = self.config.train
@@ -212,7 +218,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D
                                  replace_center=g_config['replace_center'], prerequisites=t_config['prerequisites'],
                                  lowercase=g_config['lowercase'])
 
-        # Check is the name there
+        # Check is the name present
         category_name = g_config['category_name']
         if category_name not in data:
             raise Exception(
@@ -220,16 +226,22 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D
                     category_name, " | ".join(list(data.keys()))))
 
         data = data[category_name]
+        if data_oversampled:
+            data_sampled = prepare_for_oversampled_data(data_oversampled, self.tokenizer)
+            data = data + data_sampled
 
         category_value2id = g_config['category_value2id']
         if not category_value2id:
             # Encode the category values
-            data, category_value2id = encode_category_values(data)
+            data_undersampled, full_data, category_value2id = encode_category_values(data,
+                                                                                     category_undersample=self.config.model.category_undersample)
             g_config['category_value2id'] = category_value2id
         else:
             # We already have everything, just get the data
-            data, _ = encode_category_values(data, existing_category_value2id=category_value2id)
-
+            data_undersampled, full_data, category_value2id = encode_category_values(data,
+                                                                                     existing_category_value2id=category_value2id,
+                                                                                     category_undersample=self.config.model.category_undersample)
+            g_config['category_value2id'] = category_value2id
         # Make sure the config number of classes is the same as the one found in the data
         if len(category_value2id) != self.config.model['nclasses']:
             logger.warning(
@@ -237,7 +249,29 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D
                     self.config.model['nclasses'], len(category_value2id)))
             logger.warning("Auto-setting the nclasses value in config and rebuilding the model.")
             self.config.model['nclasses'] = len(category_value2id)
-            self.model = self.get_model(embeddings=self.embeddings)
+
+        if self.config.model.phase_number == 2 and save_dir_path is not None:
+            model_save_path = os.path.join(save_dir_path, 'model.dat')
+            device = torch.device(g_config['device'])
+            try:
+                self.model.load_state_dict(torch.load(model_save_path, map_location=device))
+                logger.info("Model state loaded from dict for 2 phase learning")
+
+            except FileNotFoundError:
+                raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.")
+
+            except KeyError:
+                raise KeyError("\nError: Missing key in loaded state dictionary. \nThis might be due to a mismatch between the model architecture and the saved state.")
+
+            except Exception as e:
+                raise Exception(f"\nError: Model state cannot be loaded from dict. {e}")
+
+        data = full_data
+        if self.config.model.phase_number == 1:
+            data = data_undersampled
+            if not t_config['auto_save_model']:
+                logger.info("For phase 1, model state has to be saved. Saving model...")
+                t_config['auto_save_model'] = True
 
         report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path)
 
@@ -293,7 +327,7 @@ def eval(self, json_path: str) -> Dict:
 
         # We already have everything, just get the data
         category_value2id = g_config['category_value2id']
-        data, _ = encode_category_values(data, existing_category_value2id=category_value2id)
+        data, _, _ = encode_category_values(data, existing_category_value2id=category_value2id)
 
         # Run evaluation
         assert self.tokenizer is not None
@@ -317,8 +351,8 @@ def save(self, save_dir_path: str) -> None:
         # Save tokenizer
         assert self.tokenizer is not None
         self.tokenizer.save(save_dir_path)
-
         # Save config
+
         self.config.save(os.path.join(save_dir_path, 'config.json'))
 
         # Save the model
@@ -347,7 +381,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
         # Load config
         config = cast(ConfigMetaCAT, ConfigMetaCAT.load(os.path.join(save_dir_path, 'config.json')))
 
-        # Overwrite loaded paramters with something new
+        # Overwrite loaded parameters with something new
         if config_dict is not None:
             config.merge_config(config_dict)
 
@@ -358,7 +392,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
             tokenizer = TokenizerWrapperBPE.load(save_dir_path)
         elif config.general['tokenizer_name'] == 'bert-tokenizer':
             from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
-            tokenizer = TokenizerWrapperBERT.load(save_dir_path)
+            tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model['model_variant'])
 
         # Create meta_cat
         meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config)
@@ -380,7 +414,8 @@ def get_ents(self, doc: Doc) -> Iterable[Span]:
             try:
                 return doc.spans[spangroup_name]
             except KeyError:
-                raise Exception(f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.")
+                raise Exception(
+                    f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.")
 
         # Should we annotate overlapping entities
         if self.config.general['annotate_overlapping']:
@@ -421,18 +456,26 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
             start = ent.start_char
             end = ent.end_char
 
-            ind = 0
-            # Start where the last ent was found, cannot be before it as we've sorted
+            # Updated implementation to extract all the tokens for the medical entity (rather than the one)
+            ctoken_idx = []
             for ind, pair in enumerate(offset_mapping[last_ind:]):
-                if start >= pair[0] and start < pair[1]:
-                    break
-            ind = last_ind + ind # If we did not start from 0 in the for loop
-            last_ind = ind
+                # Checking if we've reached at the start of the entity
+                if start <= pair[0] or start <= pair[1]:
+                    if end <= pair[1]:
+                        ctoken_idx.append(ind) # End reached
+                        break
+                    else:
+                        ctoken_idx.append(ind) # Keep going
+
+            # Start where the last ent was found, cannot be before it as we've sorted
+            last_ind += ind  # If we did not start from 0 in the for loop
+
+            _start = max(0, ctoken_idx[0] - cntx_left)
+            _end = min(len(input_ids), ctoken_idx[-1] + 1 + cntx_right)
 
-            _start = max(0, ind - cntx_left)
-            _end = min(len(input_ids), ind + 1 + cntx_right)
             tkns = input_ids[_start:_end]
             cpos = cntx_left + min(0, ind - cntx_left)
+            cpos_new = [x - _start for x in ctoken_idx]
 
             if replace_center is not None:
                 if lowercase:
@@ -447,8 +490,7 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
                 ln = e_ind - s_ind  # Length of the concept in tokens
                 assert self.tokenizer is not None
                 tkns = tkns[:cpos] + self.tokenizer(replace_center)['input_ids'] + tkns[cpos + ln + 1:]
-
-            samples.append([tkns, cpos])
+            samples.append([tkns, cpos_new])
             ent_id2ind[ent._.id] = len(samples) - 1
 
         return ent_id2ind, samples
@@ -544,7 +586,6 @@ def _set_meta_anns(self,
                     for i, doc in enumerate(docs):
                         data.extend(doc._.share_tokens[0])
                         doc_ind2positions[i] = doc._.share_tokens[1]
-
                 all_predictions, all_confidences = predict(self.model, data, config)
                 for i, doc in enumerate(docs):
                     start_ind, end_ind, ent_id2ind = doc_ind2positions[i]
diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py
index 7aabceda2..9d4700df9 100644
--- a/medcat/ner/transformers_ner.py
+++ b/medcat/ner/transformers_ner.py
@@ -4,8 +4,10 @@
 import datasets
 from spacy.tokens import Doc
 from datetime import datetime
-from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple
+from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable
 from spacy.tokens import Span
+import inspect
+from functools import partial
 
 from medcat.cdb import CDB
 from medcat.utils.meta_cat.ml_utils import set_all_seeds
@@ -178,10 +180,21 @@ def train(self,
             json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels,
                                               meta_requirements=meta_requirements, file_name='data_eval.json')
             # Load dataset
-            dataset = datasets.load_dataset(os.path.abspath(transformers_ner.__file__),
-                                            data_files={'train': json_path}, # type: ignore
-                                            split='train',
-                                            cache_dir='/tmp/')
+
+            # NOTE: The following is for backwards comppatibility
+            #       in datasets==2.20.0 `trust_remote_code=True` must be explicitly
+            #       specified, otherwise an error is raised.
+            #       On the other hand, the keyword argumnet was added in datasets==2.16.0
+            #       yet we support datasets>=2.2.0.
+            #       So we need to use the kwarg if applicable and omit its use otherwise.
+            if func_has_kwarg(datasets.load_dataset, 'trust_remote_code'):
+                ds_load_dataset = partial(datasets.load_dataset, trust_remote_code=True)
+            else:
+                ds_load_dataset = datasets.load_dataset
+            dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
+                                      data_files={'train': json_path}, # type: ignore
+                                      split='train',
+                                      cache_dir='/tmp/')
             # We split before encoding so the split is document level, as encoding
             #does the document spliting into max_seq_len
             dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore
@@ -422,3 +435,9 @@ def __call__(self, doc: Doc) -> Doc:
         doc = next(self.pipe(iter([doc])))
 
         return doc
+
+
+# NOTE: Only needed for datasets backwards compatibility
+def func_has_kwarg(func: Callable, keyword: str):
+    sig = inspect.signature(func)
+    return keyword in sig.parameters
diff --git a/medcat/stats/kfold.py b/medcat/stats/kfold.py
new file mode 100644
index 000000000..491173c23
--- /dev/null
+++ b/medcat/stats/kfold.py
@@ -0,0 +1,436 @@
+from typing import Protocol, Tuple, List, Dict, Optional, Set, Iterable, Callable, cast, Any
+
+from abc import ABC, abstractmethod
+from enum import Enum, auto
+from copy import deepcopy
+
+import numpy as np
+
+from medcat.utils.checkpoint import Checkpoint
+from medcat.utils.cdb_state import captured_state_cdb
+
+from medcat.stats.stats import get_stats
+from medcat.stats.mctexport import MedCATTrainerExport, MedCATTrainerExportProject
+from medcat.stats.mctexport import MedCATTrainerExportDocument, MedCATTrainerExportAnnotation
+from medcat.stats.mctexport import count_all_annotations, count_all_docs, get_nr_of_annotations
+from medcat.stats.mctexport import iter_anns, iter_docs, MedCATTrainerExportProjectInfo
+
+
+
+class CDBLike(Protocol):
+    pass
+
+
+class CATLike(Protocol):
+
+    @property
+    def cdb(self) -> CDBLike:
+        pass
+
+    def train_supervised_raw(self,
+                             data: Dict[str, List[Dict[str, dict]]],
+                             reset_cui_count: bool = False,
+                             nepochs: int = 1,
+                             print_stats: int = 0,
+                             use_filters: bool = False,
+                             terminate_last: bool = False,
+                             use_overlaps: bool = False,
+                             use_cui_doc_limit: bool = False,
+                             test_size: float = 0,
+                             devalue_others: bool = False,
+                             use_groups: bool = False,
+                             never_terminate: bool = False,
+                             train_from_false_positives: bool = False,
+                             extra_cui_filter: Optional[Set] = None,
+                             retain_extra_cui_filter: bool = False,
+                             checkpoint: Optional[Checkpoint] = None,
+                             retain_filters: bool = False,
+                             is_resumed: bool = False) -> Tuple:
+        pass
+
+
+class SplitType(Enum):
+    """The split type."""
+    DOCUMENTS = auto()
+    """Split over number of documents."""
+    ANNOTATIONS = auto()
+    """Split over number of annotations."""
+    DOCUMENTS_WEIGHTED = auto()
+    """Split over number of documents based on the number of annotations.
+    So essentially this ensures that the same document isn't in 2 folds
+    while trying to more equally distribute documents with different number
+    of annotations.
+    For example:
+        If we have 6 documents that we want to split into 3 folds.
+        The number of annotations per document are as follows:
+           [40, 40, 20, 10, 5, 5]
+        If we were to split this trivially over documents, we'd end up
+        with the 3 folds with number of annotations that are far from even:
+           [80, 30, 10]
+        However, if we use the annotations as weights, we would be able to
+        create folds that have more evenly distributed annotations, e.g:
+           [[D1,], [D2], [D3, D4, D5, D6]]
+        where D# denotes the number of the documents, with the number of
+        annotations being equal:
+           [ 40, 40, 20 + 10 + 5 + 5 = 40]
+    """
+
+
+class FoldCreator(ABC):
+    """The FoldCreator based on a MCT export.
+
+    Args:
+        mct_export (MedCATTrainerExport): The MCT export dict.
+        nr_of_folds (int): Number of folds to create.
+        use_annotations (bool): Whether to fold on number of annotations or documents.
+    """
+
+    def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int) -> None:
+        self.mct_export = mct_export
+        self.nr_of_folds = nr_of_folds
+
+    def _find_or_add_doc(self, project: MedCATTrainerExportProject, orig_doc: MedCATTrainerExportDocument
+                         ) -> MedCATTrainerExportDocument:
+        for existing_doc in project['documents']:
+            if existing_doc['name'] == orig_doc['name']:
+                return existing_doc
+        new_doc: MedCATTrainerExportDocument = deepcopy(orig_doc)
+        new_doc['annotations'].clear()
+        project['documents'].append(new_doc)
+        return new_doc
+
+    def _create_new_project(self, proj_info: MedCATTrainerExportProjectInfo) -> MedCATTrainerExportProject:
+        (proj_name, proj_id, proj_cuis, proj_tuis) = proj_info
+        cur_project = cast(MedCATTrainerExportProject, {
+            'name': proj_name,
+            'id': proj_id,
+            'cuis': proj_cuis,
+            'documents': [],
+            })
+        # NOTE: Some MCT exports don't declare TUIs
+        if proj_tuis is not None:
+            cur_project['tuis'] = proj_tuis
+        return cur_project
+
+    def _create_export_with_documents(self, relevant_docs: Iterable[Tuple[MedCATTrainerExportProjectInfo,
+                                                                          MedCATTrainerExportDocument]]) -> MedCATTrainerExport:
+        export: MedCATTrainerExport = {
+            "projects": []
+        }
+        # helper for finding projects per name
+        used_projects: Dict[str, MedCATTrainerExportProject] = {}
+        for proj_info, doc in relevant_docs:
+            proj_name = proj_info[0]
+            if proj_name not in used_projects:
+                cur_project = self._create_new_project(proj_info) # TODO - make sure it's available
+                export['projects'].append(cur_project)
+                used_projects[proj_name] = cur_project
+            else:
+                cur_project = used_projects[proj_name]
+            cur_project['documents'].append(doc)
+        return export
+
+
+    @abstractmethod
+    def create_folds(self) -> List[MedCATTrainerExport]:
+        """Create folds.
+
+        Raises:
+            ValueError: If somethign went wrong.
+
+        Returns:
+            List[MedCATTrainerExport]: The created folds.
+        """
+
+
+class SimpleFoldCreator(FoldCreator):
+
+    def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int,
+                 counter: Callable[[MedCATTrainerExport], int]) -> None:
+        super().__init__(mct_export, nr_of_folds)
+        self._counter = counter
+        self.total = self._counter(mct_export)
+        self.per_fold = self._init_per_fold()
+
+    def _init_per_fold(self) -> List[int]:
+        per_fold = [self.total // self.nr_of_folds for _ in range(self.nr_of_folds)]
+        total = sum(per_fold)
+        if total < self.total:
+            per_fold[-1] += self.total - total
+        if any(pf <= 0 for pf in per_fold):
+            raise ValueError(f"Failed to calculate per-fold items. Got: {per_fold}")
+        return per_fold
+
+    @abstractmethod
+    def _create_fold(self, fold_nr: int) -> MedCATTrainerExport:
+        pass
+
+    def create_folds(self) -> List[MedCATTrainerExport]:
+        return [
+            self._create_fold(fold_nr) for fold_nr in range(self.nr_of_folds)
+        ]
+
+
+
+class PerDocsFoldCreator(FoldCreator):
+
+    def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int) -> None:
+        super().__init__(mct_export, nr_of_folds)
+        self.nr_of_docs = count_all_docs(self.mct_export)
+        self.per_doc_simple = self.nr_of_docs // self.nr_of_folds
+        self._all_docs = list(iter_docs(self.mct_export))
+
+    def _create_fold(self, fold_nr: int) -> MedCATTrainerExport:
+        start_nr = self.per_doc_simple * fold_nr
+        # until the end for last fold, otherwise just the next set of docs
+        end_nr = self.nr_of_docs if fold_nr == self.nr_of_folds - 1 else start_nr + self.per_doc_simple
+        relevant_docs = self._all_docs[start_nr: end_nr]
+        return self._create_export_with_documents(relevant_docs)
+
+    def create_folds(self) -> List[MedCATTrainerExport]:
+        return [
+            self._create_fold(fold_nr) for fold_nr in range(self.nr_of_folds)
+        ]
+
+
+class PerAnnsFoldCreator(SimpleFoldCreator):
+
+    def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int) -> None:
+        super().__init__(mct_export, nr_of_folds, count_all_annotations)
+
+    def _add_target_ann(self, project: MedCATTrainerExportProject,
+                        orig_doc: MedCATTrainerExportDocument,
+                        ann: MedCATTrainerExportAnnotation) -> None:
+        cur_doc: MedCATTrainerExportDocument = self._find_or_add_doc(project, orig_doc)
+        cur_doc['annotations'].append(ann)
+
+    def _targets(self) -> Iterable[Tuple[MedCATTrainerExportProjectInfo,
+                                         MedCATTrainerExportDocument,
+                                         MedCATTrainerExportAnnotation]]:
+        return iter_anns(self.mct_export)
+
+    def _create_fold(self, fold_nr: int) -> MedCATTrainerExport:
+        per_fold = self.per_fold[fold_nr]
+        cur_fold: MedCATTrainerExport = {
+            'projects': []
+        }
+        cur_project: Optional[MedCATTrainerExportProject] = None
+        included = 0
+        for target in self._targets():
+            proj_info, cur_doc, cur_ann = target
+            proj_name = proj_info[0]
+            if not cur_project or cur_project['name'] != proj_name:
+                # first or new project
+                cur_project = self._create_new_project(proj_info)
+                cur_fold['projects'].append(cur_project)
+            self._add_target_ann(cur_project, cur_doc, cur_ann)
+            included += 1
+            if included == per_fold:
+                break
+            if included > per_fold:
+                raise ValueError("Got a larger fold than expected. "
+                                 f"Expected {per_fold}, got {included}")
+        return cur_fold
+
+
+class WeightedDocumentsCreator(FoldCreator):
+
+    def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int,
+                 weight_calculator: Callable[[MedCATTrainerExportDocument], int]) -> None:
+        super().__init__(mct_export, nr_of_folds)
+        self._weight_calculator = weight_calculator
+        docs = [(doc, self._weight_calculator(doc[1])) for doc in iter_docs(self.mct_export)]
+        # descending order in weight
+        self._weighted_docs = sorted(docs, key=lambda d: d[1], reverse=True)
+
+    def create_folds(self) -> List[MedCATTrainerExport]:
+        doc_folds: List[List[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument]]]
+        doc_folds = [[] for _ in range(self.nr_of_folds)]
+        fold_weights = [0] * self.nr_of_folds
+
+        for item, weight in self._weighted_docs:
+            # Find the subset with the minimum total weight
+            min_subset_idx = np.argmin(fold_weights)
+            # add the most heavily weighted document
+            doc_folds[min_subset_idx].append(item)
+            fold_weights[min_subset_idx] += weight
+
+        return [self._create_export_with_documents(docs) for docs in doc_folds]
+
+
+def get_fold_creator(mct_export: MedCATTrainerExport,
+                     nr_of_folds: int,
+                     split_type: SplitType) -> FoldCreator:
+    """Get the appropriate fold creator.
+
+    Args:
+        mct_export (MedCATTrainerExport): The MCT export.
+        nr_of_folds (int): Number of folds to use.
+        split_type (SplitType): The type of split to use.
+
+    Raises:
+        ValueError: In case of an unknown split type.
+
+    Returns:
+        FoldCreator: The corresponding fold creator.
+    """
+    if split_type is SplitType.DOCUMENTS:
+        return PerDocsFoldCreator(mct_export=mct_export, nr_of_folds=nr_of_folds)
+    elif split_type is SplitType.ANNOTATIONS:
+        return PerAnnsFoldCreator(mct_export=mct_export, nr_of_folds=nr_of_folds)
+    elif split_type is SplitType.DOCUMENTS_WEIGHTED:
+        return WeightedDocumentsCreator(mct_export=mct_export, nr_of_folds=nr_of_folds,
+                                        weight_calculator=get_nr_of_annotations)
+    else:
+        raise ValueError(f"Unknown Split Type: {split_type}")
+
+
+def get_per_fold_metrics(cat: CATLike, folds: List[MedCATTrainerExport],
+                         *args, **kwargs) -> List[Tuple]:
+    metrics = []
+    for fold_nr, cur_fold in enumerate(folds):
+        others = list(folds)
+        others.pop(fold_nr)
+        with captured_state_cdb(cat.cdb):
+            for other in others:
+                cat.train_supervised_raw(cast(Dict[str, Any], other), *args, **kwargs)
+            stats = get_stats(cat, cast(Dict[str, Any], cur_fold), do_print=False)
+            metrics.append(stats)
+    return metrics
+
+
+def _update_all_weighted_average(joined: List[Dict[str, Tuple[int, float]]],
+                single: List[Dict[str, float]], cui2count: Dict[str, int]) -> None:
+    if len(joined) != len(single):
+        raise ValueError(f"Incompatible lists. Joined {len(joined)} and single {len(single)}")
+    for j, s in zip(joined, single):
+        _update_one_weighted_average(j, s, cui2count)
+
+
+def _update_one_weighted_average(joined: Dict[str, Tuple[int, float]],
+                one: Dict[str, float],
+                cui2count: Dict[str, int]) -> None:
+    for k in one:
+        if k not in joined:
+            joined[k] = (0, 0)
+        prev_w, prev_val = joined[k]
+        new_w, new_val = cui2count[k], one[k]
+        total_w = prev_w + new_w
+        total_val = (prev_w * prev_val + new_w * new_val) / total_w
+        joined[k] = (total_w, total_val)
+
+
+def _update_all_add(joined: List[Dict[str, int]], single: List[Dict[str, int]]) -> None:
+    if len(joined) != len(single):
+        raise ValueError(f"Incompatible number of stuff: {len(joined)} vs {len(single)}")
+    for j, s in zip(joined, single):
+        for k, v in s.items():
+            j[k] = j.get(k, 0) + v
+
+
+def _merge_examples(all_examples: Dict, cur_examples: Dict) -> None:
+    for ex_type, ex_dict in cur_examples.items():
+        if ex_type not in all_examples:
+            all_examples[ex_type] = {}
+        per_type_examples = all_examples[ex_type]
+        for ex_cui, cui_examples_list in ex_dict.items():
+            if ex_cui not in per_type_examples:
+                per_type_examples[ex_cui] = []
+            per_type_examples[ex_cui].extend(cui_examples_list)
+
+
+def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]]
+                     ) -> Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]:
+    """The the mean of the provided metrics.
+
+    Args:
+        metrics (List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]): The metrics.
+
+    Returns:
+        fps (dict):
+            False positives for each CUI.
+        fns (dict):
+            False negatives for each CUI.
+        tps (dict):
+            True positives for each CUI.
+        cui_prec (dict):
+            Precision for each CUI.
+        cui_rec (dict):
+            Recall for each CUI.
+        cui_f1 (dict):
+            F1 for each CUI.
+        cui_counts (dict):
+            Number of occurrence for each CUI.
+        examples (dict):
+            Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][].
+    """
+    # additives
+    all_fps: Dict[str, int] = {}
+    all_fns: Dict[str, int] = {}
+    all_tps: Dict[str, int] = {}
+    # weighted-averages
+    all_cui_prec: Dict[str, Tuple[int, float]] = {}
+    all_cui_rec: Dict[str, Tuple[int, float]] = {}
+    all_cui_f1: Dict[str, Tuple[int, float]] = {}
+    # additive
+    all_cui_counts: Dict[str, int] = {}
+    # combined
+    all_additives = [
+        all_fps, all_fns, all_tps, all_cui_counts
+    ]
+    all_weighted_averages = [
+        all_cui_prec, all_cui_rec, all_cui_f1
+    ]
+    # examples
+    all_examples: dict = {}
+    for current in metrics:
+        cur_wa: list = list(current[3:-2])
+        cur_counts = current[-2]
+        _update_all_weighted_average(all_weighted_averages, cur_wa, cur_counts)
+        # update ones that just need to be added up
+        cur_adds = list(current[:3]) + [cur_counts]
+        _update_all_add(all_additives, cur_adds)
+        # merge examples
+        cur_examples = current[-1]
+        _merge_examples(all_examples, cur_examples)
+    cui_prec: Dict[str, float] = {}
+    cui_rec: Dict[str, float] = {}
+    cui_f1: Dict[str, float] = {}
+    final_wa = [
+        cui_prec, cui_rec, cui_f1
+    ]
+    # just remove the weight / count
+    for df, d in zip(final_wa, all_weighted_averages):
+        for k, v in d.items():
+            df[k] = v[1]  # only the value, ingore the weight
+    return (all_fps, all_fns, all_tps, final_wa[0], final_wa[1], final_wa[2],
+            all_cui_counts, all_examples)
+
+
+def get_k_fold_stats(cat: CATLike, mct_export_data: MedCATTrainerExport, k: int = 3,
+                     split_type: SplitType = SplitType.DOCUMENTS_WEIGHTED, *args, **kwargs) -> Tuple:
+    """Get the k-fold stats for the model with the specified data.
+
+    First this will split the MCT export into `k` folds. You can do
+    this either per document or per-annotation.
+
+    For each of the `k` folds, it will start from the base model,
+    train it with with the other `k-1` folds and record the metrics.
+    After that the base model state is restored before doing the next fold.
+    After all the folds have been done, the metrics are averaged.
+
+    Args:
+        cat (CATLike): The model pack.
+        mct_export_data (MedCATTrainerExport): The MCT export.
+        k (int): The number of folds. Defaults to 3.
+        split_type (SplitType): Whether to use annodations or docs. Defaults to DOCUMENTS_WEIGHTED.
+        *args: Arguments passed to the `CAT.train_supervised_raw` method.
+        **kwargs: Keyword arguments passed to the `CAT.train_supervised_raw` method.
+
+    Returns:
+        Tuple: The averaged metrics.
+    """
+    creator = get_fold_creator(mct_export_data, k, split_type=split_type)
+    folds = creator.create_folds()
+    per_fold_metrics = get_per_fold_metrics(cat, folds, *args, **kwargs)
+    return get_metrics_mean(per_fold_metrics)
diff --git a/medcat/stats/mctexport.py b/medcat/stats/mctexport.py
new file mode 100644
index 000000000..54f5a4443
--- /dev/null
+++ b/medcat/stats/mctexport.py
@@ -0,0 +1,66 @@
+from typing import List, Iterator, Tuple, Any, Optional
+from typing_extensions import TypedDict
+
+
+class MedCATTrainerExportAnnotation(TypedDict):
+    start: int
+    end: int
+    cui: str
+    value: str
+
+
+class MedCATTrainerExportDocument(TypedDict):
+    name: str
+    id: Any
+    last_modified: str
+    text: str
+    annotations: List[MedCATTrainerExportAnnotation]
+
+
+class MedCATTrainerExportProject(TypedDict):
+    name: str
+    id: Any
+    cuis: str
+    tuis: Optional[str]
+    documents: List[MedCATTrainerExportDocument]
+
+
+MedCATTrainerExportProjectInfo = Tuple[str, Any, str, Optional[str]]
+"""The project name, project ID, CUIs str, and TUIs str"""
+
+
+class MedCATTrainerExport(TypedDict):
+    projects: List[MedCATTrainerExportProject]
+
+
+def iter_projects(export: MedCATTrainerExport) -> Iterator[MedCATTrainerExportProject]:
+    yield from export['projects']
+
+
+def iter_docs(export: MedCATTrainerExport
+              ) -> Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument]]:
+    for project in iter_projects(export):
+        info: MedCATTrainerExportProjectInfo = (
+            project['name'], project['id'], project['cuis'], project.get('tuis', None)
+        )
+        for doc in project['documents']:
+            yield info, doc
+
+
+def iter_anns(export: MedCATTrainerExport
+              ) -> Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument, MedCATTrainerExportAnnotation]]:
+    for proj_info, doc in iter_docs(export):
+        for ann in doc['annotations']:
+            yield proj_info, doc, ann
+
+
+def count_all_annotations(export: MedCATTrainerExport) -> int:
+    return len(list(iter_anns(export)))
+
+
+def count_all_docs(export: MedCATTrainerExport) -> int:
+    return len(list(iter_docs(export)))
+
+
+def get_nr_of_annotations(doc: MedCATTrainerExportDocument) -> int:
+    return len(doc['annotations'])
diff --git a/medcat/stats/stats.py b/medcat/stats/stats.py
index 610d4d2a1..e467e0519 100644
--- a/medcat/stats/stats.py
+++ b/medcat/stats/stats.py
@@ -60,6 +60,9 @@ def process_project(self, project: dict) -> None:
         # Add extra filter if set
         set_project_filters(self.addl_info, self.filters, project, self.extra_cui_filter, self.use_project_filters)
 
+        project_name = cast(str, project.get('name'))
+        project_id = cast(str, project.get('id'))
+
         documents = project["documents"]
         for dind, doc in tqdm(
             enumerate(documents),
@@ -67,8 +70,7 @@ def process_project(self, project: dict) -> None:
             total=len(documents),
             leave=False,
         ):
-            self.process_document(cast(str, project.get('name')),
-                                  cast(str, project.get('id')), doc)
+            self.process_document(project_name, project_id, doc)
 
     def process_document(self, project_name: str, project_id: str, doc: dict) -> None:
         anns = self._get_doc_annotations(doc)
diff --git a/medcat/tokenizers/meta_cat_tokenizers.py b/medcat/tokenizers/meta_cat_tokenizers.py
index 7a4b07ac0..93d8b51ed 100644
--- a/medcat/tokenizers/meta_cat_tokenizers.py
+++ b/medcat/tokenizers/meta_cat_tokenizers.py
@@ -1,3 +1,4 @@
+import logging
 import os
 from abc import ABC, abstractmethod
 from typing import List, Dict, Optional, Union, overload
@@ -26,7 +27,7 @@ def save(self, dir_path: str) -> None: ...
 
     @classmethod
     @abstractmethod
-    def load(cls, dir_path: str, **kwargs) -> Tokenizer: ...
+    def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> Tokenizer: ...
 
     @abstractmethod
     def get_size(self) -> int: ...
@@ -112,7 +113,7 @@ def save(self, dir_path: str) -> None:
         self.hf_tokenizers.save_model(dir_path, prefix=self.name)
 
     @classmethod
-    def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperBPE":
+    def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "TokenizerWrapperBPE":
         tokenizer = cls()
         vocab_file = os.path.join(dir_path, f'{tokenizer.name}-vocab.json')
         merges_file = os.path.join(dir_path, f'{tokenizer.name}-merges.txt')
@@ -186,10 +187,14 @@ def save(self, dir_path: str) -> None:
         self.hf_tokenizers.save_pretrained(path)
 
     @classmethod
-    def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperBERT":
+    def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "TokenizerWrapperBERT":
         tokenizer = cls()
         path = os.path.join(dir_path, cls.name)
-        tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs)
+        try:
+            tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs)
+        except Exception as e:
+            logging.warning("Could not load tokenizer from path due to error: {}. Loading from library for model variant: {}".format(e,model_variant))
+            tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(model_variant)
 
         return tokenizer
 
diff --git a/medcat/utils/cdb_state.py b/medcat/utils/cdb_state.py
new file mode 100644
index 000000000..794a40109
--- /dev/null
+++ b/medcat/utils/cdb_state.py
@@ -0,0 +1,179 @@
+import logging
+import contextlib
+from typing import Dict, TypedDict, Set, List, cast
+import numpy as np
+import tempfile
+import dill
+
+from copy import deepcopy
+
+
+
+logger = logging.getLogger(__name__) # separate logger from the package-level one
+
+
+CDBState = TypedDict(
+    'CDBState',
+    {
+        'name2cuis': Dict[str, List[str]],
+        'snames': Set[str],
+        'cui2names': Dict[str, Set[str]],
+        'cui2snames': Dict[str, Set[str]],
+        'cui2context_vectors': Dict[str, Dict[str, np.ndarray]],
+        'cui2count_train': Dict[str, int],
+        'name_isupper': Dict,
+        'vocab': Dict[str, int],
+    })
+"""CDB State.
+
+This is a dictionary of the parts of the CDB that change during
+(supervised) training. It can be used to store and restore the
+state of a CDB after modifying it.
+
+Currently, the following fields are saved:
+ - name2cuis
+ - snames
+ - cui2names
+ - cui2snames
+ - cui2context_vectors
+ - cui2count_train
+ - name_isupper
+ - vocab
+"""
+
+
+def copy_cdb_state(cdb) -> CDBState:
+    """Creates a (deep) copy of the CDB state.
+
+    Grabs the fields that correspond to the state,
+    creates deep copies, and returns the copies.
+
+    Args:
+        cdb: The CDB from which to grab the state.
+
+    Returns:
+        CDBState: The copied state.
+    """
+    return cast(CDBState, {
+        k: deepcopy(getattr(cdb, k)) for k in CDBState.__annotations__
+    })
+
+
+def save_cdb_state(cdb, file_path: str) -> None:
+    """Saves CDB state in a file.
+
+    Currently uses `dill.dump` to save the relevant fields/values.
+
+    Args:
+        cdb: The CDB from which to grab the state.
+        file_path (str): The file to dump the state.
+    """
+    # NOTE: The difference is that we don't create a copy here.
+    #       That is so that we don't have to occupy the memory for
+    #       both copies
+    the_dict = {
+        k: getattr(cdb, k) for k in CDBState.__annotations__
+    }
+    logger.debug("Saving CDB state on disk at: '%s'", file_path)
+    with open(file_path, 'wb') as f:
+        dill.dump(the_dict, f)
+
+
+def apply_cdb_state(cdb, state: CDBState) -> None:
+    """Apply the specified state to the specified CDB.
+
+    This overwrites the current state of the CDB with one provided.
+
+    Args:
+        cdb: The CDB to apply the state to.
+        state (CDBState): The state to use.
+    """
+    for k, v in state.items():
+        setattr(cdb, k, v)
+
+
+def load_and_apply_cdb_state(cdb, file_path: str) -> None:
+    """Delete current CDB state and apply CDB state from file.
+
+    This first delets the current state of the CDB.
+    This is to save memory. The idea is that saving the staet
+    on disk will save on RAM usage. But it wouldn't really
+    work too well if upon load, two instances were still in
+    memory.
+
+    Args:
+        cdb: The CDB to apply the state to.
+        file_path (str): The file where the state has been saved to.
+    """
+    # clear existing data on CDB
+    # this is so that we don't occupy the memory for both the loaded
+    # and the on-CDB data
+    logger.debug("Clearing CDB state in memory")
+    for k in CDBState.__annotations__:
+        val = getattr(cdb, k)
+        setattr(cdb, k, None)
+        del val
+    logger.debug("Loading CDB state from disk from '%s'", file_path)
+    with open(file_path, 'rb') as f:
+        data = dill.load(f)
+    for k in CDBState.__annotations__:
+        setattr(cdb, k, data[k])
+
+
+@contextlib.contextmanager
+def captured_state_cdb(cdb, save_state_to_disk: bool = False):
+    """A context manager that captures and re-applies the initial CDB state.
+
+    The context manager captures/copies the initial state of the CDB when entering.
+    It then allows the user to modify the state (i.e training).
+    Upon exit re-applies the initial CDB state.
+
+    If RAM is an issue, it is recommended to use `save_state_to_disk`.
+    Otherwise the copy of the original state will be held in memory.
+    If saved on disk, a temporary file is used and removed afterwards.
+
+    Args:
+        cdb: The CDB to use.
+        save_state_to_disk (bool): Whether to save state on disk or hold in in memory.
+            Defaults to False.
+
+    Yields:
+        None
+    """
+    if save_state_to_disk:
+        with on_disk_memory_capture(cdb):
+            yield
+    else:
+        with in_memory_state_capture(cdb):
+            yield
+
+
+@contextlib.contextmanager
+def in_memory_state_capture(cdb):
+    """Capture the CDB state in memory.
+
+    Args:
+        cdb: The CDB to use.
+
+    Yields:
+        None
+    """
+    state = copy_cdb_state(cdb)
+    yield
+    apply_cdb_state(cdb, state)
+
+
+@contextlib.contextmanager
+def on_disk_memory_capture(cdb):
+    """Capture the CDB state in a temporary file.
+
+    Args:
+        cdb: The CDB to use
+
+    Yields:
+        None
+    """
+    with tempfile.NamedTemporaryFile() as tf:
+        save_cdb_state(cdb, tf.name)
+        yield
+        load_and_apply_cdb_state(cdb, tf.name)
diff --git a/medcat/utils/cdb_utils.py b/medcat/utils/cdb_utils.py
index c473ddba4..fefaf1273 100644
--- a/medcat/utils/cdb_utils.py
+++ b/medcat/utils/cdb_utils.py
@@ -63,7 +63,7 @@ def merge_cdb(cdb1: CDB,
                 ontologies.update(cdb2.addl_info['cui2ontologies'][cui])
             if 'cui2description' in cdb2.addl_info:
                 description = cdb2.addl_info['cui2description'][cui]
-        cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status,
+        cdb._add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status,
                         type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build)
         if cui in cdb1.cui2names:
             if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train): 
diff --git a/medcat/utils/config_utils.py b/medcat/utils/config_utils.py
index 1aafbf3f1..09989b258 100644
--- a/medcat/utils/config_utils.py
+++ b/medcat/utils/config_utils.py
@@ -1,15 +1,79 @@
 from functools import partial
-from typing import Callable
+from typing import Callable, Optional, Protocol
 import logging
+from pydantic import BaseModel
+
+
+class WAFCarrier(Protocol):
+
+    @property
+    def weighted_average_function(self) -> Callable[[float], int]:
+        pass
 
 
 logger = logging.getLogger(__name__)
 
 
+def is_old_type_config_dict(d: dict) -> bool:
+    """Checks if the dict provided is an old style (jsonpickle) config.
+
+    This checks for json-pickle specific keys such as py/object and py/state.
+    If both of those are keys somewhere within the 2 initial layers of the
+    nested dict, it's considered old style.
+
+    Args:
+        d (dict): Loaded config.
+
+    Returns:
+        bool: Whether it's an old style (jsonpickle) config.
+    """
+    # all 2nd level keys
+    all_keys = set(sub_key for key in d for sub_key in (d[key] if isinstance(d[key], dict) else [key]))
+    # add 1st level keys
+    all_keys.update(d.keys())
+    # is old if py/object and py/state somewhere in keys
+    return set(('py/object', 'py/state')) <= all_keys
+
+
+def fix_waf_lambda(carrier: WAFCarrier) -> None:
+    weighted_average_function = carrier.weighted_average_function  # type: ignore
+    if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "":
+        # the following type ignoring is for mypy because it is unable to detect the signature
+        carrier.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore
+
+
+# NOTE: This method is a hacky workaround. The type ignores are because I cannot
+#       import config here since it would produce a circular import
+def ensure_backward_compatibility(config: BaseModel, workers: Callable[[], int]) -> None:
+    # Hacky way of supporting old CDBs
+    if hasattr(config.linking, 'weighted_average_function'):  # type: ignore
+        fix_waf_lambda(config.linking)  # type: ignore
+    if config.general.workers is None:  # type: ignore
+        config.general.workers = workers()  # type: ignore
+    disabled_comps = config.general.spacy_disabled_components  # type: ignore
+    if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps:
+        config.general.spacy_disabled_components.append('lemmatizer')  # type: ignore
+
+
+def get_and_del_weighted_average_from_config(config: BaseModel) -> Optional[Callable[[int], float]]:
+    if not hasattr(config, 'linking'):
+        return None
+    linking = config.linking
+    if not hasattr(linking, 'weighted_average_function'):
+        return None
+    waf = linking.weighted_average_function
+    delattr(linking, 'weighted_average_function')
+    return waf
+
+
 def weighted_average(step: int, factor: float) -> float:
     return max(0.1, 1 - (step ** 2 * factor))
 
 
+def default_weighted_average(step: int) -> float:
+    return weighted_average(step, factor=0.0004)
+
+
 def attempt_fix_weighted_average_function(waf: Callable[[int], float]
                                           ) -> Callable[[int], float]:
     """Attempf fix weighted_average_function.
diff --git a/medcat/utils/decorators.py b/medcat/utils/decorators.py
index a98922360..ca473774b 100644
--- a/medcat/utils/decorators.py
+++ b/medcat/utils/decorators.py
@@ -1,14 +1,30 @@
 import warnings
 import functools
-from typing import Callable
+from typing import Callable, Tuple
 
 
-def deprecated(message: str) -> Callable:
+def _format_version(ver: Tuple[int, int, int]) -> str:
+    return ".".join(str(v) for v in ver)
+
+
+def deprecated(message: str, depr_version: Tuple[int, int, int], removal_version: Tuple[int, int, int]) -> Callable:
+    """Deprecate a method.
+
+    Args:
+        message (str): The deprecation message.
+        depr_version (Tuple[int, int, int]): The first version of MedCAT where this was deprecated.
+        removal_version (Tuple[int, int, int]): The first version of MedCAT where this will be removed.
+
+    Returns:
+        Callable: _description_
+    """
     def decorator(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapped(*args, **kwargs) -> Callable:
             warnings.simplefilter("always", DeprecationWarning)
             warnings.warn("Function {} has been deprecated.{}".format(func.__name__, " " + message if message else ""))
+            warnings.warn(f"The above function was deprecated in v{_format_version(depr_version)} "
+                          f"and will be removed in v{removal_version}")
             warnings.simplefilter("default", DeprecationWarning)
             return func(*args, **kwargs)
         return wrapped
diff --git a/medcat/utils/meta_cat/data_utils.py b/medcat/utils/meta_cat/data_utils.py
index 5d2060ca7..c4dc5f9c2 100644
--- a/medcat/utils/meta_cat/data_utils.py
+++ b/medcat/utils/meta_cat/data_utils.py
@@ -1,5 +1,8 @@
 from typing import Dict, Optional, Tuple, Iterable, List
 from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
+import logging
+
+logger = logging.getLogger(__name__)
 
 
 def prepare_from_json(data: Dict,
@@ -23,8 +26,6 @@ def prepare_from_json(data: Dict,
             Size of context to get from the right of the concept
         tokenizer (TokenizerWrapperBase):
             Something to split text into tokens for the LSTM/BERT/whatever meta models.
-        cui_filter (Optional[set]):
-            CUI filter if set. Defaults to None.
         replace_center (Optional[str]):
             If not None the center word (concept) will be replaced with whatever this is.
         prerequisites (Dict):
@@ -33,6 +34,8 @@ def prepare_from_json(data: Dict,
                 {'Experiencer': 'Patient'} - Take care that the CASE has to match whatever is in the data. Defaults to `{}`.
         lowercase (bool):
             Should the text be lowercased before tokenization. Defaults to True.
+        cui_filter (Optional[set]):
+            CUI filter if set. Defaults to None.
 
     Returns:
         out_data (dict):
@@ -49,7 +52,8 @@ def prepare_from_json(data: Dict,
             if len(text) > 0:
                 doc_text = tokenizer(text)
 
-                for ann in document.get('annotations', document.get('entities', {}).values()): # A hack to suport entities and annotations
+                for ann in document.get('annotations', document.get('entities',
+                                                                    {}).values()):  # A hack to suport entities and annotations
                     cui = ann['cui']
                     skip = False
                     if 'meta_anns' in ann and prerequisites:
@@ -61,21 +65,28 @@ def prepare_from_json(data: Dict,
                                 break
 
                     if not skip and (cui_filter is None or not cui_filter or cui in cui_filter):
-                        if ann.get('validated', True) and (not ann.get('deleted', False) and not ann.get('killed', False)
-                                                           and not ann.get('irrelevant', False)):
+                        if ann.get('validated', True) and (
+                                not ann.get('deleted', False) and not ann.get('killed', False)
+                                and not ann.get('irrelevant', False)):
                             start = ann['start']
                             end = ann['end']
 
-                            # Get the index of the center token
-                            ind = 0
+                            # Updated implementation to extract all the tokens for the medical entity (rather than the one)
+                            ctoken_idx = []
                             for ind, pair in enumerate(doc_text['offset_mapping']):
-                                if start >= pair[0] and start < pair[1]:
-                                    break
-
-                            _start = max(0, ind - cntx_left)
-                            _end = min(len(doc_text['input_ids']), ind + 1 + cntx_right)
+                                if start <= pair[0] or start <= pair[1]:
+                                    if end <= pair[1]:
+                                        ctoken_idx.append(ind)
+                                        break
+                                    else:
+                                        ctoken_idx.append(ind)
+
+                            _start = max(0, ctoken_idx[0] - cntx_left)
+                            _end = min(len(doc_text['input_ids']), ctoken_idx[-1] + 1 + cntx_right)
+
+                            cpos = cntx_left + min(0, ind - cntx_left)
+                            cpos_new = [x - _start for x in ctoken_idx]
                             tkns = doc_text['input_ids'][_start:_end]
-                            cpos = cntx_left + min(0, ind-cntx_left)
 
                             if replace_center is not None:
                                 if lowercase:
@@ -87,19 +98,19 @@ def prepare_from_json(data: Dict,
                                         e_ind = p_ind
 
                                 ln = e_ind - s_ind
-                                tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos+ln+1:]
+                                tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos + ln + 1:]
 
                             # Backward compatibility if meta_anns is a list vs dict in the new approach
                             meta_anns = []
                             if 'meta_anns' in ann:
-                                meta_anns = ann['meta_anns'].values() if type(ann['meta_anns']) is dict else ann['meta_anns']
+                                meta_anns = ann['meta_anns'].values() if isinstance(ann['meta_anns'],dict) else ann['meta_anns']
 
                             # If the annotation is validated
                             for meta_ann in meta_anns:
                                 name = meta_ann['name']
                                 value = meta_ann['value']
 
-                                sample = [tkns, cpos, value]
+                                sample = [tkns, cpos_new, value]
 
                                 if name in out_data:
                                     out_data[name].append(sample)
@@ -108,7 +119,41 @@ def prepare_from_json(data: Dict,
     return out_data
 
 
-def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None) -> Tuple:
+def prepare_for_oversampled_data(data: List,
+                                 tokenizer: TokenizerWrapperBase) -> List:
+    """Convert the data from a json format into a CSV-like format for training. This function is not very efficient (the one
+       working with spacy documents as part of the meta_cat.pipe method is much better). If your dataset is > 1M documents think
+       about rewriting this function - but would be strange to have more than 1M manually annotated documents.
+
+       Args:
+           data (List):
+               Oversampled data expected in the following format:
+               [[['text','of','the','document'], [index of medical entity], "label" ],
+                ['text','of','the','document'], [index of medical entity], "label" ]]
+           tokenizer (TokenizerWrapperBase):
+                Something to split text into tokens for the LSTM/BERT/whatever meta models.
+
+       Returns:
+            data_sampled (list):
+                The processed data in the format that can be merged with the output from prepare_from_json.
+                [[<[tokens]>, [index of medical entity], "label" ],
+                <[tokens]>, [index of medical entity], "label" ]]
+                """
+
+    data_sampled = []
+    for sample in data:
+        # Checking if the input is already tokenized
+        if isinstance(sample[0][0], str):
+            doc_text = tokenizer(sample[0])
+            data_sampled.append([doc_text[0]['input_ids'], sample[1], sample[2]])
+        else:
+            data_sampled.append([sample[0], sample[1], sample[2]])
+
+    return data_sampled
+
+
+def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None,
+                           category_undersample=None) -> Tuple:
     """Converts the category values in the data outputed by `prepare_from_json`
     into integere values.
 
@@ -117,10 +162,14 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
             Output of `prepare_from_json`.
         existing_category_value2id(Optional[Dict]):
             Map from category_value to id (old/existing).
+        category_undersample:
+            Name of class that should be used to undersample the data (for 2 phase learning)
 
     Returns:
         dict:
-            New data with integeres inplace of strings for categry values.
+            New underesampled data (for 2 phase learning) with integers inplace of strings for category values
+        dict:
+            New data with integers inplace of strings for category values.
         dict:
             Map rom category value to ID for all categories in the data.
     """
@@ -131,6 +180,23 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
         category_value2id = {}
 
     category_values = set([x[2] for x in data])
+    # Ensuring that each label has data and checking for class imbalance
+
+    label_data = {key: 0 for key in category_value2id}
+    for i in range(len(data)):
+        if data[i][2] in category_value2id:
+            label_data[data[i][2]] = label_data[data[i][2]] + 1
+
+    # If a label has no data, changing the mapping
+    if 0 in label_data.values():
+        category_value2id_: Dict = {}
+        keys_ls = [key for key, value in category_value2id.items() if value != 0]
+        for k in keys_ls:
+            category_value2id_[k] = len(category_value2id_)
+
+        logger.warning("Labels found with 0 data; updates made\nFinal label encoding mapping:", category_value2id_)
+        category_value2id = category_value2id_
+
     for c in category_values:
         if c not in category_value2id:
             category_value2id[c] = len(category_value2id)
@@ -139,30 +205,39 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
     for i in range(len(data)):
         data[i][2] = category_value2id[data[i][2]]
 
-    return data, category_value2id
+    # Creating dict with labels and its number of samples
+    label_data_ = {v: 0 for v in category_value2id.values()}
+    for i in range(len(data)):
+        if data[i][2] in category_value2id.values():
+            label_data_[data[i][2]] = label_data_[data[i][2]] + 1
+    # Undersampling data
+    if category_undersample is None or category_undersample == '':
+        min_label = min(label_data_.values())
 
+    else:
+        if category_undersample not in label_data_.keys() and category_undersample in category_value2id.keys():
+            min_label = label_data_[category_value2id[category_undersample]]
+        else:
+            min_label = label_data_[category_undersample]
 
-class Span(object):
-    def __init__(self, start_char: str, end_char: str, id_: str) -> None:
-        self._ = Empty()
-        self.start_char = start_char
-        self.end_char = end_char
-        self._.id = id_  # type: ignore
-        self._.meta_anns = None # type: ignore
+    data_undersampled = []
+    label_data_counter = {v: 0 for v in category_value2id.values()}
 
+    for sample in data:
+        if label_data_counter[sample[-1]] < min_label:
+            data_undersampled.append(sample)
+            label_data_counter[sample[-1]] += 1
 
-class Doc(object):
-    def __init__(self, text: str, id_: str) -> None:
-        self._ = Empty()
-        self._.share_tokens = None  # type: ignore
-        self.ents: List = []
-        # We do not have overlapps at this stage
-        self._ents = self.ents
-        self.text = text
-        self.id = id_
+    label_data = {v: 0 for v in category_value2id.values()}
+    for i in range(len(data_undersampled)):
+        if data_undersampled[i][2] in category_value2id.values():
+            label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
+    logger.info(f"Updated label_data: {label_data}")
+
+    return data_undersampled, data, category_value2id
 
 
-def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable[Doc]:
+def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable:
     """Creates a generator of fake spacy documents, used for running
     meta_cat pipe separately from main cat pipeline.
 
@@ -173,7 +248,7 @@ def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable[Doc]:
             Map from document id to text of that document.
 
     Yields:
-        Doc: spacy like documents that can be feed into meta_cat.pipe.
+        Generator: Generator of spacy like documents that can be feed into meta_cat.pipe.
     """
     for id_ in data.keys():
         ents = data[id_]['entities'].values()
@@ -187,3 +262,23 @@ def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable[Doc]:
 class Empty(object):
     def __init__(self) -> None:
         pass
+
+
+class Span(object):
+    def __init__(self, start_char: str, end_char: str, id_: str) -> None:
+        self._ = Empty()
+        self.start_char = start_char
+        self.end_char = end_char
+        self._.id = id_  # type: ignore
+        self._.meta_anns = None  # type: ignore
+
+
+class Doc(object):
+    def __init__(self, text: str, id_: str) -> None:
+        self._ = Empty()
+        self._.share_tokens = None  # type: ignore
+        self.ents: List = []
+        # We do not have overlapps at this stage
+        self._ents = self.ents
+        self.text = text
+        self.id = id_
diff --git a/medcat/utils/meta_cat/ml_utils.py b/medcat/utils/meta_cat/ml_utils.py
index 9f75fa69a..79cedb9f3 100644
--- a/medcat/utils/meta_cat/ml_utils.py
+++ b/medcat/utils/meta_cat/ml_utils.py
@@ -2,18 +2,22 @@
 import random
 import math
 import torch
+import torch.nn.functional as F
 import numpy as np
 import pandas as pd
 import torch.optim as optim
-from typing import List, Optional, Tuple, Any, Dict
+from typing import List, Optional, Tuple, Any, Dict, Union
 from torch import nn
 from scipy.special import softmax
 from medcat.config_meta_cat import ConfigMetaCAT
 from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
 from sklearn.metrics import classification_report, precision_recall_fscore_support, confusion_matrix
+from sklearn.model_selection import train_test_split
+from sklearn.utils.class_weight import compute_class_weight
+from transformers import AdamW, get_linear_schedule_with_warmup
 
-import logging
 
+import logging
 
 logger = logging.getLogger(__name__)
 
@@ -47,9 +51,13 @@ def create_batch_piped_data(data: List[Tuple[List[int], int, Optional[int]]],
             Same as data, but subsetted and as a tensor
         cpos ():
             Center positions for the data
+        attention_mask:
+            Indicating padding mask for the data
+        y:
+            class label of the data
     """
     max_seq_len = max([len(x[0]) for x in data])
-    x = [x[0][0:max_seq_len] + [pad_id]*max(0, max_seq_len - len(x[0])) for x in data[start_ind:end_ind]]
+    x = [x[0][0:max_seq_len] + [pad_id] * max(0, max_seq_len - len(x[0])) for x in data[start_ind:end_ind]]
     cpos = [x[1] for x in data[start_ind:end_ind]]
     y = None
     if len(data[0]) == 3:
@@ -57,9 +65,9 @@ def create_batch_piped_data(data: List[Tuple[List[int], int, Optional[int]]],
         y = torch.tensor([x[2] for x in data[start_ind:end_ind]], dtype=torch.long).to(device)
 
     x = torch.tensor(x, dtype=torch.long).to(device)
-    cpos = torch.tensor(cpos, dtype=torch.long).to(device)
-
-    return x, cpos, y
+    # cpos = torch.tensor(cpos, dtype=torch.long).to(device)
+    attention_masks = (x != 0).type(torch.int)
+    return x, cpos, attention_masks, y
 
 
 def predict(model: nn.Module, data: List[Tuple[List[int], int, Optional[int]]],
@@ -94,8 +102,10 @@ def predict(model: nn.Module, data: List[Tuple[List[int], int, Optional[int]]],
 
     with torch.no_grad():
         for i in range(num_batches):
-            x, cpos, _ = create_batch_piped_data(data, i*batch_size, (i+1)*batch_size, device=device, pad_id=pad_id)
-            logits = model(x, cpos, ignore_cpos=ignore_cpos)
+            x, cpos, attention_masks, _ = create_batch_piped_data(data, i * batch_size, (i + 1) * batch_size,
+                                                                  device=device, pad_id=pad_id)
+
+            logits = model(x, center_positions=cpos, attention_mask=attention_masks, ignore_cpos=ignore_cpos)
             all_logits.append(logits.detach().cpu().numpy())
 
     predictions = []
@@ -111,7 +121,7 @@ def predict(model: nn.Module, data: List[Tuple[List[int], int, Optional[int]]],
 
 
 def split_list_train_test(data: List, test_size: float, shuffle: bool = True) -> Tuple:
-    """Shuffle and randomply split data
+    """Shuffle and randomly split data
 
     Args:
         data (List): The data.
@@ -124,9 +134,14 @@ def split_list_train_test(data: List, test_size: float, shuffle: bool = True) ->
     if shuffle:
         random.shuffle(data)
 
-    test_ind = int(len(data) * test_size)
-    test_data = data[:test_ind]
-    train_data = data[test_ind:]
+    X_features = [x[:-1] for x in data]
+    y_labels = [x[-1] for x in data]
+
+    X_train, X_test, y_train, y_test = train_test_split(X_features, y_labels, test_size=test_size,
+                                                        random_state=42)
+
+    train_data = [x + [y] for x, y in zip(X_train, y_train)]
+    test_data = [x + [y] for x, y in zip(X_test, y_test)]
 
     return train_data, test_data
 
@@ -142,12 +157,25 @@ def print_report(epoch: int, running_loss: List, all_logits: List, y: Any, name:
         name (str): The name of the report. Defaults to Train.
     """
     if all_logits:
-        logger.info('Epoch: %d %s %s', epoch, "*"*50, name)
+        logger.info('Epoch: %d %s %s', epoch, "*" * 50, name)
         logger.info(classification_report(y, np.argmax(np.concatenate(all_logits, axis=0), axis=1)))
 
 
+class FocalLoss(nn.Module):
+    def __init__(self, alpha=None, gamma=2):
+        super(FocalLoss, self).__init__()
+        self.alpha = alpha
+        self.gamma = gamma
+
+    def forward(self, inputs, targets):
+        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
+        pt = torch.exp(-ce_loss)
+        loss = (self.alpha[targets] * (1 - pt) ** self.gamma * ce_loss).mean()
+        return loss
+
+
 def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_path: Optional[str] = None) -> Dict:
-    """Trains a LSTM model (for now) with autocheckpoints
+    """Trains a LSTM model and BERT with autocheckpoints
 
     Args:
         model (nn.Module): The model
@@ -162,18 +190,82 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
         Exception: If auto-save is enabled but no save dir path is provided.
     """
     # Get train/test from data
-    train_data, test_data = split_list_train_test(data, test_size=config.train['test_size'], shuffle=config.train['shuffle_data'])
-    device = torch.device(config.general['device']) # Create a torch device
+    train_data, test_data = split_list_train_test(data, test_size=config.train['test_size'],
+                                                  shuffle=config.train['shuffle_data'])
+    device = torch.device(config.general['device'])  # Create a torch device
 
     class_weights = config.train['class_weights']
-    if class_weights is not None:
-        class_weights = torch.FloatTensor(class_weights).to(device)
-        criterion = nn.CrossEntropyLoss(weight=class_weights) # Set the criterion to Cross Entropy Loss
+
+    if class_weights is None:
+        if config.train['compute_class_weights'] is True:
+            y_ = [x[2] for x in train_data]
+            class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_), y=y_)
+            config.train['class_weights'] = class_weights
+            logger.info(f"Class weights computed: {class_weights}")
+
+            class_weights = torch.FloatTensor(class_weights).to(device)
+            if config.train['loss_funct'] == 'cross_entropy':
+                criterion: Union[FocalLoss, nn.CrossEntropyLoss] = nn.CrossEntropyLoss(
+                    weight=class_weights)
+            elif config.train['loss_funct'] == 'focal_loss':
+                criterion = FocalLoss(alpha=class_weights, gamma=config.train['gamma'])
+
+        else:
+            logger.warning("Class weights not provided and compute_class_weights parameter is set to False. No class weights used for training.")
+            if config.train['loss_funct'] == 'cross_entropy':
+                criterion = nn.CrossEntropyLoss()
+            elif config.train['loss_funct'] == 'focal_loss':
+                criterion = FocalLoss(gamma=config.train['gamma'])
     else:
-        criterion = nn.CrossEntropyLoss() # Set the criterion to Cross Entropy Loss
+        class_weights = torch.FloatTensor(class_weights).to(device)
+        if config.train['loss_funct'] == 'cross_entropy':
+            criterion = nn.CrossEntropyLoss(
+                weight=class_weights)
+        elif config.train['loss_funct'] == 'focal_loss':
+            criterion = FocalLoss(alpha=class_weights, gamma=config.train['gamma'])
+
     parameters = filter(lambda p: p.requires_grad, model.parameters())
-    optimizer = optim.Adam(parameters, lr=config.train['lr'])
-    model.to(device) # Move the model to device
+
+    def initialize_model(classifier, data_, batch_size_, lr_, epochs=4):
+        """Initialize the Classifier, the optimizer and the learning rate scheduler.
+
+            Args:
+                classifier (nn.Module):
+                    The model to be trained
+                data_ (List):
+                    The data
+                batch_size_:
+                    Batch size
+                lr_:
+                    Learning rate for training
+                epochs:
+                    Number of training iterations
+
+            Returns:
+                classifier:
+                    model
+                optimizer_:
+                    optimizer
+                scheduler_:
+                    scheduler
+            """
+
+        # Create the optimizer
+        optimizer_ = AdamW(classifier.parameters(),
+                           lr=lr_,  # Default learning rate
+                           eps=1e-8,  # Default epsilon value
+                           weight_decay=1e-5
+                           )
+
+        # Total number of training steps
+        total_steps = int((len(data_) / batch_size_) * epochs)
+        logger.info('Total steps for optimizer: {}'.format(total_steps))
+
+        # Set up the learning rate scheduler
+        scheduler_ = get_linear_schedule_with_warmup(optimizer_,
+                                                     num_warmup_steps=0,  # Default value
+                                                     num_training_steps=total_steps)
+        return classifier, optimizer_, scheduler_
 
     batch_size = config.train['batch_size']
     batch_size_eval = config.general['batch_size_eval']
@@ -182,6 +274,13 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
     ignore_cpos = config.model['ignore_cpos']
     num_batches = math.ceil(len(train_data) / batch_size)
     num_batches_test = math.ceil(len(test_data) / batch_size_eval)
+    optimizer = optim.Adam(parameters, lr=config.train['lr'], weight_decay=1e-5)
+    if config.model.model_architecture_config is not None:
+        if config.model.model_architecture_config['lr_scheduler'] is True:
+            model, optimizer, scheduler = initialize_model(model, train_data, batch_size, config.train['lr'],
+                                                           epochs=nepochs)
+
+    model.to(device)  # Move the model to device
 
     # Can be pre-calculated for the whole dataset
     y_test = [x[2] for x in test_data]
@@ -193,8 +292,11 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
         all_logits = []
         model.train()
         for i in range(num_batches):
-            x, cpos, y = create_batch_piped_data(train_data, i*batch_size, (i+1)*batch_size, device=device, pad_id=pad_id)
-            logits = model(x, center_positions=cpos, ignore_cpos=ignore_cpos)
+            model.zero_grad()
+
+            x, cpos, attention_masks, y = create_batch_piped_data(train_data, i * batch_size, (i + 1) * batch_size,
+                                                                  device=device, pad_id=pad_id)
+            logits = model(x, attention_mask=attention_masks, center_positions=cpos, ignore_cpos=ignore_cpos)
             loss = criterion(logits, y)
             loss.backward()
             # Track loss and logits
@@ -202,8 +304,11 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
             all_logits.append(logits.detach().cpu().numpy())
 
             parameters = filter(lambda p: p.requires_grad, model.parameters())
-            nn.utils.clip_grad_norm_(parameters, 0.25)
+            nn.utils.clip_grad_norm_(parameters, 0.15)
             optimizer.step()
+            if config.model.model_architecture_config is not None:
+                if config.model.model_architecture_config['lr_scheduler'] is True:
+                    scheduler.step()
 
         all_logits_test = []
         running_loss_test = []
@@ -211,8 +316,10 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
 
         with torch.no_grad():
             for i in range(num_batches_test):
-                x, cpos, y = create_batch_piped_data(test_data, i*batch_size_eval, (i+1)*batch_size_eval, device=device, pad_id=pad_id)
-                logits = model(x, cpos, ignore_cpos=ignore_cpos)
+                x, cpos, attention_masks, y = create_batch_piped_data(test_data, i * batch_size_eval,
+                                                                      (i + 1) * batch_size_eval, device=device,
+                                                                      pad_id=pad_id)
+                logits = model(x, attention_mask=attention_masks, center_positions=cpos, ignore_cpos=ignore_cpos)
 
                 # Track loss and logits
                 running_loss_test.append(loss.item())
@@ -221,12 +328,20 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
         print_report(epoch, running_loss, all_logits, y=y_train, name='Train')
         print_report(epoch, running_loss_test, all_logits_test, y=y_test, name='Test')
 
-        _report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), output_dict=True)
+        _report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1),
+                                        output_dict=True)
         if not winner_report or _report[config.train['metric']['base']][config.train['metric']['score']] > \
                 winner_report['report'][config.train['metric']['base']][config.train['metric']['score']]:
 
-            report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), output_dict=True)
+            report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1),
+                                           output_dict=True)
+            cm = confusion_matrix(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), normalize='true')
+            report_train = classification_report(y_train, np.argmax(np.concatenate(all_logits, axis=0), axis=1),
+                                                 output_dict=True)
+
+            winner_report['confusion_matrix'] = cm
             winner_report['report'] = report
+            winner_report['report_train'] = report_train
             winner_report['epoch'] = epoch
 
             # Save if needed
@@ -237,8 +352,11 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
                 else:
                     path = os.path.join(save_dir_path, 'model.dat')
                     torch.save(model.state_dict(), path)
-                    logger.info("\n##### Model saved to %s at epoch: %d and %s/%s: %s #####\n", path, epoch, config.train['metric']['base'],
-                          config.train['metric']['score'], winner_report['report'][config.train['metric']['base']][config.train['metric']['score']])
+                    logger.info("\n##### Model saved to %s at epoch: %d and %s/%s: %s #####\n", path, epoch,
+                                config.train['metric']['base'],
+                                config.train['metric']['score'],
+                                winner_report['report'][config.train['metric']['base']][
+                                    config.train['metric']['score']])
 
     return winner_report
 
@@ -255,7 +373,7 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T
     Returns:
         Dict: Results (precision, recall, f1, examples, confusion matrix)
     """
-    device = torch.device(config.general['device']) # Create a torch device
+    device = torch.device(config.general['device'])  # Create a torch device
     batch_size_eval = config.general['batch_size_eval']
     pad_id = config.model['padding_idx']
     ignore_cpos = config.model['ignore_cpos']
@@ -263,9 +381,9 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T
 
     if class_weights is not None:
         class_weights = torch.FloatTensor(class_weights).to(device)
-        criterion = nn.CrossEntropyLoss(weight=class_weights) # Set the criterion to Cross Entropy Loss
+        criterion = nn.CrossEntropyLoss(weight=class_weights)  # Set the criterion to Cross Entropy Loss
     else:
-        criterion = nn.CrossEntropyLoss() # Set the criterion to Cross Entropy Loss
+        criterion = nn.CrossEntropyLoss()  # Set the criterion to Cross Entropy Loss
 
     y_eval = [x[2] for x in data]
     num_batches = math.ceil(len(data) / batch_size_eval)
@@ -276,8 +394,11 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T
 
     with torch.no_grad():
         for i in range(num_batches):
-            x, cpos, y = create_batch_piped_data(data, i*batch_size_eval, (i+1)*batch_size_eval, device=device, pad_id=pad_id)
-            logits = model(x, cpos, ignore_cpos=ignore_cpos)
+            x, cpos, attention_masks, y = create_batch_piped_data(data, i * batch_size_eval, (i + 1) * batch_size_eval,
+                                                                  device=device, pad_id=pad_id)
+
+            logits = model(x, center_positions=cpos, attention_mask=attention_masks, ignore_cpos=ignore_cpos)
+
             loss = criterion(logits, y)
 
             # Track loss and logits
@@ -290,24 +411,27 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T
     predictions = np.argmax(np.concatenate(all_logits, axis=0), axis=1)
     precision, recall, f1, support = precision_recall_fscore_support(y_eval, predictions, average=score_average)
 
-    labels = [name for (name, _) in sorted(config.general['category_value2id'].items(), key=lambda x:x[1])]
+    labels = [name for (name, _) in sorted(config.general['category_value2id'].items(), key=lambda x: x[1])]
     confusion = pd.DataFrame(
-        data=confusion_matrix(y_eval, predictions,),
+        data=confusion_matrix(y_eval, predictions, ),
         columns=["true " + label for label in labels],
         index=["predicted " + label for label in labels],
     )
 
-
     examples: Dict = {'FP': {}, 'FN': {}, 'TP': {}}
     id2category_value = {v: k for k, v in config.general['category_value2id'].items()}
     for i, p in enumerate(predictions):
         y = id2category_value[y_eval[i]]
         p = id2category_value[p]
         c = data[i][1]
+        if isinstance(c,list):
+            c = c[-1]
+
         tkns = data[i][0]
         assert tokenizer.hf_tokenizers is not None
-        text = tokenizer.hf_tokenizers.decode(tkns[0:c]) + " <<"+ tokenizer.hf_tokenizers.decode(tkns[c:c+1]).strip() + ">> " + \
-            tokenizer.hf_tokenizers.decode(tkns[c+1:])
+        text = tokenizer.hf_tokenizers.decode(tkns[0:c]) + " <<" + tokenizer.hf_tokenizers.decode(
+            tkns[c:c + 1]).strip() + ">> " + \
+               tokenizer.hf_tokenizers.decode(tkns[c + 1:])
         info = "Predicted: {}, True: {}".format(p, y)
         if p != y:
             # We made a mistake
diff --git a/medcat/utils/meta_cat/models.py b/medcat/utils/meta_cat/models.py
index c28a2c6ed..70e235316 100644
--- a/medcat/utils/meta_cat/models.py
+++ b/medcat/utils/meta_cat/models.py
@@ -1,11 +1,11 @@
 import torch
 from collections import OrderedDict
-from typing import Optional, Any, List
+from typing import Optional, Any, List, Iterable
 from torch import nn, Tensor
-from torch.nn import CrossEntropyLoss
-from transformers import BertPreTrainedModel, BertModel, BertConfig
-from transformers.modeling_outputs import TokenClassifierOutput
+from transformers import BertModel, AutoConfig
 from medcat.meta_cat import ConfigMetaCAT
+import logging
+logger = logging.getLogger(__name__)
 
 
 class LSTM(nn.Module):
@@ -24,7 +24,7 @@ def __init__(self, embeddings: Optional[Tensor], config: ConfigMetaCAT) -> None:
             # Disable training for the embeddings - IMPORTANT
             self.embeddings.weight.requires_grad = config.model['emb_grad']
 
-        # Create the RNN cell - devide 
+        # Create the RNN cell - devide
         self.rnn = nn.LSTM(input_size=config.model['input_size'],
                            hidden_size=config.model['hidden_size'] // config.model['num_directions'],
                            num_layers=config.model['num_layers'],
@@ -47,10 +47,11 @@ def forward(self,
             mask = attention_mask
 
         # Embed the input: from id -> vec
-        x = self.embeddings(x) # x.shape = batch_size x sequence_length x emb_size
+        x = self.embeddings(x)  # x.shape = batch_size x sequence_length x emb_size
 
         # Tell RNN to ignore padding and set the batch_first to True
-        x = nn.utils.rnn.pack_padded_sequence(x, mask.sum(1).int().view(-1).cpu(), batch_first=True, enforce_sorted=False)
+        x = nn.utils.rnn.pack_padded_sequence(x, mask.sum(1).int().view(-1).cpu(), batch_first=True,
+                                              enforce_sorted=False)
 
         # Run 'x' through the RNN
         x, hidden = self.rnn(x)
@@ -59,16 +60,22 @@ def forward(self,
         x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
 
         # Get what we need
-        row_indices = torch.arange(0, x.size(0)).long()
+        # row_indices = torch.arange(0, x.size(0)).long()
 
         # If this is  True we will always take the last state and not CPOS
         if ignore_cpos:
             x = hidden[0]
             x = x.view(self.config.model['num_layers'], self.config.model['num_directions'], -1,
-                       self.config.model['hidden_size']//self.config.model['num_directions'])
+                       self.config.model['hidden_size'] // self.config.model['num_directions'])
             x = x[-1, :, :, :].permute(1, 2, 0).reshape(-1, self.config.model['hidden_size'])
         else:
-            x = x[row_indices, center_positions, :]
+            x_all = []
+            for i, indices in enumerate(center_positions):
+                this_hidden = x[i, indices, :]
+                to_append, _ = torch.max(this_hidden, dim=0)
+                x_all.append(to_append)
+
+            x = torch.stack(x_all)
 
         # Push x through the fc network and add dropout
         x = self.d1(x)
@@ -77,34 +84,61 @@ def forward(self,
         return x
 
 
-class BertForMetaAnnotation(BertPreTrainedModel):
-
+class BertForMetaAnnotation(nn.Module):
     _keys_to_ignore_on_load_unexpected: List[str] = [r"pooler"]  # type: ignore
 
-    def __init__(self, config: BertConfig) -> None:
-        super().__init__(config)
-        self.num_labels = config.num_labels
-
-        self.bert = BertModel(config, add_pooling_layer=False)
-        self.dropout = nn.Dropout(config.hidden_dropout_prob)
-        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+    def __init__(self, config):
+        super(BertForMetaAnnotation, self).__init__()
+        _bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers'])
+        if config.model['input_size'] != _bertconfig.hidden_size:
+            logger.warning(f"\nInput size for {config.model.model_variant} model should be {_bertconfig.hidden_size}, provided input size is {config.model['input_size']} Input size changed to {_bertconfig.hidden_size}")
 
-        self.init_weights() # type: ignore
+        bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig)
+        self.config = config
+        self.config.use_return_dict = False
+        self.bert = bert
+        self.num_labels = config.model["nclasses"]
+        for param in self.bert.parameters():
+            param.requires_grad = not config.model.model_freeze_layers
+
+        hidden_size_2 = int(config.model.hidden_size / 2)
+        # dropout layer
+        self.dropout = nn.Dropout(config.model.dropout)
+        # relu activation function
+        self.relu = nn.ReLU()
+        # dense layer 1
+        self.fc1 = nn.Linear(_bertconfig.hidden_size*2, config.model.hidden_size)
+        # dense layer 2
+        self.fc2 = nn.Linear(config.model.hidden_size, hidden_size_2)
+        # dense layer 3
+        self.fc3 = nn.Linear(hidden_size_2, hidden_size_2)
+        # dense layer 3 (Output layer)
+        model_arch_config = config.model.model_architecture_config
+        if model_arch_config is not None:
+            if model_arch_config['fc2'] is True or model_arch_config['fc3'] is True:
+                self.fc4 = nn.Linear(hidden_size_2, self.num_labels)
+            else:
+                self.fc4 = nn.Linear(config.model.hidden_size, self.num_labels)
+        else:
+            self.fc4 = nn.Linear(hidden_size_2, self.num_labels)
+        # softmax activation function
+        self.softmax = nn.LogSoftmax(dim=1)
 
     def forward(
-        self,
-        input_ids: Optional[torch.LongTensor] = None,
-        attention_mask: Optional[torch.FloatTensor] = None,
-        token_type_ids: Optional[torch.LongTensor] = None,
-        position_ids: Optional[torch.LongTensor] = None,
-        head_mask: Optional[torch.FloatTensor] = None,
-        inputs_embeds: Optional[torch.FloatTensor] = None,
-        labels: Optional[torch.LongTensor] = None,
-        center_positions: Optional[Any] = None,
-        output_attentions: Optional[bool] = None,
-        output_hidden_states: Optional[bool] = None,
-        return_dict: Optional[bool] = None,
-    ) -> TokenClassifierOutput:
+            self,
+            input_ids: Optional[torch.LongTensor] = None,
+            attention_mask: Optional[torch.FloatTensor] = None,
+            token_type_ids: Optional[torch.LongTensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            head_mask: Optional[torch.FloatTensor] = None,
+            inputs_embeds: Optional[torch.FloatTensor] = None,
+            labels: Optional[torch.LongTensor] = None,
+            center_positions: Iterable[Any] = [],
+            ignore_cpos: Optional[bool] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            return_dict: Optional[bool] = None
+    ):
         """labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
             Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
             1]``.
@@ -119,47 +153,59 @@ def forward(
             labels (Optional[torch.LongTensor]): Labels. Defaults to None.
             center_positions (Optional[Any]): Cennter positions. Defaults to None.
             output_attentions (Optional[bool]): Output attentions. Defaults to None.
+            ignore_cpos: If center positions are to be ignored.
             output_hidden_states (Optional[bool]): Output hidden states. Defaults to None.
             return_dict (Optional[bool]): Whether to return a dict. Defaults to None.
 
         Returns:
             TokenClassifierOutput: The token classifier output.
         """
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict # type: ignore
+        # return_dict = return_dict if return_dict is not None else self.config.use_return_dict # type: ignore
 
-        outputs = self.bert( # type: ignore
+        outputs = self.bert(  # type: ignore
             input_ids,
-            attention_mask=attention_mask,
-            token_type_ids=token_type_ids,
-            position_ids=position_ids,
-            head_mask=head_mask,
-            inputs_embeds=inputs_embeds,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
+            attention_mask=attention_mask, output_hidden_states=True
         )
 
-        sequence_output = outputs[0] # (batch_size, sequence_length, hidden_size)
-
-        row_indices = torch.arange(0, sequence_output.size(0)).long()
-        sequence_output = sequence_output[row_indices, center_positions, :]
+        x_all = []
+        for i, indices in enumerate(center_positions):
+            this_hidden: torch.Tensor = outputs.last_hidden_state[i, indices, :]
+            to_append, _ = torch.max(this_hidden, dim=0)
+            x_all.append(to_append)
 
-        sequence_output = self.dropout(sequence_output)
-        logits = self.classifier(sequence_output)
+        x = torch.stack(x_all)
 
-        loss = None
-        if labels is not None:
-            loss_fct = CrossEntropyLoss()
-            # Only keep active parts of the loss
-            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+        pooled_output = outputs[1]
+        x = torch.cat((x, pooled_output), dim=1)
 
-        if not return_dict:
-            output = (logits,) + outputs[2:]
-            return ((loss,) + output) if loss is not None else output
-
-        return TokenClassifierOutput(
-            loss=loss,
-            logits=logits,
-            hidden_states=outputs.hidden_states,
-            attentions=outputs.attentions,
-        )
+        # fc1
+        x = self.dropout(x)
+        x = self.fc1(x)
+        x = self.relu(x)
+
+        if self.config.model.model_architecture_config is not None:
+            if self.config.model.model_architecture_config['fc2'] is True:
+                # fc2
+                x = self.fc2(x)
+                x = self.relu(x)
+                x = self.dropout(x)
+
+            if self.config.model.model_architecture_config['fc3'] is True:
+                # fc3
+                x = self.fc3(x)
+                x = self.relu(x)
+                x = self.dropout(x)
+        else:
+            # fc2
+            x = self.fc2(x)
+            x = self.relu(x)
+            x = self.dropout(x)
+
+            # fc3
+            x = self.fc3(x)
+            x = self.relu(x)
+            x = self.dropout(x)
+
+        # output layer
+        x = self.fc4(x)
+        return x
diff --git a/medcat/utils/ner/__init__.py b/medcat/utils/ner/__init__.py
index 2657c7df7..5d296dc3a 100644
--- a/medcat/utils/ner/__init__.py
+++ b/medcat/utils/ner/__init__.py
@@ -1,2 +1,2 @@
 from .metrics import metrics
-from .helpers import deid_text, make_or_update_cdb
+from .helpers import make_or_update_cdb
diff --git a/medcat/utils/ner/deid.py b/medcat/utils/ner/deid.py
index d71b52004..688bb1ea6 100644
--- a/medcat/utils/ner/deid.py
+++ b/medcat/utils/ner/deid.py
@@ -40,7 +40,7 @@
 from medcat.cat import CAT
 from medcat.utils.ner.model import NerModel
 
-from medcat.utils.ner.helpers import _deid_text as deid_text, replace_entities_in_text
+from medcat.utils.ner.helpers import replace_entities_in_text
 
 
 logger = logging.getLogger(__name__)
@@ -69,6 +69,12 @@ def train(self, json_path: Union[str, list, None],
     def deid_text(self, text: str, redact: bool = False) -> str:
         """Deidentify text and potentially redact information.
 
+        De-identified text.
+        If redaction is enabled, identifiable entities will be
+        replaced with starts (e.g `*****`).
+        Otherwise, the replacement will be the CUI or in other words,
+        the type of information that was hidden (e.g [PATIENT]).
+
         Args:
             text (str): The text to deidentify.
             redact (bool): Whether to redact the information.
@@ -76,8 +82,8 @@ def deid_text(self, text: str, redact: bool = False) -> str:
         Returns:
             str: The deidentified text.
         """
-        self.cat.get_entities
-        return deid_text(self.cat, text, redact=redact)
+        entities = self.cat.get_entities(text)['entities']
+        return replace_entities_in_text(text, entities, self.cat.cdb.get_name, redact=redact)
 
     def deid_multi_texts(self,
                          texts: Union[Iterable[str], Iterable[Tuple]],
diff --git a/medcat/utils/ner/helpers.py b/medcat/utils/ner/helpers.py
index 65ae7050c..bea1e45ca 100644
--- a/medcat/utils/ner/helpers.py
+++ b/medcat/utils/ner/helpers.py
@@ -3,35 +3,6 @@
 from medcat.utils.data_utils import count_annotations
 from medcat.cdb import CDB
 
-from medcat.utils.decorators import deprecated
-
-
-# For now, we will keep this method separate from the above class
-# This is so that we wouldn't need to create a thorwaway object
-# when calling the method from .helpers where it used to be.
-# After the deprecated method in .helpers is removed, we can
-# move this to a proper class method.
-def _deid_text(cat, text: str, redact: bool = False) -> str:
-    """De-identify text.
-
-    De-identified text.
-    If redaction is enabled, identifiable entities will be
-    replaced with starts (e.g `*****`).
-    Otherwise, the replacement will be the CUI or in other words,
-    the type of information that was hidden (e.g [PATIENT]).
-
-
-    Args:
-        cat (CAT): The CAT object to use for deid.
-        text (str): The input document.
-        redact (bool): Whether to redact. Defaults to False.
-
-    Returns:
-        str: The de-identified document.
-    """
-    entities = cat.get_entities(text)['entities']
-    return replace_entities_in_text(text, entities, cat.cdb.get_name, redact=redact)
-
 
 def replace_entities_in_text(text: str,
                              entities: Dict,
@@ -45,13 +16,6 @@ def replace_entities_in_text(text: str,
     return new_text
 
 
-@deprecated("API now allows creating a DeId model (medcat.utils.ner.deid.DeIdModel). "
-            "It aims to simplify the usage of DeId models. "
-            "The use of this model is encouraged over the use of this method.")
-def deid_text(*args, **kwargs) -> str:
-    return _deid_text(*args, **kwargs)
-
-
 def make_or_update_cdb(json_path: str, cdb: Optional[CDB] = None,
                        min_count: int = 0) -> CDB:
     """Creates a new CDB or updates an existing one with new
diff --git a/medcat/utils/saving/coding.py b/medcat/utils/saving/coding.py
index 81a8420aa..89f9c0651 100644
--- a/medcat/utils/saving/coding.py
+++ b/medcat/utils/saving/coding.py
@@ -1,6 +1,7 @@
 from typing import Any, Protocol, runtime_checkable, List, Union, Type, Optional, Callable
 
 import json
+import re
 
 
 @runtime_checkable
@@ -35,6 +36,7 @@ def try_encode(self, obj: object) -> Any:
 
 
 SET_IDENTIFIER = '==SET=='
+PATTERN_IDENTIFIER = "==PATTERN=="
 
 
 class SetEncoder(PartEncoder):
@@ -79,10 +81,34 @@ def try_decode(self, dct: dict) -> Union[dict, set]:
         return dct
 
 
+class PatternEncoder(PartEncoder):
+
+    def try_encode(self, obj):
+        if isinstance(obj, re.Pattern):
+            return {PATTERN_IDENTIFIER: obj.pattern}
+        raise UnsuitableObject()
+
+
+class PatternDecoder(PartDecoder):
+
+    def try_decode(self, dct: dict) -> Union[dict, re.Pattern]:
+        """Decode re.Patttern from input dicts.
+
+        Args:
+            dct (dict): The input dict
+
+        Returns:
+            Union[dict, set]: The original dict if this was not a serialized pattern, the pattern otherwise
+        """
+        if PATTERN_IDENTIFIER in dct:
+            return re.compile(dct[PATTERN_IDENTIFIER])
+        return dct
+
+
 PostProcessor = Callable[[Any], None]  # CDB -> None
 
-DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, ]
-DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, ]
+DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, PatternEncoder]
+DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, PatternDecoder]
 LOADING_POSTPROCESSORS: List[PostProcessor] = []
 
 
@@ -133,6 +159,8 @@ def object_hook(self, dct: dict) -> Any:
     def def_inst(cls) -> 'CustomDelegatingDecoder':
         if cls._def_inst is None:
             cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS])
+        elif len(cls._def_inst._delegates) < len(DEFAULT_DECODERS):
+            cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS])
         return cls._def_inst
 
 
diff --git a/medcat/utils/saving/envsnapshot.py b/medcat/utils/saving/envsnapshot.py
new file mode 100644
index 000000000..262c48410
--- /dev/null
+++ b/medcat/utils/saving/envsnapshot.py
@@ -0,0 +1,73 @@
+from typing import List, Dict, Any, Set
+
+import os
+import re
+import pkg_resources
+import platform
+
+
+ENV_SNAPSHOT_FILE_NAME = "environment_snapshot.json"
+
+INSTALL_REQUIRES_FILE_PATH = os.path.join(os.path.dirname(__file__),
+                                          "..", "..", "..",
+                                          "install_requires.txt")
+# NOTE: The install_requires.txt file is copied into the wheel during build
+#       so that it can be included in the distributed package.
+#       However, that means it's 1 folder closer to this file since it'll now
+#       be in the root of the package rather than the root of the project.
+INSTALL_REQUIRES_FILE_PATH_PIP = os.path.join(os.path.dirname(__file__),
+                                              "..", "..",
+                                              "install_requires.txt")
+
+
+def get_direct_dependencies() -> Set[str]:
+    """Get the set of direct dependeny names.
+
+    The current implementation reads install_requires.txt for dependenceies,
+    removes comments, whitespace, quotes; removes the versions and returns
+    the names as a set.
+
+    Returns:
+        Set[str]: The set of direct dependeny names.
+    """
+    req_file = INSTALL_REQUIRES_FILE_PATH
+    if not os.path.exists(req_file):
+        # When pip-installed. See note above near constant definiation
+        req_file = INSTALL_REQUIRES_FILE_PATH_PIP
+    with open(req_file) as f:
+        # read every line, strip quotes and comments
+        dep_lines = [line.split("#")[0].replace("'", "").replace('"', "").strip() for line in f.readlines()]
+        # remove comment-only (or empty) lines
+        deps = [dep for dep in dep_lines if dep]
+    return set(re.split("[@<=>~]", dep)[0].strip() for dep in deps)
+
+
+def get_installed_packages() -> List[List[str]]:
+    """Get the installed packages and their versions.
+
+    Returns:
+        List[List[str]]: List of lists. Each item contains of a dependency name and version.
+    """
+    direct_deps = get_direct_dependencies()
+    installed_packages = []
+    for package in pkg_resources.working_set:
+        if package.project_name not in direct_deps:
+            continue
+        installed_packages.append([package.project_name, package.version])
+    return installed_packages
+
+
+def get_environment_info() -> Dict[str, Any]:
+    """Get the current environment information.
+
+    This includes dependency versions, the OS, the CPU architecture and the python version.
+
+    Returns:
+        Dict[str, Any]: _description_
+    """
+    return {
+        "dependencies": get_installed_packages(),
+        "os": platform.platform(),
+        "cpu_architecture": platform.machine(),
+        "python_version": platform.python_version()
+    }
diff --git a/setup.py b/setup.py
index cfb824727..549e7c091 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,18 @@
 import setuptools
+import shutil
 
 with open("./README.md", "r") as fh:
     long_description = fh.read()
 
+# make a copy of install requirements so that it gets distributed with the wheel
+shutil.copy('install_requires.txt', 'medcat/install_requires.txt')
+
+with open("install_requires.txt") as f:
+    # read every line, strip quotes and comments
+    dep_lines = [l.split("#")[0].replace("'", "").replace('"', "").strip() for l in f.readlines()]
+    # remove comment-only (or empty) lines
+    install_requires = [dep for dep in dep_lines if dep]
+
 
 setuptools.setup(
     name="medcat",
@@ -17,31 +27,9 @@
     packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets',
               'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', 'medcat.utils.relation_extraction',
               'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'],
-    install_requires=[
-        'numpy>=1.22.0,<1.26.0',  # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy
-        'pandas>=1.4.2', # first to support 3.11
-        'gensim>=4.3.0,<5.0.0',  # 5.3.0 is first to support 3.11; avoid major version bump
-        'spacy>=3.6.0,<4.0.0',  # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump
-        'scipy~=1.9.2',  # 1.9.2 is first to support 3.11
-        'transformers>=4.34.0,<5.0.0',  # avoid major version bump
-        'accelerate>=0.23.0', # required by Trainer class in de-id
-        'torch>=1.13.0,<3.0.0', # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now
-        'tqdm>=4.27',
-        'scikit-learn>=1.1.3,<2.0.0',  # 1.1.3 is first to supporrt 3.11; avoid major version bump
-        'dill>=0.3.6,<1.0.0', # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump
-        'datasets>=2.2.2,<3.0.0', # avoid major bump
-        'jsonpickle>=2.0.0', # allow later versions, tested with 3.0.0
-        'psutil>=5.8.0',
-        # 0.70.12 uses older version of dill (i.e less than 0.3.5) which is required for datasets
-        'multiprocess~=0.70.12',  # 0.70.14 seemed to work just fine
-        'aiofiles>=0.8.0', # allow later versions, tested with 22.1.0
-        'ipywidgets>=7.6.5', # allow later versions, tested with 0.8.0
-        'xxhash>=3.0.0', # allow later versions, tested with 3.1.0
-        'blis>=0.7.5', # allow later versions, tested with 0.7.9
-        'click>=8.0.4', # allow later versions, tested with 8.1.3
-        'pydantic>=1.10.0,<2.0', # for spacy compatibility; avoid 2.0 due to breaking changes
-        "humanfriendly~=10.0",  # for human readable file / RAM sizes
-        ],
+    install_requires=install_requires,
+    include_package_data=True,
+    package_data={"medcat": ["install_requires.txt"]},
     classifiers=[
         "Programming Language :: Python :: 3",
         "Programming Language :: Python :: 3.8",
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29bb..7fbc9f3b2 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,25 @@
+from typing import Callable, Tuple
+
+from medcat.utils import decorators
+
+
+class DeprecatedMethodCallException(ValueError):
+
+    def __init__(self, func: Callable, msg: str,
+                 depr_version: Tuple[int, int, int],
+                 removal_version: Tuple[int, int, int]) -> None:
+        super().__init__(f"A deprecated method {func.__name__} was called. Deprecation message:\n{msg}\n"
+                         f"The method was deprecated in v{depr_version} and is scheduled for "
+                         f"removal in v{removal_version}")
+
+
+def deprecation_exception_raiser(message: str, depr_version: Tuple[int, int, int],
+                     removal_version: Tuple[int, int, int]):
+    def decorator(func: Callable) -> Callable:
+        def wrapper(*_, **__):
+            raise DeprecatedMethodCallException(func, message, depr_version, removal_version)
+        return wrapper
+    return decorator
+
+
+decorators.deprecated = deprecation_exception_raiser
diff --git a/tests/archive_tests/test_cdb_maker_archive.py b/tests/archive_tests/test_cdb_maker_archive.py
deleted file mode 100644
index 9e2fc2d72..000000000
--- a/tests/archive_tests/test_cdb_maker_archive.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import logging
-import unittest
-import numpy as np
-from medcat.cdb import CDB
-from medcat.cdb_maker import CDBMaker
-from medcat.config import Config
-from medcat.preprocessing.cleaners import prepare_name
-
-
-class CdbMakerArchiveTests(unittest.TestCase):
-
-    def setUp(self):
-        self.config = Config()
-        self.config.general['log_level'] = logging.DEBUG
-        self.maker = CDBMaker(self.config)
-
-        # Building a new CDB from two files (full_build)
-        csvs = ['../examples/cdb.csv', '../examples/cdb_2.csv']
-        self.cdb = self.maker.prepare_csvs(csvs, full_build=True)
-
-    def test_prepare_csvs(self):
-        assert len(self.cdb.cui2names) == 3
-        assert len(self.cdb.cui2snames) == 3
-        assert len(self.cdb.name2cuis) == 5
-        assert len(self.cdb.cui2tags) == 3
-        assert len(self.cdb.cui2preferred_name) == 2
-        assert len(self.cdb.cui2context_vectors) == 3
-        assert len(self.cdb.cui2count_train) == 3
-        assert self.cdb.name2cuis2status['virus']['C0000039'] == 'P'
-        assert self.cdb.cui2type_ids['C0000039'] == {'T234', 'T109', 'T123'}
-        assert self.cdb.addl_info['cui2original_names']['C0000039'] == {'Virus', 'Virus K', 'Virus M', 'Virus Z'}
-        assert self.cdb.addl_info['cui2description']['C0000039'].startswith("Synthetic")
-
-    def test_name_addition(self):
-        self.cdb.add_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.get_spacy_nlp(), {}, self.config), name_status='P', full_build=True)
-        assert self.cdb.addl_info['cui2original_names']['C0000239'] == {'MY: new,-_! Name.', 'Second csv'}
-        assert 'my:newname.' in self.cdb.name2cuis
-        assert 'my:new' in self.cdb.snames
-        assert 'my:newname.' in self.cdb.name2cuis2status
-        assert self.cdb.name2cuis2status['my:newname.'] == {'C0000239': 'P'}
-
-    def test_name_removal(self):
-        self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.get_spacy_nlp(), {}, self.config))
-        # Run again to make sure it does not break anything
-        self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.get_spacy_nlp(), {}, self.config))
-        assert len(self.cdb.name2cuis) == 5
-        assert 'my:newname.' not in self.cdb.name2cuis2status
-
-    def test_filtering(self):
-        cuis_to_keep = {'C0000039'} # Because of transition 2 will be kept
-        self.cdb.filter_by_cui(cuis_to_keep=cuis_to_keep)
-        assert len(self.cdb.cui2names) == 2
-        assert len(self.cdb.name2cuis) == 4
-        assert len(self.cdb.snames) == 4
-
-    def test_vector_addition(self):
-        self.cdb.reset_training()
-        np.random.seed(11)
-        cuis = list(self.cdb.cui2names.keys())
-        for i in range(2):
-            for cui in cuis:
-                vectors = {}
-                for cntx_type in self.config.linking['context_vector_sizes']:
-                    vectors[cntx_type] = np.random.rand(300)
-                self.cdb.update_context_vector(cui, vectors, negative=False)
-
-        assert self.cdb.cui2count_train['C0000139'] == 2
-        assert self.cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300
-
-
-    def test_negative(self):
-        cuis = list(self.cdb.cui2names.keys())
-        for cui in cuis:
-            vectors = {}
-            for cntx_type in self.config.linking['context_vector_sizes']:
-                vectors[cntx_type] = np.random.rand(300)
-            self.cdb.update_context_vector(cui, vectors, negative=True)
-
-        assert self.cdb.cui2count_train['C0000139'] == 2
-        assert self.cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300
-
-    def test_save_and_load(self):
-        self.cdb.save("./tmp_cdb.dat")
-        cdb2 = CDB.load('./tmp_cdb.dat')
-        # Check a random thing
-        assert cdb2.cui2context_vectors['C0000139']['long'][7] == self.cdb.cui2context_vectors['C0000139']['long'][7]
-
-    def test_training_import(self):
-        cdb2 = CDB.load('./tmp_cdb.dat')
-        self.cdb.reset_training()
-        cdb2.reset_training()
-        np.random.seed(11)
-        cuis = list(self.cdb.cui2names.keys())
-        for i in range(2):
-            for cui in cuis:
-                vectors = {}
-                for cntx_type in self.config.linking['context_vector_sizes']:
-                    vectors[cntx_type] = np.random.rand(300)
-                self.cdb.update_context_vector(cui, vectors, negative=False)
-
-        cdb2.import_training(cdb=self.cdb, overwrite=True)
-        assert cdb2.cui2context_vectors['C0000139']['long'][7] == self.cdb.cui2context_vectors['C0000139']['long'][7]
-        assert cdb2.cui2count_train['C0000139'] == self.cdb.cui2count_train['C0000139']
-
-    def test_concept_similarity(self):
-        cdb = CDB(config=self.config)
-        np.random.seed(11)
-        for i in range(500):
-            cui = "C" + str(i)
-            type_ids = {'T-' + str(i%10)}
-            cdb._add_concept(cui=cui, names=prepare_name('Name: ' + str(i), self.maker.pipe.get_spacy_nlp(), {}, self.config), ontologies=set(),
-                            name_status='P', type_ids=type_ids, description='', full_build=True)
-
-            vectors = {}
-            for cntx_type in self.config.linking['context_vector_sizes']:
-                vectors[cntx_type] = np.random.rand(300)
-            cdb.update_context_vector(cui, vectors, negative=False)
-        res = cdb.most_similar('C200', 'long', type_id_filter=['T-0'], min_cnt=1, topn=10, force_build=True)
-        assert len(res) == 10
-
-    def test_training_reset(self):
-        self.cdb.reset_training()
-        assert len(self.cdb.cui2context_vectors['C0']) == 0
-        assert self.cdb.cui2count_train['C0'] == 0
diff --git a/tests/archive_tests/test_ner_archive.py b/tests/archive_tests/test_ner_archive.py
deleted file mode 100644
index d41ccd0c7..000000000
--- a/tests/archive_tests/test_ner_archive.py
+++ /dev/null
@@ -1,139 +0,0 @@
-import logging
-import unittest
-import numpy as np
-from timeit import default_timer as timer
-from medcat.cdb import CDB
-from medcat.preprocessing.tokenizers import spacy_split_all
-from medcat.ner.vocab_based_ner import NER
-from medcat.preprocessing.taggers import tag_skip_and_punct
-from medcat.pipe import Pipe
-from medcat.utils.normalizers import BasicSpellChecker
-from medcat.vocab import Vocab
-from medcat.preprocessing.cleaners import prepare_name
-from medcat.linking.vector_context_model import ContextModel
-from medcat.linking.context_based_linker import Linker
-from medcat.config import Config
-
-from ..helper import VocabDownloader
-
-
-class NerArchiveTests(unittest.TestCase):
-
-    def setUp(self) -> None:
-        self.config = Config()
-        self.config.general['log_level'] = logging.INFO
-        cdb = CDB(config=self.config)
-
-        self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config)
-        self.nlp.add_tagger(tagger=tag_skip_and_punct,
-                       name='skip_and_punct',
-                       additional_fields=['is_punct'])
-
-        # Add a couple of names
-        cdb.add_names(cui='S-229004', names=prepare_name('Movar', self.nlp, {}, self.config))
-        cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', self.nlp, {}, self.config))
-        cdb.add_names(cui='S-229005', names=prepare_name('CDB', self.nlp, {}, self.config))
-        # Check
-        #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}}
-
-        downloader = VocabDownloader()
-        self.vocab_path = downloader.vocab_path
-        downloader.check_or_download()
-
-        vocab = Vocab.load(self.vocab_path)
-        # Make the pipeline
-        self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config)
-        self.nlp.add_tagger(tagger=tag_skip_and_punct,
-                       name='skip_and_punct',
-                       additional_fields=['is_punct'])
-        spell_checker = BasicSpellChecker(cdb_vocab=cdb.vocab, config=self.config, data_vocab=vocab)
-        self.nlp.add_token_normalizer(spell_checker=spell_checker, config=self.config)
-        ner = NER(cdb, self.config)
-        self.nlp.add_ner(ner)
-
-        # Add Linker
-        link = Linker(cdb, vocab, self.config)
-        self.nlp.add_linker(link)
-
-        self.text = "CDB - I was running and then Movar    Virus attacked and CDb"
-
-    def tearDown(self) -> None:
-        self.nlp.destroy()
-
-    def test_limits_for_tokens_and_uppercase(self):
-        self.config.ner['max_skip_tokens'] = 1
-        self.config.ner['upper_case_limit_len'] = 4
-        self.config.linking['disamb_length_limit'] = 2
-
-        d = self.nlp(self.text)
-
-        assert len(d._.ents) == 2
-        assert d._.ents[0]._.link_candidates[0] == 'S-229004'
-
-    def test_change_limit_for_skip(self):
-        self.config.ner['max_skip_tokens'] = 3
-        d = self.nlp(self.text)
-        assert len(d._.ents) == 3
-
-    def test_change_limit_for_upper_case(self):
-        self.config.ner['upper_case_limit_len'] = 3
-        d = self.nlp(self.text)
-        assert len(d._.ents) == 4
-
-    def test_check_name_length_limit(self):
-        self.config.ner['min_name_len'] = 4
-        d = self.nlp(self.text)
-        assert len(d._.ents) == 2
-
-    def test_speed(self):
-        text = "CDB - I was running and then Movar    Virus attacked and CDb"
-        text = text * 300
-        self.config.general['spell_check'] = True
-        start = timer()
-        for i in range(50):
-            d = self.nlp(text)
-        end = timer()
-        print("Time: ", end - start)
-
-    def test_without_spell_check(self):
-        # Now without spell check
-        self.config.general['spell_check'] = False
-        start = timer()
-        for i in range(50):
-            d = self.nlp(self.text)
-        end = timer()
-        print("Time: ", end - start)
-
-
-    def test_for_linker(self):
-        self.config = Config()
-        self.config.general['log_level'] = logging.DEBUG
-        cdb = CDB(config=self.config)
-
-        # Add a couple of names
-        cdb.add_names(cui='S-229004', names=prepare_name('Movar', self.nlp, {}, self.config))
-        cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', self.nlp, {}, self.config))
-        cdb.add_names(cui='S-229005', names=prepare_name('CDB', self.nlp, {}, self.config))
-        cdb.add_names(cui='S-2290045', names=prepare_name('Movar', self.nlp, {}, self.config))
-        # Check
-        #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}, 'S-2290045': {'movar'}}
-
-        cuis = list(cdb.cui2names.keys())
-        for cui in cuis[0:50]:
-            vectors = {'short': np.random.rand(300),
-                      'long': np.random.rand(300),
-                      'medium': np.random.rand(300)
-                      }
-            cdb.update_context_vector(cui, vectors, negative=False)
-
-        d = self.nlp(self.text)
-        vocab = Vocab.load(self.vocab_path)
-        cm = ContextModel(cdb, vocab, self.config)
-        cm.train_using_negative_sampling('S-229004')
-        self.config.linking['train_count_threshold'] = 0
-
-        cm.train('S-229004', d._.ents[1], d)
-
-        cm.similarity('S-229004', d._.ents[1], d)
-
-        cm.disambiguate(['S-2290045', 'S-229004'], d._.ents[1], 'movar', d)
diff --git a/tests/check_deprecations.py b/tests/check_deprecations.py
new file mode 100644
index 000000000..4d10fba97
--- /dev/null
+++ b/tests/check_deprecations.py
@@ -0,0 +1,178 @@
+from typing import List, Dict, Optional, Tuple, Callable
+import ast
+import os
+from sys import argv as sys_argv
+from sys import exit as sys_exit
+from medcat.utils.decorators import deprecated
+
+
+def get_decorator_args(decorator: ast.expr, decorator_name: str) -> Tuple[Optional[List[str]], Optional[Dict[str, str]]]:
+    if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name) and decorator.func.id == decorator_name:
+        return decorator.args, {kw.arg: kw.value for kw in decorator.keywords}
+    return None, None
+
+
+def is_decorated_with(node: ast.FunctionDef, decorator_name: str) -> Tuple[bool, List[str], Dict[str, str]]:
+    for decorator in node.decorator_list:
+        if isinstance(decorator, ast.Name) and decorator.id == decorator_name:
+            return True, [], {}
+        elif isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name) and decorator.func.id == decorator_name:
+            args, kwargs = get_decorator_args(decorator, decorator_name)
+            return True, args, kwargs
+    return False, [], {}
+
+
+class FunctionVisitor(ast.NodeVisitor):
+    def __init__(self, decorator_name: str):
+        self.decorator_name = decorator_name
+        self.decorated_functions: List[Dict[str, Optional[List[str]]]] = []
+        self.context: List[str] = []
+
+    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
+        self.context.append(node.name)
+        is_decorated, args, kwargs = is_decorated_with(node, self.decorator_name)
+        if is_decorated:
+            self.decorated_functions.append({
+                'name': '.'.join(self.context),
+                'args': args,
+                'kwargs': kwargs
+            })
+        self.generic_visit(node)
+        self.context.pop()
+
+    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
+        self.visit_FunctionDef(node)
+
+    def visit_ClassDef(self, node: ast.ClassDef) -> None:
+        self.context.append(node.name)
+        self.generic_visit(node)
+        self.context.pop()
+
+
+def find_decorated_functions_in_file(filepath: str, decorator_name: str) -> List[Dict[str, Optional[List[str]]]]:
+    with open(filepath, "r") as source:
+        tree = ast.parse(source.read())
+
+    visitor = FunctionVisitor(decorator_name)
+    visitor.visit(tree)
+    return visitor.decorated_functions
+
+
+def find_decorated_functions_in_codebase(codebase_path: str, decorator_name: str) -> Dict[str, List[Dict[str, Optional[List[str]]]]]:
+    decorated_functions: Dict[str, List[Dict[str, Optional[List[str]]]]] = {}
+    for root, _, files in os.walk(codebase_path):
+        for file in files:
+            if file.endswith(".py"):
+                filepath = os.path.join(root, file)
+                decorated_funcs = find_decorated_functions_in_file(filepath, decorator_name)
+                if decorated_funcs:
+                    decorated_functions[filepath] = decorated_funcs
+    return decorated_functions
+
+
+def extract_version_from_tuple(tuple_node: ast.Tuple) -> Tuple[int, int, int]:
+    """Extract constant values from an ast.Tuple node.
+
+    Args:
+        tuple_node (ast.Tuple): The AST node representing the tuple.
+
+    Raises:
+        ValueError: If the tuple contains unsuitable values.
+
+    Returns:
+        Tuple[int, int, int]: The major, minor, and patch version.
+    """
+    values = []
+    for element in tuple_node.elts:
+        if isinstance(element, ast.Constant):
+            cur_value = element.value
+        else:
+            raise ValueError(f"Unsupported element type in tuple: {type(element)}")
+        values.append(cur_value)
+        if not isinstance(cur_value, int):
+            raise ValueError(f"Unknown type of value in version tuple: {type(cur_value)}: {cur_value}")
+    if len(values) != 3:
+        raise ValueError(f"Unexpected number of version elements ({len(values)}): {values}")
+    return tuple(values)
+
+
+def get_deprecated_methods_that_should_have_been_removed(codebase_path: str,
+                                                         decorator_name: str,
+                                                         medcat_version: Tuple[int, int, int]
+                                                         ) -> List[Tuple[str, str, Tuple[int, int, int]]]:
+    """Get deprecated methods that should have been removed.
+
+    Args:
+        codebase_path (str): Path to codebase.
+        decorator_name (str): Name of decorator to check.
+        medcat_version (Tuple[int, int, int]): The current MedCAT version.
+
+    Returns:
+        List[Tuple[str, str, Tuple[int, int, int]]]:
+            The list of file, method, and version in which the method should have been deprecated.
+    """
+    decorated_functions = find_decorated_functions_in_codebase(codebase_path, decorator_name)
+
+    should_be_removed = []
+    for filepath, funcs in decorated_functions.items():
+        for func in funcs:
+            func_name = func['name']
+            args, kwargs = func['args'], func['kwargs']
+            if 'removal_version' in kwargs:
+                rem_ver = kwargs['removal_version']
+            else:
+                rem_ver = args[-1]
+            rem_ver = extract_version_from_tuple(rem_ver)
+            if rem_ver <= medcat_version:
+                should_be_removed.append((filepath, func_name, rem_ver))
+    return should_be_removed
+
+
+def _ver2str(ver: Tuple[int, int, int]) -> str:
+    maj, min, patch = ver
+    return f"v{maj}.{min}.{patch}"
+
+
+def main(args: List[str] = sys_argv[1:],
+         deprecated_decorator: Callable[[], Callable] = deprecated):
+    decorator_name = deprecated_decorator.__name__
+    pos_args = [arg for arg in args if not arg.startswith("-")]
+    codebase_path = 'medcat' if len(pos_args) <= 1 else pos_args[1]
+    print("arg0", repr(args[0]))
+    remove_ver_prefix = '--remove-prefix' in args
+    pure_ver = pos_args[0]
+    if remove_ver_prefix:
+        # remove v from (e.g) v1.12.0
+        pure_ver = pure_ver[1:]
+    medcat_version = tuple(int(s) for s in pure_ver.split("."))
+    compare_next_minor_release = '--next-version' in args
+
+    # pad out medcat varesions
+    # NOTE: Mostly so that e.g (1, 12, 0) <= (1, 12, 0) would be True.
+    #       Otherwise (1, 12, 0) <= (1, 12) would equate to False.
+    if len(medcat_version) < 3:
+        medcat_version = tuple(list(medcat_version) + [0,] * (3 - len(medcat_version)))
+    # NOTE: In main GHA workflow we know the current minor release
+    #       but after that release has been done, we (generally, but not always!)
+    #       want to start removing deprecated methods due to be removed before
+    #       the next minor release.
+    if compare_next_minor_release:
+        l_ver = list(medcat_version)
+        l_ver[1] += 1
+        medcat_version = tuple(l_ver)
+
+    to_remove = get_deprecated_methods_that_should_have_been_removed(codebase_path, decorator_name, medcat_version)
+
+    ver_descr = "next" if compare_next_minor_release else "current"
+    for filepath, func_name, rem_ver in to_remove:
+        print("SHOULD ALREADY BE REMOVED")
+        print(f"In file: {filepath}")
+        print(f" Method: {func_name}")
+        print(f" Scheduled for removal in: {_ver2str(rem_ver)} ({ver_descr} version: {_ver2str(medcat_version)})")
+    if to_remove:
+        print("Found issues - see above")
+        sys_exit(1)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tests/medmentions/make_cdb.py b/tests/medmentions/make_cdb.py
deleted file mode 100644
index feb8629d2..000000000
--- a/tests/medmentions/make_cdb.py
+++ /dev/null
@@ -1,120 +0,0 @@
-from medcat.cdb_maker import CDBMaker
-from medcat.config import Config, weighted_average
-from functools import partial
-import numpy as np
-import logging
-
-from ..helper import VocabDownloader
-
-
-config = Config()
-config.general['log_level'] = logging.INFO
-config.general['spacy_model'] = 'en_core_sci_lg'
-maker = CDBMaker(config)
-
-# Building a new CDB from two files (full_build)
-csvs = ['./tmp_medmentions.csv']
-cdb = maker.prepare_csvs(csvs, full_build=True)
-
-cdb.save("./tmp_cdb.dat")
-
-
-from medcat.vocab import Vocab
-from medcat.cdb import CDB
-from medcat.cat import CAT
-
-downloader = VocabDownloader()
-vocab_path = downloader.vocab_path
-downloader.check_or_download()
-
-config = Config()
-cdb = CDB.load("./tmp_cdb.dat", config=config)
-vocab = Vocab.load(vocab_path)
-
-cdb.reset_training()
-
-cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab)
-cat.config.ner['min_name_len'] = 3
-cat.config.ner['upper_case_limit_len'] = 3
-cat.config.linking['disamb_length_limit'] = 3
-cat.config.linking['filters'] = {'cuis': set()}
-cat.config.linking['train_count_threshold'] = -1
-cat.config.linking['context_vector_sizes'] = {'xlong': 27, 'long': 18, 'medium': 9, 'short': 3}
-cat.config.linking['context_vector_weights'] = {'xlong': 0, 'long': 0.4, 'medium': 0.4, 'short': 0.2}
-cat.config.linking['weighted_average_function'] = partial(weighted_average, factor=0.0004)
-cat.config.linking['similarity_threshold_type'] = 'dynamic'
-cat.config.linking['similarity_threshold'] = 0.35
-cat.config.linking['calculate_dynamic_threshold'] = True
-
-cat.train(df.text.values, fine_tune=True)
-
-
-cdb.config.general['spacy_disabled_components'] = ['ner', 'parser', 'vectors', 'textcat',
-                                                      'entity_linker', 'sentencizer', 'entity_ruler', 'merge_noun_chunks',
-                                                                                                    'merge_entities', 'merge_subtokens']
-
-%load_ext autoreload
-%autoreload 2
-
-# Train
-_ = cat.train(open("./tmp_medmentions_text_only.txt", 'r'), fine_tune=False)
-
-_ = cat.train_supervised("/home/ubuntu/data/medmentions/medmentions.json", reset_cui_count=True, nepochs=13, train_from_false_positives=True, print_stats=3, test_size=0.1)
-cdb.save("/home/ubuntu/data/umls/2020ab/cdb_trained_medmen.dat")
-
-
-_ = cat.train_supervised("/home/ubuntu/data/medmentions/medmentions.json", reset_cui_count=False, nepochs=13, train_from_false_positives=True, print_stats=3, test_size=0)
-
-cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab)
-cat.config.linking['similarity_threshold'] = 0.1
-cat.config.ner['min_name_len'] = 2
-cat.config.ner['upper_case_limit_len'] = 1
-cat.config.linking['train_count_threshold'] = -2
-cat.config.linking['filters']['cuis'] = set()
-cat.config.linking['context_vector_sizes'] = {'xlong': 27, 'long': 18, 'medium': 9, 'short': 3}
-cat.config.linking['context_vector_weights'] = {'xlong': 0.1, 'long': 0.4, 'medium': 0.4, 'short': 0.1}
-cat.config.linking['similarity_threshold_type'] = 'static'
-
-cat.config.linking['similarity_threshold_type'] = 'dynamic'
-cat.config.linking['similarity_threshold'] = 0.35
-cat.config.linking['calculate_dynamic_threshold'] = True
-
-
-# Print some stats
-_ = cat._print_stats(data)
-
-#Epoch: 0, Prec: 0.4331506351144245, Rec: 0.5207520064957372, F1: 0.47292889758643175
-#p: 0.421  r: 0.507  f1: 0.460
-
-
-# Remove all names that are numbers
-for name in list(cdb.name2cuis.keys()):
-    if name.replace(".", '').replace("~", '').replace(",", '').replace(":", '').replace("-", '').isnumeric():
-        del cdb.name2cuis[name]
-        print(name)
-
-
-for name in list(cdb.name2cuis.keys()):
-    if len(name) < 7 and (not name.isalpha()) and len(re.sub("[^A-Za-z]*", '', name)) < 2:
-        del cdb.name2cuis[name]
-        print(name)
-
-
-
-
-# RUN SUPER
-cdb = CDB.load("./tmp_cdb.dat")
-vocab = Vocab.load(vocab_path)
-cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab)
-
-
-# Train supervised
-cdb.reset_cui_count()
-cat.config.ner['uppe_case_limit_len'] = 1
-cat.config.ner['min_name_len'] = 1
-data_path = "./tmp_medmentions.json"
-_ = cat.train_supervised(data_path, use_cui_doc_limit=True, nepochs=30, devalue_others=True, test_size=0.2)
-
-
-cdb = maker.prepare_csvs(csv_paths=csvs)
-cdb.save("/home/ubuntu/data/umls/2020ab/cdb_vbg.dat")
diff --git a/tests/medmentions/prepare_data.py b/tests/medmentions/prepare_data.py
deleted file mode 100644
index 6e1bfdf2e..000000000
--- a/tests/medmentions/prepare_data.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from medcat.utils.medmentions import original2concept_csv
-from medcat.utils.medmentions import original2json
-from medcat.utils.medmentions import original2pure_text
-
-_ = original2json("../../examples/medmentions/medmentions.txt", '../../examples/medmentions/tmp_medmentions.json')
-_ = original2concept_csv("../../examples/medmentions/medmentions.txt", '../../examples/medmentions/tmp_medmentions.csv')
-original2pure_text("../../examples/medmentions/medmentions.txt", '../../examples/medmentions/tmp_medmentions_text_only.txt')
diff --git a/tests/resources/jsonpickle_config.json b/tests/resources/jsonpickle_config.json
new file mode 100644
index 000000000..784f933ce
--- /dev/null
+++ b/tests/resources/jsonpickle_config.json
@@ -0,0 +1,274 @@
+{
+    "version": {
+      "py/object": "medcat.config.VersionInfo",
+      "py/state": {
+        "__dict__": {
+          "history": ["0c0de303b6dc0020"],
+          "meta_cats": {},
+          "cdb_info": {},
+          "performance": {
+            "ner": {},
+            "meta": {}
+          },
+          "description": "No description",
+          "id": null,
+          "last_modified": null,
+          "location": null,
+          "ontology": null,
+          "medcat_version": null
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "cdb_maker": {
+      "py/object": "medcat.config.CDBMaker",
+      "py/state": {
+        "__dict__": {
+          "name_versions": [
+            "LOWER",
+            "CLEAN"
+          ],
+          "multi_separator": "|",
+          "remove_parenthesis": 5,
+          "min_letters_required": 2
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "annotation_output": {
+      "py/object": "medcat.config.AnnotationOutput",
+      "py/state": {
+        "__dict__": {
+          "doc_extended_info": false,
+          "context_left": -1,
+          "context_right": -1,
+          "lowercase_context": true,
+          "include_text_in_output": false
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "general": {
+      "py/object": "medcat.config.General",
+      "py/state": {
+        "__dict__": {
+          "spacy_disabled_components": [
+            "ner",
+            "parser",
+            "vectors",
+            "textcat",
+            "entity_linker",
+            "sentencizer",
+            "entity_ruler",
+            "merge_noun_chunks",
+            "merge_entities",
+            "merge_subtokens"
+          ],
+          "checkpoint": {
+            "py/object": "medcat.config.CheckPoint",
+            "py/state": {
+              "__dict__": {
+                "output_dir": "checkpoints",
+                "steps": null,
+                "max_to_keep": 1
+              },
+              "__fields_set__": {
+                "py/set": []
+              },
+              "__private_attribute_values__": {}
+            }
+          },
+          "log_level": 20,
+          "log_format": "%(levelname)s:%(name)s: %(message)s",
+          "log_path": "./medcat.log",
+          "spacy_model": "en_core_web_lg",
+          "separator": "~",
+          "spell_check": true,
+          "diacritics": false,
+          "spell_check_deep": false,
+          "spell_check_len_limit": 7,
+          "show_nested_entities": false,
+          "full_unlink": false,
+          "workers": 7,
+          "make_pretty_labels": null,
+          "map_cui_to_group": false
+        },
+        "__fields_set__": {
+          "py/set": [
+            "spacy_model"
+          ]
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "preprocessing": {
+      "py/object": "medcat.config.Preprocessing",
+      "py/state": {
+        "__dict__": {
+          "words_to_skip": {
+            "py/set": [
+              "nos"
+            ]
+          },
+          "keep_punct": {
+            "py/set": [
+              ".",
+              ":"
+            ]
+          },
+          "do_not_normalize": {
+            "py/set": [
+              "VBD",
+              "VBP",
+              "VBN",
+              "JJR",
+              "JJS",
+              "VBG"
+            ]
+          },
+          "skip_stopwords": false,
+          "min_len_normalize": 5,
+          "stopwords": {
+            "py/set": [
+              "three",
+              "two",
+              "one"
+            ]
+          },
+          "max_document_length": 1000000
+        },
+        "__fields_set__": {
+          "py/set": [
+            "stopwords"
+          ]
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "ner": {
+      "py/object": "medcat.config.Ner",
+      "py/state": {
+        "__dict__": {
+          "min_name_len": 3,
+          "max_skip_tokens": 2,
+          "check_upper_case_names": false,
+          "upper_case_limit_len": 4,
+          "try_reverse_word_order": false
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "linking": {
+      "py/object": "medcat.config.Linking",
+      "py/state": {
+        "__dict__": {
+          "optim": {
+            "type": "linear",
+            "base_lr": 1,
+            "min_lr": 0.00005
+          },
+          "context_vector_sizes": {
+            "xlong": 27,
+            "long": 18,
+            "medium": 9,
+            "short": 3
+          },
+          "context_vector_weights": {
+            "xlong": 0.1,
+            "long": 0.4,
+            "medium": 0.4,
+            "short": 0.1
+          },
+          "filters": {
+            "py/object": "medcat.config.LinkingFilters",
+            "py/state": {
+              "__dict__": {
+                "cuis": {
+                  "py/set": []
+                },
+                "cuis_exclude": {
+                  "py/set": []
+                }
+              },
+              "__fields_set__": {
+                "py/set": []
+              },
+              "__private_attribute_values__": {}
+            }
+          },
+          "train": true,
+          "random_replacement_unsupervised": 0.8,
+          "disamb_length_limit": 3,
+          "filter_before_disamb": false,
+          "train_count_threshold": 1,
+          "always_calculate_similarity": false,
+          "weighted_average_function": {
+            "py/object": "medcat.config._DefPartial",
+            "fun": {
+              "py/reduce": [
+                {
+                  "py/type": "functools.partial"
+                },
+                {
+                  "py/tuple": [
+                    {
+                      "py/function": "medcat.utils.config_utils.weighted_average"
+                    }
+                  ]
+                },
+                {
+                  "py/tuple": [
+                    {
+                      "py/function": "medcat.utils.config_utils.weighted_average"
+                    },
+                    {
+                      "py/tuple": []
+                    },
+                    {
+                      "factor": 0.0004
+                    },
+                    {}
+                  ]
+                }
+              ]
+            }
+          },
+          "calculate_dynamic_threshold": false,
+          "similarity_threshold_type": "static",
+          "similarity_threshold": 0.25,
+          "negative_probability": 0.5,
+          "negative_ignore_punct_and_num": true,
+          "prefer_primary_name": 0.35,
+          "prefer_frequent_concepts": 0.35,
+          "subsample_after": 30000,
+          "devalue_linked_concepts": false,
+          "context_ignore_center_tokens": false
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "word_skipper": {
+      "py/object": "re.Pattern",
+      "pattern": "^(nos)$"
+    },
+    "punct_checker": {
+      "py/object": "re.Pattern",
+      "pattern": "[^a-z0-9]+"
+    },
+    "hash": null
+  }
\ No newline at end of file
diff --git a/tests/resources/jsonpickle_meta_cat_config.json b/tests/resources/jsonpickle_meta_cat_config.json
new file mode 100644
index 000000000..4da001c6c
--- /dev/null
+++ b/tests/resources/jsonpickle_meta_cat_config.json
@@ -0,0 +1,89 @@
+{
+    "general": {
+      "py/object": "medcat.config_meta_cat.General",
+      "py/state": {
+        "__dict__": {
+          "device": "cpu",
+          "disable_component_lock": false,
+          "seed": -100,
+          "description": "No description",
+          "category_name": null,
+          "category_value2id": {},
+          "vocab_size": null,
+          "lowercase": true,
+          "cntx_left": 15,
+          "cntx_right": 10,
+          "replace_center": null,
+          "batch_size_eval": 5000,
+          "annotate_overlapping": false,
+          "tokenizer_name": "bbpe",
+          "save_and_reuse_tokens": false,
+          "pipe_batch_size_in_chars": 20000000,
+          "span_group": null
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "model": {
+      "py/object": "medcat.config_meta_cat.Model",
+      "py/state": {
+        "__dict__": {
+          "model_name": "lstm",
+          "model_variant": "bert-base-uncased",
+          "model_freeze_layers": true,
+          "num_layers": 2,
+          "input_size": 300,
+          "hidden_size": 300,
+          "dropout": 0.5,
+          "phase_number": 0,
+          "category_undersample": "",
+          "model_architecture_config": {
+            "fc2": true,
+            "fc3": false,
+            "lr_scheduler": true
+          },
+          "num_directions": 2,
+          "nclasses": 2,
+          "padding_idx": -1,
+          "emb_grad": true,
+          "ignore_cpos": false
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "train": {
+      "py/object": "medcat.config_meta_cat.Train",
+      "py/state": {
+        "__dict__": {
+          "batch_size": 100,
+          "nepochs": 50,
+          "lr": 0.001,
+          "test_size": 0.1,
+          "shuffle_data": true,
+          "class_weights": null,
+          "compute_class_weights": false,
+          "score_average": "weighted",
+          "prerequisites": {},
+          "cui_filter": null,
+          "auto_save_model": true,
+          "last_train_on": null,
+          "metric": {
+            "base": "weighted avg",
+            "score": "f1-score"
+          },
+          "loss_funct": "cross_entropy",
+          "gamma": 2
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    }
+  }
\ No newline at end of file
diff --git a/tests/resources/jsonpickle_rel_cat_config.json b/tests/resources/jsonpickle_rel_cat_config.json
new file mode 100644
index 000000000..411caaa52
--- /dev/null
+++ b/tests/resources/jsonpickle_rel_cat_config.json
@@ -0,0 +1,91 @@
+{
+    "general": {
+      "py/object": "medcat.config_rel_cat.General",
+      "py/state": {
+        "__dict__": {
+          "device": "cpu",
+          "relation_type_filter_pairs": [],
+          "vocab_size": null,
+          "lowercase": true,
+          "cntx_left": 15,
+          "cntx_right": 15,
+          "window_size": 300,
+          "mct_export_max_non_rel_sample_size": 200,
+          "mct_export_create_addl_rels": false,
+          "tokenizer_name": "bert",
+          "model_name": "bert-base-uncased",
+          "log_level": 20,
+          "max_seq_length": 512,
+          "tokenizer_special_tokens": false,
+          "annotation_schema_tag_ids": [],
+          "labels2idx": {},
+          "idx2labels": {},
+          "pin_memory": true,
+          "seed": 13,
+          "task": "train"
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "model": {
+      "py/object": "medcat.config_rel_cat.Model",
+      "py/state": {
+        "__dict__": {
+          "input_size": 300,
+          "hidden_size": 768,
+          "hidden_layers": 3,
+          "model_size": 5120,
+          "dropout": 0.2,
+          "num_directions": 2,
+          "padding_idx": -1,
+          "emb_grad": true,
+          "ignore_cpos": false
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    },
+    "train": {
+      "py/object": "medcat.config_rel_cat.Train",
+      "py/state": {
+        "__dict__": {
+          "nclasses": 2,
+          "batch_size": 25,
+          "nepochs": 1,
+          "lr": 100000,
+          "adam_epsilon": 0.0001,
+          "test_size": 0.2,
+          "gradient_acc_steps": 1,
+          "multistep_milestones": [
+            2,
+            4,
+            6,
+            8,
+            12,
+            15,
+            18,
+            20,
+            22,
+            24,
+            26,
+            30
+          ],
+          "multistep_lr_gamma": 0.8,
+          "max_grad_norm": 1,
+          "shuffle_data": true,
+          "class_weights": null,
+          "score_average": "weighted",
+          "auto_save_model": true
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    }
+  }
\ No newline at end of file
diff --git a/tests/resources/jsonpickle_tner_config.json b/tests/resources/jsonpickle_tner_config.json
new file mode 100644
index 000000000..eb3639453
--- /dev/null
+++ b/tests/resources/jsonpickle_tner_config.json
@@ -0,0 +1,23 @@
+{
+    "general": {
+      "py/object": "medcat.config_transformers_ner.General",
+      "py/state": {
+        "__dict__": {
+          "name": "deid",
+          "model_name": "roberta-base",
+          "seed": 13,
+          "description": "No description",
+          "pipe_batch_size_in_chars": -100,
+          "ner_aggregation_strategy": "simple",
+          "chunking_overlap_window": 5,
+          "test_size": 0.2,
+          "last_train_on": null,
+          "verbose_metrics": false
+        },
+        "__fields_set__": {
+          "py/set": []
+        },
+        "__private_attribute_values__": {}
+      }
+    }
+  }
\ No newline at end of file
diff --git a/tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json b/tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json
new file mode 100644
index 000000000..79f1a0ac4
--- /dev/null
+++ b/tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json
@@ -0,0 +1,84 @@
+{
+    "projects": [
+        {
+            "cuis": "",
+            "tuis": "",
+            "name": "TEST-PROJ",
+            "id": "PROJ_FAKE",
+            "documents": [
+                {
+                    "name": "fake_doc_0",
+                    "id": 100,
+                    "last_modified": "-1",
+                    "text": "This virus is called virus M and was read from the second CSV we could find.",
+                    "annotations": [
+                        {
+                            "cui": "C0000039",
+                            "start": 5,
+                            "end": 10,
+                            "value": "virus"
+                        },
+                        {
+                            "cui": "C0000139",
+                            "start": 21,
+                            "end": 28,
+                            "value": "virus M"
+                        },
+                        {
+                            "cui": "C0000239",
+                            "start": 51,
+                            "end": 62,
+                            "value": "second CSV"
+                        }
+                    ]
+                },
+                {
+                    "name": "fake_doc_1",
+                    "id": 101,
+                    "last_modified": "-1",
+                    "text": "We found a virus. Turned out it was virus M. This was the second CSV we looked at.",
+                    "annotations": [
+                        {
+                            "cui": "C0000039",
+                            "start": 11,
+                            "end": 16,
+                            "value": "virus"
+                        },
+                        {
+                            "cui": "C0000139",
+                            "start": 36,
+                            "end": 43,
+                            "value": "virus M"
+                        },
+                        {
+                            "cui": "C0000239",
+                            "start": 58,
+                            "end": 69,
+                            "value": "second CSV"
+                        }
+                    ]
+                },
+                {
+                    "name": "fake_doc_2",
+                    "id": 102,
+                    "last_modified": "-1",
+                    "text": "We opened second CSV and found virus M to be the culprit.",
+                    "annotations": [
+                        {
+                            "cui": "C0000239",
+                            "start": 10,
+                            "end": 21,
+                            "value": "second CSV"
+                        },
+                        {
+                            "cui": "C0000139",
+                            "start": 31,
+                            "end": 38,
+                            "value": "virus M"
+                        }
+                    ]
+                }
+            ]
+        }
+    ]
+}
\ No newline at end of file
diff --git a/webapp/webapp/demo/__init__.py b/tests/stats/__init__.py
similarity index 100%
rename from webapp/webapp/demo/__init__.py
rename to tests/stats/__init__.py
diff --git a/tests/stats/helpers.py b/tests/stats/helpers.py
new file mode 100644
index 000000000..80771b11c
--- /dev/null
+++ b/tests/stats/helpers.py
@@ -0,0 +1,17 @@
+from pydantic import create_model_from_typeddict
+
+from medcat.stats.mctexport import MedCATTrainerExport
+
+
+MCTExportPydanticModel = create_model_from_typeddict(MedCATTrainerExport)
+
+
+def nullify_doc_names_proj_ids(export: MedCATTrainerExport) -> MedCATTrainerExport:
+    return {'projects': [
+        {
+            'name': project['name'], 
+            'documents': sorted([
+                {k: v if k != 'name' else '' for k, v in doc.items()} for doc in project['documents']
+            ], key=lambda doc: doc['id'])
+        } for project in export['projects']
+    ]}
diff --git a/tests/stats/test_kfold.py b/tests/stats/test_kfold.py
new file mode 100644
index 000000000..87dcdd454
--- /dev/null
+++ b/tests/stats/test_kfold.py
@@ -0,0 +1,298 @@
+import os
+import json
+from typing import Dict, Union, Optional
+from copy import deepcopy
+
+from medcat.stats import kfold
+from medcat.cat import CAT
+from pydantic.error_wrappers import ValidationError as PydanticValidationError
+
+import unittest
+
+from .helpers import MCTExportPydanticModel, nullify_doc_names_proj_ids
+
+
+class MCTExportTests(unittest.TestCase):
+    EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..",
+                               "resources", "medcat_trainer_export.json")
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        with open(cls.EXPORT_PATH) as f:
+            cls.mct_export = json.load(f)
+
+    def assertIsMCTExport(self, obj):
+        try:
+            model = MCTExportPydanticModel(**obj)
+        except PydanticValidationError as e:
+            raise AssertionError("Not n MCT export") from e
+        self.assertIsInstance(model, MCTExportPydanticModel)
+
+
+class KFoldCreatorTests(MCTExportTests):
+    K = 3
+    SPLIT_TYPE = kfold.SplitType.DOCUMENTS
+
+
+    def setUp(self) -> None:
+        self.creator = kfold.get_fold_creator(self.mct_export, self.K, split_type=self.SPLIT_TYPE)
+        self.folds = self.creator.create_folds()
+
+    def test_folding_does_not_modify_initial_export(self):
+        with open(self.EXPORT_PATH) as f:
+            export_copy = json.load(f)
+        self.assertEqual(export_copy, self.mct_export)
+
+    def test_mct_export_has_correct_format(self):
+        self.assertIsMCTExport(self.mct_export)
+
+    def test_folds_have_docs(self):
+        for nr, fold in enumerate(self.folds):
+            with self.subTest(f"Fold-{nr}"):
+                self.assertGreater(kfold.count_all_docs(fold), 0)
+
+    def test_folds_have_anns(self):
+        for nr, fold in enumerate(self.folds):
+            with self.subTest(f"Fold-{nr}"):
+                self.assertGreater(kfold.count_all_annotations(fold), 0)
+
+    def test_folds_are_mct_exports(self):
+        for nr, fold in enumerate(self.folds):
+            with self.subTest(f"Fold-{nr}"):
+                self.assertIsMCTExport(fold)
+
+    def test_gets_correct_number_of_folds(self):
+        self.assertEqual(len(self.folds), self.K)
+
+    def test_folds_keep_all_docs(self):
+        total_docs = 0
+        for fold in self.folds:
+            docs = kfold.count_all_docs(fold)
+            total_docs += docs
+        count_all_once = kfold.count_all_docs(self.mct_export)
+        if self.SPLIT_TYPE is kfold.SplitType.ANNOTATIONS:
+            # NOTE: This may be greater if split in the middle of a document
+            #       because that document may then exist in both folds
+            self.assertGreaterEqual(total_docs, count_all_once)
+        else:
+            self.assertEqual(total_docs, count_all_once)
+
+    def test_folds_keep_all_anns(self):
+        total_anns = 0
+        for fold in self.folds:
+            anns = kfold.count_all_annotations(fold)
+            total_anns += anns
+        count_all_once = kfold.count_all_annotations(self.mct_export)
+        self.assertEqual(total_anns, count_all_once)
+
+    def test_1fold_same_as_orig(self):
+        folds = kfold.get_fold_creator(self.mct_export, 1, split_type=self.SPLIT_TYPE).create_folds()
+        self.assertEqual(len(folds), 1)
+        fold, = folds
+        self.assertIsInstance(fold, dict)
+        self.assertIsMCTExport(fold)
+        self.assertEqual(
+            nullify_doc_names_proj_ids(self.mct_export),
+            nullify_doc_names_proj_ids(fold),
+        )
+
+    def test_has_reasonable_annotations_per_folds(self):
+        anns_per_folds = [kfold.count_all_annotations(fold) for fold in self.folds]
+        print(f"ANNS per folds:\n{anns_per_folds}")
+        docs_per_folds = [kfold.count_all_docs(fold) for fold in self.folds]
+        print(f"DOCS per folds:\n{docs_per_folds}")
+
+
+# this is a taylor-made export that
+# just contains a few "documents"
+# with the fake CUIs "annotated"
+NEW_EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..",
+                               "resources", "medcat_trainer_export_FAKE_CONCEPTS.json")
+
+
+class KFoldCreatorPerAnnsTests(KFoldCreatorTests):
+    SPLIT_TYPE = kfold.SplitType.ANNOTATIONS
+
+
+class KFoldCreatorPerWeightedDocsTests(KFoldCreatorTests):
+    SPLIT_TYPE = kfold.SplitType.DOCUMENTS_WEIGHTED
+    # should have a total of 435, so 145 per in ideal world
+    # but we'll allow the following deviation
+    PERMITTED_MAX_DEVIATION_IN_ANNS = 5
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        cls.total_anns = kfold.count_all_annotations(cls.mct_export)
+        cls.expected_anns_per_fold = cls.total_anns // cls.K
+        cls.expected_lower_bound = cls.expected_anns_per_fold - cls.PERMITTED_MAX_DEVIATION_IN_ANNS
+        cls.expected_upper_bound = cls.expected_anns_per_fold + cls.PERMITTED_MAX_DEVIATION_IN_ANNS
+
+    def test_has_reasonable_annotations_per_folds(self):
+        anns_per_folds = [kfold.count_all_annotations(fold) for fold in self.folds]
+        for nr, anns in enumerate(anns_per_folds):
+            with self.subTest(f"Fold-{nr}"):
+                self.assertGreater(anns, self.expected_lower_bound)
+                self.assertLess(anns, self.expected_upper_bound)
+        # NOTE: as of testing, this will split [146, 145, 144]
+        #       whereas regular per-docs split will have [140, 163, 132]
+
+
+class KFoldCreatorNewExportTests(KFoldCreatorTests):
+    EXPORT_PATH = NEW_EXPORT_PATH
+
+
+class KFoldCreatorNewExportAnnsTests(KFoldCreatorNewExportTests):
+    SPLIT_TYPE = kfold.SplitType.ANNOTATIONS
+
+
+class KFoldCreatorNewExportWeightedDocsTests(KFoldCreatorNewExportTests):
+    SPLIT_TYPE = kfold.SplitType.DOCUMENTS_WEIGHTED
+
+
+class KFoldCATTests(MCTExportTests):
+    _names = ['fps', 'fns', 'tps', 'prec', 'rec', 'f1', 'counts', 'examples']
+    EXPORT_PATH = NEW_EXPORT_PATH
+    CAT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
+    TOLERANCE_PLACES = 10  # tolerance of 10 digits
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        cls.cat = CAT.load_model_pack(cls.CAT_PATH)
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.reg_stats = self.cat._print_stats(self.mct_export, do_print=False)
+        # TODO - remove
+        self.maxDiff = 4000
+
+    # NOTE: Due to floating point errors, sometimes we may get slightly different results
+    def assertDictsAlmostEqual(self, d1: Dict[str, Union[int, float]], d2: Dict[str, Union[int, float]],
+                               tolerance_places: Optional[int] = None) -> None:
+        self.assertEqual(d1.keys(), d2.keys())
+        tol = tolerance_places if tolerance_places is not None else self.TOLERANCE_PLACES
+        for k in d1:
+            v1, v2 = d1[k], d2[k]
+            self.assertAlmostEqual(v1, v2, places=tol)
+
+
+class KFoldStatsConsistencyTests(KFoldCATTests):
+
+    def test_mct_export_valid(self):
+        self.assertIsMCTExport(self.mct_export)
+
+    def test_stats_consistent(self):
+        stats = self.cat._print_stats(self.mct_export, do_print=False)
+        for name, stats1, stats2 in zip(self._names, self.reg_stats, stats):
+            with self.subTest(name):
+                # NOTE: These should be EXACTLY equal since there shouldn't be
+                #       any different additions and the like
+                self.assertEqual(stats1, stats2)
+
+
+class KFoldMetricsTests(KFoldCATTests):
+    SPLIT_TYPE = kfold.SplitType.DOCUMENTS
+
+    def test_metrics_1_fold_same_as_normal(self):
+        stats = kfold.get_k_fold_stats(self.cat, self.mct_export, k=1,
+                                       split_type=self.SPLIT_TYPE)
+        for name, reg, folds1 in zip(self._names, self.reg_stats, stats):
+            with self.subTest(name):
+                if name != 'examples':
+                    # NOTE: These may not be exactly equal due to floating point errors
+                    self.assertDictsAlmostEqual(reg, folds1)
+                else:
+                    self.assertEqual(reg, folds1)
+
+
+class KFoldPerAnnsMetricsTests(KFoldMetricsTests):
+    SPLIT_TYPE = kfold.SplitType.ANNOTATIONS
+
+
+class KFoldWeightedDocsMetricsTests(KFoldMetricsTests):
+    SPLIT_TYPE = kfold.SplitType.DOCUMENTS_WEIGHTED
+
+
+class KFoldDuplicatedTests(KFoldCATTests):
+    COPIES = 3
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        cls.docs_in_orig = kfold.count_all_docs(cls.mct_export)
+        cls.anns_in_orig = kfold.count_all_annotations(cls.mct_export)
+        cls.data_copied: kfold.MedCATTrainerExport = deepcopy(cls.mct_export)
+        for project in cls.data_copied['projects']:
+            documents_list = project['documents']
+            copies = documents_list + [
+                {k: v if k != 'name' else f"{v}_cp_{nr}" for k, v in doc.items()} for nr in range(cls.COPIES - 1)
+                for doc in documents_list
+            ]
+            project['documents'] = copies
+        cls.docs_in_copy = kfold.count_all_docs(cls.data_copied)
+        cls.anns_in_copy = kfold.count_all_annotations(cls.data_copied)
+        cls.stats_copied = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES)
+        cls.stats_copied_2 = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES)
+
+    # some stats with real model/data will be e.g 0.99 vs 0.9747
+    # so in that case, lower it to 1 or so
+    _stats_consistency_tolerance = 8
+
+    def test_stats_consistent(self):
+        for name, one, two in zip(self._names, self.stats_copied, self.stats_copied_2):
+            with self.subTest(name):
+                if name == 'examples':
+                    # examples are hard
+                    # sometimes they differ by quite a lot
+                    for etype in one:
+                        ev1, ev2 = one[etype], two[etype]
+                        with self.subTest(f"{name}-{etype}"):
+                            self.assertEqual(ev1.keys(), ev2.keys())
+                            for cui in ev1:
+                                per_cui_examples1 = ev1[cui]
+                                per_cui_examples2 = ev2[cui]
+                                with self.subTest(f"{name}-{etype}-{cui}-[{self.cat.cdb.cui2preferred_name.get(cui, cui)}]"):
+                                    self.assertEqual(len(per_cui_examples1), len(per_cui_examples2), "INCORRECT NUMBER OF ITEMS")
+                                    for ex1, ex2 in zip(per_cui_examples1, per_cui_examples2):
+                                        self.assertDictsAlmostEqual(ex1, ex2, tolerance_places=self._stats_consistency_tolerance)
+                    continue
+                self.assertEqual(one, two)
+
+    def test_copy_has_correct_number_documents(self):
+        self.assertEqual(self.COPIES * self.docs_in_orig, self.docs_in_copy)
+
+    def test_copy_has_correct_number_annotations(self):
+        self.assertEqual(self.COPIES * self.anns_in_orig, self.anns_in_copy)
+
+    def test_3_fold_identical_folds(self):
+        folds = kfold.get_fold_creator(self.data_copied, nr_of_folds=self.COPIES,
+                                  split_type=kfold.SplitType.DOCUMENTS).create_folds()
+        self.assertEqual(len(folds), self.COPIES)
+        for nr, fold in enumerate(folds):
+            with self.subTest(f"Fold-{nr}"):
+                # if they're all equal to original, they're eqaul to each other
+                self.assertEqual(
+                    nullify_doc_names_proj_ids(fold),
+                    nullify_doc_names_proj_ids(self.mct_export)
+                )
+
+    def test_metrics_3_fold(self):
+        stats_simple = self.reg_stats
+        for name, old, new in zip(self._names, stats_simple, self.stats_copied):
+            if name == 'examples':
+                continue
+            # with self.subTest(name):
+            if name in ("fps", "fns", "tps", "counts"):
+                # count should be triples
+                pass
+            if name in ("prec", "rec", "f1"):
+                # these should average to the same ??
+                all_keys = old.keys() | new.keys()
+                for cui in all_keys:
+                    cuiname = self.cat.cdb.cui2preferred_name.get(cui, cui)
+                    with self.subTest(f"{name}-{cui} [{cuiname}]"):
+                        self.assertIn(cui, old.keys(), f"CUI '{cui}' ({cuiname}) not in old")
+                        self.assertIn(cui, new.keys(), f"CUI '{cui}' ({cuiname}) not in new")
+                        v1, v2 = old[cui], new[cui]
+                        self.assertEqual(v1, v2, f"Values not equal for {cui} ({self.cat.cdb.cui2preferred_name.get(cui, cui)})")
diff --git a/tests/stats/test_mctexport.py b/tests/stats/test_mctexport.py
new file mode 100644
index 000000000..8ef11f556
--- /dev/null
+++ b/tests/stats/test_mctexport.py
@@ -0,0 +1,38 @@
+import os
+import json
+
+from medcat.stats import mctexport
+
+import unittest
+
+from .helpers import MCTExportPydanticModel
+
+
+class MCTExportIterationTests(unittest.TestCase):
+    EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..",
+                               "resources", "medcat_trainer_export.json")
+    EXPECTED_DOCS = 27
+    EXPECTED_ANNS = 435
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        with open(cls.EXPORT_PATH) as f:
+            cls.mct_export: mctexport.MedCATTrainerExport = json.load(f)
+
+    def test_conforms_to_template(self):
+        # NOTE: This uses pydantic to make sure that the MedCATTrainerExport
+        #       type matches the actual export format
+        model_instance = MCTExportPydanticModel(**self.mct_export)
+        self.assertIsInstance(model_instance, MCTExportPydanticModel)
+
+    def test_iterates_over_all_docs(self):
+        self.assertEqual(mctexport.count_all_docs(self.mct_export), self.EXPECTED_DOCS)
+
+    def test_iterates_over_all_anns(self):
+        self.assertEqual(mctexport.count_all_annotations(self.mct_export), self.EXPECTED_ANNS)
+
+    def test_gets_correct_nr_of_annotations_per_doc(self):
+        for project in self.mct_export['projects']:
+            for doc in project["documents"]:
+                with self.subTest(f"Proj-{project['name']} ({project['id']})-{doc['name']} ({doc['id']})"):
+                    self.assertEqual(mctexport.get_nr_of_annotations(doc), len(doc["annotations"]))
diff --git a/tests/test_cat.py b/tests/test_cat.py
index ce1b62d98..780039473 100644
--- a/tests/test_cat.py
+++ b/tests/test_cat.py
@@ -301,9 +301,9 @@ def _test_train_superivsed(self, temp_file: str):
         data_path = self.SUPERVISED_TRAINING_JSON
         ckpt_dir_path = temp_file
         checkpoint = Checkpoint(dir_path=ckpt_dir_path, steps=1, max_to_keep=sys.maxsize)
-        fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised(data_path,
-                                                                                     checkpoint=checkpoint,
-                                                                                     nepochs=nepochs)
+        fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised_from_json(data_path,
+                                                                                               checkpoint=checkpoint,
+                                                                                               nepochs=nepochs)
         checkpoints = [f for f in os.listdir(ckpt_dir_path) if "checkpoint-" in f]
         self.assertEqual({}, fp)
         self.assertEqual({}, fn)
@@ -328,13 +328,11 @@ def _test_resume_supervised_training(self, temp_file: str):
         data_path = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export.json")
         ckpt_dir_path = temp_file
         checkpoint = Checkpoint(dir_path=ckpt_dir_path, steps=1, max_to_keep=sys.maxsize)
-        self.undertest.train_supervised(data_path,
-                                        checkpoint=checkpoint,
-                                        nepochs=nepochs_train)
-        fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised(data_path,
-                                                                                     checkpoint=checkpoint,
-                                                                                     nepochs=nepochs_train+nepochs_retrain,
-                                                                                     is_resumed=True)
+        self.undertest.train_supervised_from_json(data_path,
+                                                  checkpoint=checkpoint,
+                                                  nepochs=nepochs_train)
+        fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised_from_json(
+            data_path, checkpoint=checkpoint, nepochs=nepochs_train+nepochs_retrain, is_resumed=True)
         checkpoints = [f for f in os.listdir(ckpt_dir_path) if "checkpoint-" in f]
         self.assertEqual({}, fp)
         self.assertEqual({}, fn)
@@ -351,15 +349,15 @@ def _test_resume_supervised_training(self, temp_file: str):
     def test_train_supervised_does_not_retain_MCT_filters_default(self, extra_cui_filter=None):
         data_path = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export_filtered.json")
         before = str(self.undertest.config.linking.filters)
-        self.undertest.train_supervised(data_path, nepochs=1, use_filters=True, extra_cui_filter=extra_cui_filter)
+        self.undertest.train_supervised_from_json(data_path, nepochs=1, use_filters=True, extra_cui_filter=extra_cui_filter)
         after = str(self.undertest.config.linking.filters)
         self.assertEqual(before, after)
 
     def test_train_supervised_can_retain_MCT_filters(self, extra_cui_filter=None, retain_extra_cui_filter=False):
         data_path = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export_filtered.json")
         before = str(self.undertest.config.linking.filters)
-        self.undertest.train_supervised(data_path, nepochs=1, use_filters=True, retain_filters=True,
-                                        extra_cui_filter=extra_cui_filter, retain_extra_cui_filter=retain_extra_cui_filter)
+        self.undertest.train_supervised_from_json(data_path, nepochs=1, use_filters=True, retain_filters=True,
+                                                  extra_cui_filter=extra_cui_filter, retain_extra_cui_filter=retain_extra_cui_filter)
         after = str(self.undertest.config.linking.filters)
         self.assertNotEqual(before, after)
         with open(data_path, 'r') as f:
@@ -701,7 +699,7 @@ def _get_meta_cat(meta_cat_dir):
                        config=config)
     os.makedirs(meta_cat_dir, exist_ok=True)
     json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json")
-    meta_cat.train(json_path, save_dir_path=meta_cat_dir)
+    meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir)
     return meta_cat
 
 
@@ -712,7 +710,7 @@ class TestLoadingOldWeights(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
         cls.cdb = CDB.load(cls.cdb_path)
-        cls.wf = cls.cdb.config.linking.weighted_average_function
+        cls.wf = cls.cdb.weighted_average_function
 
     def test_can_call_weights(self):
         res = self.wf(step=1)
diff --git a/tests/test_cdb_maker.py b/tests/test_cdb_maker.py
index f84e47b15..f454ebe7d 100644
--- a/tests/test_cdb_maker.py
+++ b/tests/test_cdb_maker.py
@@ -132,6 +132,24 @@ def setUpClass(cls):
     def tearDownClass(cls) -> None:
         cls.maker.destroy_pipe()
 
+    # NOTE: The following tests are state-dependent. That is to say,
+    #       if the order in which they're executed changes, they may fail.
+    #       They currently rely on the fact that test methods are executed
+    #       in anlphabetic order. But this is overall not good test design
+    #       since failure of one unit could lead to the failure of another
+    #       in unsexpected ways (since there's an expectation on the state).
+    #
+    #       e.g, if I run:
+    #           python -m unittest\
+    #               tests.test_cdb_maker.B_CDBMakerEditTests.test_bd_addition_of_context_vector_positive\
+    #               tests.test_cdb_maker.B_CDBMakerEditTests.test_bc_filter_by_cui\
+    #               tests.test_cdb_maker.B_CDBMakerEditTests.test_bb_removal_of_name\
+    #               tests.test_cdb_maker.B_CDBMakerEditTests.test_ba_addition_of_new_name
+    #       Then there will be failures in `test_ba_addition_of_new_name` and `test_bb_removal_of_name`
+    #       due to the changes in state.
+    #
+    #       Though to make it clear, in the standard configuration the tests run in the
+    #       "correct" order and are successful.
     def test_ba_addition_of_new_name(self):
         self.cdb.add_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.spacy_nlp, {}, self.config), name_status='P', full_build=True)
         self.assertEqual(len(self.cdb.name2cuis), 6, "Should equal 6")
@@ -142,7 +160,7 @@ def test_ba_addition_of_new_name(self):
         self.assertIn('my~:~new~name~.', self.cdb.name2cuis2status)
 
     def test_bb_removal_of_name(self):
-        self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.spacy_nlp, {}, self.config))
+        self.cdb._remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.spacy_nlp, {}, self.config))
         self.assertEqual(len(self.cdb.name2cuis), 5, "Should equal 5")
         self.assertNotIn('my:newname.', self.cdb.name2cuis2status)
 
diff --git a/tests/test_config.py b/tests/test_config.py
index ce6ed76eb..bfd440a78 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -2,6 +2,7 @@
 import pickle
 import tempfile
 from medcat.config import Config, MixingConfig, VersionInfo, General, LinkingFilters
+from medcat.config import UseOfOldConfigOptionException, Linking
 from pydantic import ValidationError
 import os
 
@@ -208,6 +209,13 @@ def test_config_hash_recalc_same_changed(self):
         h2 = config.get_hash()
         self.assertEqual(h1, h2)
 
+    def test_can_save_load(self):
+        config = Config()
+        with tempfile.NamedTemporaryFile() as file:
+            config.save(file.name)
+            config2 = Config.load(file.name)
+        self.assertEqual(config, config2)
+
 
 class ConfigLinkingFiltersTests(unittest.TestCase):
 
@@ -228,5 +236,36 @@ def test_not_allow_empty_dict_for_cuis_exclude(self):
             LinkingFilters(cuis_exclude={})
 
 
+class BackwardsCompatibilityTests(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.config = Config()
+
+    def test_use_weighted_average_function_identifier_nice_error(self):
+        with self.assertRaises(UseOfOldConfigOptionException):
+            self.config.linking.weighted_average_function(0)
+
+    def test_use_weighted_average_function_dict_nice_error(self):
+        with self.assertRaises(UseOfOldConfigOptionException):
+            self.config.linking['weighted_average_function'](0)
+
+
+class BackwardsCompatibilityWafPayloadTests(unittest.TestCase):
+    arg = 'weighted_average_function'
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls.config = Config()
+        with cls.assertRaises(cls, UseOfOldConfigOptionException) as cls.context:
+            cls.config.linking.weighted_average_function(0)
+        cls.raised = cls.context.exception
+
+    def test_exception_has_correct_conf_type(self):
+        self.assertIs(self.raised.conf_type, Linking)
+
+    def test_exception_has_correct_arg(self):
+        self.assertEqual(self.raised.arg_name, self.arg)
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/tests/test_meta_cat.py b/tests/test_meta_cat.py
index 8cd444668..ead082c0b 100644
--- a/tests/test_meta_cat.py
+++ b/tests/test_meta_cat.py
@@ -10,6 +10,7 @@
 import spacy
 from spacy.tokens import Span
 
+
 class MetaCATTests(unittest.TestCase):
 
     @classmethod
@@ -17,7 +18,7 @@ def setUpClass(cls) -> None:
         tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained('prajjwal1/bert-tiny'))
         config = ConfigMetaCAT()
         config.general['category_name'] = 'Status'
-        config.train['nepochs'] = 1
+        config.train['nepochs'] = 2
         config.model['input_size'] = 100
 
         cls.meta_cat: MetaCAT = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config)
@@ -29,14 +30,16 @@ def tearDown(self) -> None:
         shutil.rmtree(self.tmp_dir)
 
     def test_train(self):
-        json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json')
-        results = self.meta_cat.train(json_path, save_dir_path=self.tmp_dir)
-
-        self.assertEqual(results['report']['weighted avg']['f1-score'], 1.0)
+        json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources',
+                                 'mct_export_for_meta_cat_test.json')
+        results = self.meta_cat.train_from_json(json_path, save_dir_path=self.tmp_dir)
+        if self.meta_cat.config.model.phase_number != 1:
+            self.assertEqual(results['report']['weighted avg']['f1-score'], 1.0)
 
     def test_save_load(self):
-        json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json')
-        self.meta_cat.train(json_path, save_dir_path=self.tmp_dir)
+        json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources',
+                                 'mct_export_for_meta_cat_test.json')
+        self.meta_cat.train_from_json(json_path, save_dir_path=self.tmp_dir)
         self.meta_cat.save(self.tmp_dir)
         n_meta_cat = MetaCAT.load(self.tmp_dir)
 
@@ -53,17 +56,18 @@ def _prepare_doc_w_spangroup(self, spangroup_name: str):
         Span.set_extension('meta_anns', default=None, force=True)
         nlp = spacy.blank("en")
         doc = nlp("Pt has diabetes and copd.")
-        span_0 = doc.char_span(7,15, label="diabetes")
+        span_0 = doc.char_span(7, 15, label="diabetes")
         assert span_0.text == 'diabetes'
 
-        span_1 = doc.char_span(20,24, label="copd")
+        span_1 = doc.char_span(20, 24, label="copd")
         assert span_1.text == 'copd'
         doc.spans[spangroup_name] = [span_0, span_1]
         return doc
 
     def test_predict_spangroup(self):
-        json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json')
-        self.meta_cat.train(json_path, save_dir_path=self.tmp_dir)
+        json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources',
+                                 'mct_export_for_meta_cat_test.json')
+        self.meta_cat.train_from_json(json_path, save_dir_path=self.tmp_dir)
         self.meta_cat.save(self.tmp_dir)
         n_meta_cat = MetaCAT.load(self.tmp_dir)
 
@@ -90,5 +94,29 @@ def test_predict_spangroup(self):
         n_meta_cat.config.general.span_group = None
 
 
+class MetaCATBertTest(MetaCATTests):
+    @classmethod
+    def setUpClass(cls) -> None:
+        tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained('prajjwal1/bert-tiny'))
+        config = ConfigMetaCAT()
+        config.general['category_name'] = 'Status'
+        config.train['nepochs'] = 2
+        config.model['input_size'] = 100
+        config.train['batch_size'] = 64
+        config.model['model_name'] = 'bert'
+
+        cls.meta_cat: MetaCAT = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config)
+        cls.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
+        os.makedirs(cls.tmp_dir, exist_ok=True)
+
+    def test_two_phase(self):
+        self.meta_cat.config.model['phase_number'] = 1
+        self.test_train()
+        self.meta_cat.config.model['phase_number'] = 2
+        self.test_train()
+
+        self.meta_cat.config.model['phase_number'] = 0
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/tests/utils/saving/test_envsnapshot.py b/tests/utils/saving/test_envsnapshot.py
new file mode 100644
index 000000000..16bee1ffb
--- /dev/null
+++ b/tests/utils/saving/test_envsnapshot.py
@@ -0,0 +1,105 @@
+from typing import Any
+import platform
+import os
+import tempfile
+import json
+import zipfile
+
+from medcat.cat import CAT
+from medcat.utils.saving import envsnapshot
+
+import unittest
+
+
+def list_zip_contents(zip_file_path):
+    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
+        return zip_ref.namelist()
+
+
+class DirectDependenciesTests(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.direct_deps = envsnapshot.get_direct_dependencies()
+
+    def test_nonempty(self):
+        self.assertTrue(self.direct_deps)
+
+    def test_does_not_contain_versions(self, version_starters: str = '<=>~'):
+        for dep in self.direct_deps:
+            for vs in version_starters:
+                with self.subTest(f"DEP '{dep}' check for '{vs}'"):
+                    self.assertNotIn(vs, dep)
+
+    def test_deps_are_installed_packages(self):
+        for dep in self.direct_deps:
+            with self.subTest(f"Has '{dep}'"):
+                envsnapshot.pkg_resources.require(dep)
+
+
+class EnvSnapshotAloneTests(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.env_info = envsnapshot.get_environment_info()
+
+    def test_info_is_dict(self):
+        self.assertIsInstance(self.env_info, dict)
+
+    def test_info_is_not_empty(self):
+        self.assertTrue(self.env_info)
+
+    def assert_has_target(self, target: str, expected: Any):
+        self.assertIn(target, self.env_info)
+        py_ver = self.env_info[target]
+        self.assertEqual(py_ver, expected)
+
+    def test_has_os(self):
+        self.assert_has_target("os", platform.platform())
+
+    def test_has_py_ver(self):
+        self.assert_has_target("python_version", platform.python_version())
+
+    def test_has_cpu_arch(self):
+        self.assert_has_target("cpu_architecture", platform.machine())
+
+    def test_has_dependencies(self, name: str = "dependencies"):
+        # NOTE: just making sure it's a anon-empty list
+        self.assertIn(name, self.env_info)
+        deps = self.env_info[name]
+        self.assertTrue(deps)
+
+    def test_all_direct_dependencies_are_installed(self):
+        deps = self.env_info['dependencies']
+        direct_deps = envsnapshot.get_direct_dependencies()
+        self.assertEqual(len(deps), len(direct_deps))
+
+
+CAT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples")
+ENV_SNAPSHOT_FILE_NAME = envsnapshot.ENV_SNAPSHOT_FILE_NAME
+
+
+class EnvSnapshotInCATTests(unittest.TestCase):
+    expected_env = envsnapshot.get_environment_info()
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls.cat = CAT.load_model_pack(CAT_PATH)
+        cls._temp_dir = tempfile.TemporaryDirectory()
+        mpn = cls.cat.create_model_pack(cls._temp_dir.name)
+        cls.cat_folder = os.path.join(cls._temp_dir.name, mpn)
+        cls.envrion_file_path = os.path.join(cls.cat_folder, ENV_SNAPSHOT_FILE_NAME)
+
+    def test_has_environment(self):
+        self.assertTrue(os.path.exists(self.envrion_file_path))
+
+    def test_eviron_saved(self):
+        with open(self.envrion_file_path) as f:
+            saved_info: dict = json.load(f)
+        self.assertEqual(saved_info.keys(), self.expected_env.keys())
+        for k in saved_info:
+            with self.subTest(k):
+                v1, v2 = saved_info[k], self.expected_env[k]
+                self.assertEqual(v1, v2)
+
+    def test_zip_has_env_snapshot(self):
+        filenames = list_zip_contents(self.cat_folder + ".zip")
+        self.assertIn(ENV_SNAPSHOT_FILE_NAME, filenames)
diff --git a/tests/utils/saving/test_serialization.py b/tests/utils/saving/test_serialization.py
index c2c44da16..cb26312f0 100644
--- a/tests/utils/saving/test_serialization.py
+++ b/tests/utils/saving/test_serialization.py
@@ -10,6 +10,7 @@
 from medcat.vocab import Vocab
 
 from medcat.utils.saving.serializer import JsonSetSerializer, CDBSerializer, SPECIALITY_NAMES, ONE2MANY
+from medcat.utils.saving.envsnapshot import ENV_SNAPSHOT_FILE_NAME
 
 import medcat.utils.saving.coding as _
 
@@ -60,6 +61,7 @@ class ModelCreationTests(unittest.TestCase):
     json_model_pack = tempfile.TemporaryDirectory()
     EXAMPLES = os.path.join(os.path.dirname(
         os.path.realpath(__file__)), "..", "..", "..", "examples")
+    EXCEPTIONAL_JSONS = ['model_card.json', ENV_SNAPSHOT_FILE_NAME]
 
     @classmethod
     def setUpClass(cls) -> None:
@@ -95,7 +97,7 @@ def test_dill_to_json(self):
             SPECIALITY_NAMES) - len(ONE2MANY))
         for json in jsons:
             with self.subTest(f'JSON {json}'):
-                if json.endswith('model_card.json'):
+                if any(json.endswith(exception) for exception in self.EXCEPTIONAL_JSONS):
                     continue  # ignore model card here
                 if any(name in json for name in ONE2MANY):
                     # ignore cui2many and name2many
@@ -117,10 +119,6 @@ def test_round_trip(self):
         # The spacy model has full path in the loaded model, thus won't be equal
         cat.config.general.spacy_model = os.path.basename(
             cat.config.general.spacy_model)
-        # There can also be issues with loading the config.linking.weighted_average_function from file
-        # This should be fixed with newer models,
-        # but the example model is older, so has the older functionalitys
-        cat.config.linking.weighted_average_function = self.undertest.config.linking.weighted_average_function
         self.assertEqual(cat.config.asdict(), self.undertest.config.asdict())
         self.assertEqual(cat.cdb.config, self.undertest.cdb.config)
         self.assertEqual(len(cat.vocab.vocab), len(self.undertest.vocab.vocab))
diff --git a/tests/utils/test_cdb_state.py b/tests/utils/test_cdb_state.py
new file mode 100644
index 000000000..068af128b
--- /dev/null
+++ b/tests/utils/test_cdb_state.py
@@ -0,0 +1,113 @@
+import unittest
+import os
+from unittest import mock
+from typing import Callable, Any, Dict
+import tempfile
+
+from medcat.utils.cdb_state import captured_state_cdb, CDBState, copy_cdb_state
+from medcat.cdb import CDB
+from medcat.vocab import Vocab
+from medcat.cat import CAT
+
+
+class StateTests(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat"))
+        cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "vocab.dat"))
+        cls.vocab.make_unigram_table()
+        cls.cdb.config.general.spacy_model = "en_core_web_md"
+        cls.meta_cat_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
+        cls.undertest = CAT(cdb=cls.cdb, config=cls.cdb.config, vocab=cls.vocab, meta_cats=[])
+        cls.initial_state = copy_cdb_state(cls.cdb)
+
+    @classmethod
+    def _set_info(cls, k: str, v: Any, info_dict: Dict):
+        info_dict[k] = (len(v), len(str(v)))
+
+    @classmethod
+    def do_smth_for_each_state_var(cls, cdb: CDB, callback: Callable[[str, Any], None]) -> None:
+        for k in CDBState.__annotations__:
+            v = getattr(cdb, k)
+            callback(k, v)
+
+
+class StateSavedTests(StateTests):
+    on_disk = False
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        # capture state
+        with captured_state_cdb(cls.cdb, save_state_to_disk=cls.on_disk):
+            # clear state
+            cls.do_smth_for_each_state_var(cls.cdb, lambda k, v: v.clear())
+            cls.cleared_state = copy_cdb_state(cls.cdb)
+        # save after state - should be equal to before
+        cls.restored_state = copy_cdb_state(cls.cdb)
+
+    def test_state_saved(self):
+        nr_of_targets = len(CDBState.__annotations__)
+        self.assertGreater(nr_of_targets, 0)
+        self.assertEqual(len(self.initial_state), nr_of_targets)
+        self.assertEqual(len(self.cleared_state), nr_of_targets)
+        self.assertEqual(len(self.restored_state), nr_of_targets)
+
+    def test_clearing_worked(self):
+        self.assertNotEqual(self.initial_state, self.cleared_state)
+        for k, v in self.cleared_state.items():
+            with self.subTest(k):
+                # length is 0
+                self.assertFalse(v)
+
+    def test_state_restored(self):
+        self.assertEqual(self.initial_state, self.restored_state)
+
+
+class StateSavedOnDiskTests(StateSavedTests):
+    on_disk = True
+    _named_tempory_file = tempfile.NamedTemporaryFile
+
+    @classmethod
+    def saved_name_temp_file(cls):
+        tf = cls._named_tempory_file()
+        cls.temp_file_name = tf.name
+        return tf
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        with mock.patch("builtins.open", side_effect=open) as cls.popen:
+            with mock.patch("tempfile.NamedTemporaryFile", side_effect=cls.saved_name_temp_file) as cls.pntf:
+                return super().setUpClass()
+
+    def test_temp_file_called(self):
+        self.pntf.assert_called_once()
+
+    def test_saved_on_disk(self):
+        self.popen.assert_called()
+        self.assertGreaterEqual(self.popen.call_count, 2)
+        self.popen.assert_has_calls([mock.call(self.temp_file_name, 'wb'),
+                                     mock.call(self.temp_file_name, 'rb')])
+
+
+class StateWithTrainingTests(StateTests):
+    SUPERVISED_TRAINING_JSON = os.path.join(os.path.dirname(__file__), "..", "resources", "medcat_trainer_export.json")
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        super().setUpClass()
+        with captured_state_cdb(cls.cdb):
+            # do training
+            cls.undertest.train_supervised_from_json(cls.SUPERVISED_TRAINING_JSON)
+            cls.after_train_state = copy_cdb_state(cls.cdb)
+        cls.restored_state = copy_cdb_state(cls.cdb)
+
+
+class StateRestoredAfterTrain(StateWithTrainingTests):
+
+    def test_train_state_changed(self):
+        self.assertNotEqual(self.initial_state, self.after_train_state)
+
+    def test_restored_state_same(self):
+        self.assertDictEqual(self.initial_state, self.restored_state)
diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py
new file mode 100644
index 000000000..d1a7262e7
--- /dev/null
+++ b/tests/utils/test_config_utils.py
@@ -0,0 +1,121 @@
+from medcat.config import Config
+from medcat.utils.saving.coding import default_hook, CustomDelegatingEncoder
+from medcat.utils import config_utils
+from medcat import config as main_config
+from medcat import config_meta_cat
+from medcat import config_transformers_ner
+from medcat import config_rel_cat
+import json
+import os
+
+import unittest
+
+OLD_STYLE_DICT = {'py/object': 'medcat.config.VersionInfo',
+                  'py/state': {
+                      '__dict__': {
+                          'history': ['0c0de303b6dc0020',],
+                          'meta_cats': [],
+                          'cdb_info': {
+                              'Number of concepts': 785910,
+                              'Number of names': 2480049,
+                              'Number of concepts that received training': 378746,
+                              'Number of seen training examples in total': 1863973060,
+                              'Average training examples per concept': {
+                                  'py/reduce': [{'py/function': 'numpy.core.multiarray.scalar'},]
+                                  }
+                              },
+                          'performance': {'ner': {}, 'meta': {}},
+                          'description': 'No description',
+                          'id': 'ff4f4e00bc97de58',
+                          'last_modified': '26 April 2024',
+                          'location': None,
+                          'ontology': ['ONTOLOGY1'],
+                          'medcat_version': '1.10.2'
+                          },
+                      '__fields_set__': {
+                          'py/set': ['id', 'ontology', 'description', 'history',
+                                     'location', 'medcat_version', 'last_modified',
+                                     'meta_cats', 'cdb_info', 'performance']
+                                     },
+                      '__private_attribute_values__': {}
+                    }
+                 }
+
+
+NEW_STYLE_DICT = json.loads(json.dumps(Config().asdict(), cls=CustomDelegatingEncoder.def_inst),
+                            object_hook=default_hook)
+
+
+class ConfigUtilsTests(unittest.TestCase):
+
+    def test_identifies_old_style_dict(self):
+        self.assertTrue(config_utils.is_old_type_config_dict(OLD_STYLE_DICT))
+
+    def test_identifies_new_style_dict(self):
+        self.assertFalse(config_utils.is_old_type_config_dict(NEW_STYLE_DICT))
+
+
+class OldFormatJsonTests(unittest.TestCase):
+
+    def assert_knows_old_format(self, file_path: str):
+        with open(file_path) as f:
+            d = json.load(f)
+        self.assertTrue(config_utils.is_old_type_config_dict(d))
+
+
+class OldConfigLoadTests(OldFormatJsonTests):
+    JSON_PICKLE_FILE_PATH = os.path.join(
+        os.path.dirname(__file__), "..", "resources", "jsonpickle_config.json"
+    )
+    EXPECTED_VERSION_HISTORY = ['0c0de303b6dc0020',]
+
+    def test_knows_is_old_format(self):
+        self.assert_knows_old_format(self.JSON_PICKLE_FILE_PATH)
+
+    def test_loads_old_style_correctly(self):
+        cnf: main_config.Config = main_config.Config.load(self.JSON_PICKLE_FILE_PATH)
+        self.assertEqual(cnf.version.history, self.EXPECTED_VERSION_HISTORY)
+
+
+class MetaCATConfigTests(OldFormatJsonTests):
+    META_CAT_OLD_PATH = os.path.join(
+        os.path.dirname(__file__), "..", "resources", "jsonpickle_meta_cat_config.json"
+    )
+    EXPECTED_TARGET = -100
+    TARGET_CLASS = config_meta_cat.ConfigMetaCAT
+
+    @classmethod
+    def get_target(cls, cnf):
+        return cnf.general.seed
+
+    def test_knows_is_old_format(self):
+        self.assert_knows_old_format(self.META_CAT_OLD_PATH)
+
+    def test_can_load_old_format_correctly(self):
+        cnf = self.TARGET_CLASS.load(self.META_CAT_OLD_PATH)
+        self.assertIsInstance(cnf, self.TARGET_CLASS)
+        self.assertEqual(self.get_target(cnf), self.EXPECTED_TARGET)
+
+
+class TNERCATConfigTests(MetaCATConfigTests):
+    META_CAT_OLD_PATH = os.path.join(
+        os.path.dirname(__file__), "..", "resources", "jsonpickle_tner_config.json"
+    )
+    EXPECTED_TARGET = -100
+    TARGET_CLASS = config_transformers_ner.ConfigTransformersNER
+
+    @classmethod
+    def get_target(cls, cnf):
+        return cnf.general.pipe_batch_size_in_chars
+
+
+class RelCATConfigTests(MetaCATConfigTests):
+    META_CAT_OLD_PATH = os.path.join(
+        os.path.dirname(__file__), "..", "resources", "jsonpickle_rel_cat_config.json"
+    )
+    EXPECTED_TARGET = 100_000
+    TARGET_CLASS = config_rel_cat.ConfigRelCAT
+
+    @classmethod
+    def get_target(cls, cnf):
+        return cnf.train.lr
diff --git a/webapp/.gitignore b/webapp/.gitignore
deleted file mode 100644
index fc6ea2e67..000000000
--- a/webapp/.gitignore
+++ /dev/null
@@ -1,6 +0,0 @@
-webapp/data/*
-!webapp/data/.keep
-webapp/db/*
-!webapp/db/.keep
-webapp/models/*
-!webapp/models/.keep
\ No newline at end of file
diff --git a/webapp/README.md b/webapp/README.md
deleted file mode 100644
index 128741d05..000000000
--- a/webapp/README.md
+++ /dev/null
@@ -1 +0,0 @@
-This is a demo application for MedCAT, please note that it was made to be as unreadable as possible - I appologize to anyone that has to use this, it was not done on on purpose. 
diff --git a/webapp/docker-compose.yml b/webapp/docker-compose.yml
deleted file mode 100644
index 60a392990..000000000
--- a/webapp/docker-compose.yml
+++ /dev/null
@@ -1,26 +0,0 @@
-version: '3.4'
-
-services:
-  medcatweb:
-    build:
-        network: host
-        context: ./webapp
-    command: >
-      bash -c "/etc/init.d/cron start &&
-               python /webapp/manage.py runserver 0.0.0.0:8000"
-    volumes:
-      - ./webapp/data:/webapp/data
-      - ./webapp/db:/webapp/db
-      - ./webapp/models:/webapp/models
-      - ./envs/env_db_backup:/etc/environment
-      - medcat_data:/medcat_data
-    ports:
-      - "80:8000"
-    env_file:
-      - ./envs/env_medmen
-      - ./envs/env_db_backup
-    tty: true
-
-volumes:
-  medcat_data:
-    driver: local
diff --git a/webapp/envs/env_db_backup b/webapp/envs/env_db_backup
deleted file mode 100644
index 6071abf41..000000000
--- a/webapp/envs/env_db_backup
+++ /dev/null
@@ -1,8 +0,0 @@
-DB_BACKUP_ON_S3=False
-DB_BACKUP_LOCATION=demo-db-backup/
-DB_BACKUP_EVERY_MINS=720
-DB_BACKUP_RETRY_BACKOFF_MINS=5
-ACCESS_KEY=
-SECRET_KEY=
-BUCKET_NAME=
-DELETE_LOGS_OLDER_THAN=7
\ No newline at end of file
diff --git a/webapp/envs/env_medmen b/webapp/envs/env_medmen
deleted file mode 100644
index 9952ef961..000000000
--- a/webapp/envs/env_medmen
+++ /dev/null
@@ -1 +0,0 @@
-MODEL_PACK_PATH=//
diff --git a/webapp/webapp/.dockerignore b/webapp/webapp/.dockerignore
deleted file mode 100644
index 285877ca8..000000000
--- a/webapp/webapp/.dockerignore
+++ /dev/null
@@ -1,2 +0,0 @@
-data
-models
\ No newline at end of file
diff --git a/webapp/webapp/Dockerfile b/webapp/webapp/Dockerfile
deleted file mode 100644
index 21d19078f..000000000
--- a/webapp/webapp/Dockerfile
+++ /dev/null
@@ -1,37 +0,0 @@
-FROM python:3.7
-
-# Create the required folders
-RUN mkdir -p /webapp/models
-
-# Copy everything
-COPY . /webapp
-
-ENV VOCAB_URL=https://medcat.rosalind.kcl.ac.uk/media/vocab.dat
-ENV CDB_URL=https://medcat.rosalind.kcl.ac.uk/media/cdb-medmen-v1.dat
-
-ENV CDB_PATH=/webapp/models/cdb.dat
-ENV VOCAB_PATH=/webapp/models/vocab.dat
-
-# Create the data directory
-RUN mkdir -p /medcat_data
-
-# Set the pythonpath
-WORKDIR /webapp
-
-RUN pip install -r requirements.txt
-
-# Get the spacy model
-RUN python -m spacy download en_core_web_md
-
-# Build the db
-RUN python manage.py makemigrations && \
-    python manage.py makemigrations demo && \
-    python manage.py migrate && \
-    python manage.py migrate demo && \
-    python manage.py collectstatic --noinput
-
-# Create the db backup cron job
-RUN apt-get update && apt-get install -y --no-install-recommends apt-utils cron sqlite3 libsqlite3-dev
-COPY etc/cron.d/db-backup-cron /etc/cron.d/db-backup-cron
-RUN chmod 0644 /etc/cron.d/db-backup-cron
-RUN crontab /etc/cron.d/db-backup-cron
diff --git a/webapp/webapp/data/.keep b/webapp/webapp/data/.keep
deleted file mode 100644
index e69de29bb..000000000
diff --git a/webapp/webapp/db/.keep b/webapp/webapp/db/.keep
deleted file mode 100644
index e69de29bb..000000000
diff --git a/webapp/webapp/demo/admin.py b/webapp/webapp/demo/admin.py
deleted file mode 100644
index 7cce1297b..000000000
--- a/webapp/webapp/demo/admin.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from django.contrib import admin
-from .models import *
-
-admin.site.register(Downloader)
-admin.site.register(MedcatModel)
-
-def remove_text(modeladmin, request, queryset):
-    UploadedText.objects.all().delete()
-
-class UploadedTextAdmin(admin.ModelAdmin):
-    model = UploadedText
-    actions = [remove_text]
-
-# Register your models here.
-admin.site.register(UploadedText, UploadedTextAdmin)
-
diff --git a/webapp/webapp/demo/apps.py b/webapp/webapp/demo/apps.py
deleted file mode 100644
index 57920c332..000000000
--- a/webapp/webapp/demo/apps.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from django.apps import AppConfig
-
-
-class DemoConfig(AppConfig):
-    name = 'demo'
diff --git a/webapp/webapp/demo/db_backup.py b/webapp/webapp/demo/db_backup.py
deleted file mode 100644
index e637f249a..000000000
--- a/webapp/webapp/demo/db_backup.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import os
-from django.core import management
-from django.conf import settings
-from django_cron import CronJobBase, Schedule
-
-
-class DbBackup(CronJobBase):
-
-    RUN_EVERY_MINS = int(os.environ.get("DB_BACKUP_EVERY_MINS", 60 * 12))
-    RETRY_AFTER_FAILURE_MINS = int(os.environ.get("DB_BACKUP_RETRY_BACKOFF_MINS", 5))
-
-    schedule = Schedule(run_every_mins=RUN_EVERY_MINS, retry_after_failure_mins=RETRY_AFTER_FAILURE_MINS)
-    code = "demo.db_backup.DbBackup"
-
-    def __init__(self):
-        backup_location = settings.DBBACKUP_STORAGE_OPTIONS["location"]
-        os.makedirs(backup_location, exist_ok=True)
-
-    def do(self):
-        management.call_command("dbbackup", "--noinput", "-z")
diff --git a/webapp/webapp/demo/forms.py b/webapp/webapp/demo/forms.py
deleted file mode 100644
index 100efc438..000000000
--- a/webapp/webapp/demo/forms.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from email.policy import default
-from django import forms
-from .models import Downloader
-
-
-class DownloaderForm(forms.ModelForm):
-    consent = forms.BooleanField(required=True, label=(
-        f"I consent to MedCAT collecting and storing my names, email, company"
-        f" or academic institution name, funder and project title, and use"
-        f" case description. I am aware that MedCAT has been funded through"
-        f" academic research grants, and therefore funding bodies require its"
-        f" support team to report wider impact and usage of produced works"
-        f" with the above information."
-    ))
-
-    def __init__(self, models, *args, **kwargs):
-        super().__init__(*args, *kwargs)
-        self.fields["modelpack"] = forms.ChoiceField(label="Select a model for download",
-                                                     choices=[(
-                                                         model.model_name,
-                                                         f"{model.model_display_name}{' (' + model.model_description + ')' if model.model_description else ''}"
-                                                     ) for model in models],
-                                                     widget=forms.RadioSelect())
-
-    class Meta:
-        model = Downloader
-        exclude = ['downloaded_file']
-        fields = [
-            "first_name",
-            "last_name",
-            "email",
-            "affiliation",
-            "funder",
-            "use_case",
-        ]
-        labels = {
-            "first_name": "First Name",
-            "last_name": "Last Name",
-            "email": "Email",
-            "affiliation": "Company or Academic Institution",
-            "funder": "Funder and Project Title (optional)",
-            "use_case": "Please describe your use case",
-        }
-        widgets = {
-            "affiliation": forms.TextInput(attrs={"size": 40}),
-            "funder": forms.TextInput(attrs={"size": 40}),
-            "use_case": forms.Textarea(attrs={"rows": 5, "cols": 40}),
-        }
diff --git a/webapp/webapp/demo/migrations/0001_initial.py b/webapp/webapp/demo/migrations/0001_initial.py
deleted file mode 100644
index 4d63843b9..000000000
--- a/webapp/webapp/demo/migrations/0001_initial.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Generated by Django 2.2.3 on 2019-09-17 11:43
-
-from django.db import migrations, models
-
-
-class Migration(migrations.Migration):
-
-    initial = True
-
-    dependencies = [
-    ]
-
-    operations = [
-        migrations.CreateModel(
-            name='UploadedText',
-            fields=[
-                ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
-                ('text', models.TextField(blank=True, default='')),
-                ('create_time', models.DateTimeField(auto_now_add=True)),
-            ],
-        ),
-    ]
diff --git a/webapp/webapp/demo/migrations/0002_downloader_medcatmodel.py b/webapp/webapp/demo/migrations/0002_downloader_medcatmodel.py
deleted file mode 100644
index d5ce70fd4..000000000
--- a/webapp/webapp/demo/migrations/0002_downloader_medcatmodel.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# Generated by Django 3.2.11 on 2022-04-06 16:42
-
-import django.core.files.storage
-from django.db import migrations, models
-import django.utils.timezone
-
-
-class Migration(migrations.Migration):
-
-    dependencies = [
-        ('demo', '0001_initial'),
-    ]
-
-    operations = [
-        migrations.CreateModel(
-            name='Downloader',
-            fields=[
-                ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
-                ('first_name', models.CharField(max_length=20)),
-                ('last_name', models.CharField(max_length=20)),
-                ('email', models.EmailField(max_length=50)),
-                ('affiliation', models.CharField(max_length=100)),
-                ('funder', models.CharField(blank=True, default='', max_length=100)),
-                ('use_case', models.TextField(max_length=200)),
-                ('downloaded_file', models.CharField(default=django.utils.timezone.now, max_length=100)),
-            ],
-        ),
-        migrations.CreateModel(
-            name='MedcatModel',
-            fields=[
-                ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
-                ('model_name', models.CharField(max_length=20, unique=True)),
-                ('model_file', models.FileField(storage=django.core.files.storage.FileSystemStorage(location='/medcat_data'), upload_to='')),
-                ('model_display_name', models.CharField(max_length=50)),
-                ('model_description', models.TextField(default=django.utils.timezone.now, max_length=200)),
-            ],
-        ),
-    ]
diff --git a/webapp/webapp/demo/migrations/__init__.py b/webapp/webapp/demo/migrations/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/webapp/webapp/demo/models.py b/webapp/webapp/demo/models.py
deleted file mode 100644
index da2c8aa5d..000000000
--- a/webapp/webapp/demo/models.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from django.db import models
-from django.core.files.storage import FileSystemStorage
-
-
-MODEL_FS = FileSystemStorage(location="/medcat_data")
-
-
-# Create your models here.
-class UploadedText(models.Model):
-    text = models.TextField(default="", blank=True)
-    create_time = models.DateTimeField(auto_now_add=True)
-
-
-class Downloader(models.Model):
-    first_name = models.CharField(max_length=20)
-    last_name = models.CharField(max_length=20)
-    email = models.EmailField(max_length=50)
-    affiliation = models.CharField(max_length=100)
-    funder = models.CharField(max_length=100, blank=True, default="")
-    use_case = models.TextField(max_length=200)
-    downloaded_file = models.CharField(max_length=100)
-
-    def __str__(self):
-        return f'{self.first_name} - {self.last_name}'
-
-
-class MedcatModel(models.Model):
-    model_name = models.CharField(max_length=20, unique=True)
-    model_file = models.FileField(storage=MODEL_FS)
-    model_display_name = models.CharField(max_length=50)
-    model_description = models.TextField(max_length=200)
diff --git a/webapp/webapp/demo/static/css/annotations.css b/webapp/webapp/demo/static/css/annotations.css
deleted file mode 100644
index b0dc44531..000000000
--- a/webapp/webapp/demo/static/css/annotations.css
+++ /dev/null
@@ -1,110 +0,0 @@
-.textbox{
-  margin-left: 5%;
-  margin-right: 5%;
-  margin-top: 5%;
-}
-
-.train-annotations{
-  margin-top: 3%;
-  min-height: 80vh;
-}
-
-.annotations{
-  padding-top: 1%;
-  overflow: scroll;
-  max-height: 82vh;
-  padding-left: 40px;
-  margin-left: -20px;
-  box-shadow: 5px -5px 5px -5px #9476518a;
-}
-
-.green:hover {
-  color: green;
-}
-
-.red:hover {
-  color: red;
-}
-
-
-div.icons{
-  margin-top: -15px;
-  margin-left: 30px;
-  font-size: 30px;
-  position: relative;
-  float: right;
-}
-
-table.info{
-  margin-top: 0%;
-}
-
-.concept-info{
-  box-shadow: -5px -5px 5px -5px #aaaaaa;
-  padding-top: 1%;
-  overflow: scroll;
-  max-height: 82vh;
-  margin-left: 20px;
-}
-
-.row{
-  min-height: 80vh;
-  max-height: 80vh;
-}
-
-mark{
-  cursor: pointer;
-}
-
-i{
-  cursor: pointer;
-}
-
-
-.btns-a{
-  margin-bottom:10px; 
-}
-
-.btns-a-r{
-  margin-left: 10px;
-  margin-bottom:10px; 
-  position: relative;
-  float: right;
-}
-
-td.first{
-  width: 50px;
-}
-
-td.second{
-  width: 100%;
-}
-
-.w100{
-  width: 100%
-}
-
-.flt-right{
-  float: right;
-  margin-right: 5px;
-}
-
-.posfed{
-  -webkit-animation: fadein 0.3s;
-  color: green;
-  float: right;
-  margin-right: 10px;
-  border-left: 4px solid black;
-  padding: 4px;
-  padding-top: 8px;
-  padding-bottom: 8px;
-  position: fixed;
-  top: 3px;
-  right: 3px;
-}
-
-/* Safari, Chrome and Opera > 12.1 */
-@-webkit-keyframes fadein {
-  from { opacity: 0; }
-  to   { opacity: 1; }
-}
diff --git a/webapp/webapp/demo/static/css/base.css b/webapp/webapp/demo/static/css/base.css
deleted file mode 100644
index 0c76c6904..000000000
--- a/webapp/webapp/demo/static/css/base.css
+++ /dev/null
@@ -1,86 +0,0 @@
-/* Modal related styles */
-.modal-mask {
-  position: fixed;
-  z-index: 9998;
-  top: 0;
-  left: 0;
-  width: 100%;
-  height: 100%;
-  background-color: rgba(0, 0, 0, .5);
-  display: table;
-  transition: opacity .3s ease;
-}
-
-.modal-wrapper {
-  display: table-cell;
-  vertical-align: middle;
-}
-
-.modal-container {
-  width: 400px;
-  margin: 0px auto;
-  padding: 20px 30px;
-  background-color: #fff;
-  border-radius: 2px;
-  box-shadow: 0 2px 8px rgba(0, 0, 0, .33);
-  transition: all .3s ease;
-  font-family: Helvetica, Arial, sans-serif;
-}
-
-.modal-header div {
-  width: 100%;
-}
-
-.modal-header h3 {
-  margin-top: 0;
-  color: #42b983;
-  text-align: center;
-}
-
-.modal-header h4 {
-  display: inline-block;
-}
-
-.modal-header .close {
-  float: right;
-}
-
-.modal-body {
-  margin: 20px 0;
-}
-
-.modal-footer {
-  display: flex !important;
-  justify-content: center !important;
-}
-
-.modal-default-button {
-  float: right;
-}
-
-.modal-content {
-  border: 0;
-}
-
-/*
- * The following styles are auto-applied to elements with
- * transition="modal" when their visibility is toggled
- * by Vue.js.
- *
- * You can easily play with the modal transition by editing
- * these styles.
- */
-
-.modal-enter {
-  opacity: 0;
-}
-
-.modal-leave-active {
-  opacity: 0;
-}
-
-.modal-enter .modal-container,
-.modal-leave-active .modal-container {
-  -webkit-transform: scale(1.1);
-  transform: scale(1.1);
-}
diff --git a/webapp/webapp/demo/static/css/home.css b/webapp/webapp/demo/static/css/home.css
deleted file mode 100644
index 12ed6666f..000000000
--- a/webapp/webapp/demo/static/css/home.css
+++ /dev/null
@@ -1,23 +0,0 @@
-p.welcome{
-  margin-top: 30px;
-  margin-left: 40px;
-  font-size: 16px;
-}
-
-.usecase{
-  margin-top: 100px;
-}
-
-.file-upload {
-  display: inline;
-}
-
-.download {
-  display: inline;
-}
-
-html,
-body {
-  height: 100%;
-  margin: 0
-}
diff --git a/webapp/webapp/demo/static/image/favicon.ico b/webapp/webapp/demo/static/image/favicon.ico
deleted file mode 100644
index 5e8c53af57922a5280844f9001c03a1a1ef8a6de..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 4641
zcmeHLd0Z2B79T2#U_t6pC|dMDXC{;P-uK<_<(Fgx
z_<0!`m>B>7VCdt`3Bul&V;64zJna5tXOS!Rrh^1|xdA2h=G|%zEXfbA;*K(z}1
zUSlLx7XZYOF!8Mbz{~{z6M6pSKo-W(mwSgJ0I*=O`qBaN@=dV^4^c=M8pd7Egk%y2
zkT2uG4oZm}lLi2ml8HS^U=+kDC1NSURJvGdB$(KNm6R%M}QITNMk`x&W3k#5_<`Bsa
zB;rgkSSgwXrsm8F<3kfs$>S7aO-y`<2#a9}EJYCvN1o9fmRl}&0y@P_B`DXJlSpPn
zBB)0^7W!B!XG9_)iE@}21S7II1q6Gx_K?r8koSdjmV=yi9bcDuMxRxzCN!ykdM+?b^m;#wX
z0-|h@1QfzVxm3s^P7<{wHJ!r5f)$FSpqL|r;?$kx!(n3#2fm2Ubf%M>Nlp|RK>#^H
z1V<`^N8s^kAc5k{Ct>47q0yWKEaI$c-ZAAZLQt6^QLA!T{WQ}Y498BTjx)p}%~Z-k
z1p=#k&&6`Q`6kAAVoKE*O%F2>nE*`y6|k!i3yAHiUO%{i@$SO8yW=P%vNO(x3qm5P
zy5($H#F>DT0ZiB!GXO<{QX$M|6UTw$gxMhp(Xf<_*<}z&6avLDgiK+Q=u9$c)3^$a
z!0dRD6hT2L1dr86^r!1H8krV+b{in1nF9ojhxy}fjnSrqSVZ-<6pQ{}Hg#R6XJ@h%
zweqeWD|OlEbdV0xAc(-=gA{@z9pn)hj@Zpn;7o&QWExcfV$0!wS+>z)P94JW!G3pW
zS)&v0^KL`_xyasuPiAjSwbm{gJsI$25SF-!ZvdtfpH88ZogsoVoysSW)lVReMryY%oo|)Uv?F5@+{FspD}pSC!GVB#^J#n^Sl31Y=E@)Q`5YzU%_++kf6k
zZ9bdYn0NWUfdAaf&DQ2AB$Y^p+cp)y9}&@HwXuG4`}2a7%Fcq$-d_|tz2eS-p@PmK
z*6n^)=jMMMEwrG^ujvI{J?t$0Dc~~NeMk7Aeb|6BK7RP~q0ZZt>`#Y#0zj~>cVuMb
zOxJQ^)YX?qj;s!)EpyrKSg5#g;X=x>M?aQ}n<$~m-sTJ3s;VlJJ9qE88`FZHKYq*)
z4hSHfPkJ@j`J(^Pah{=Kacn%h8@J~{bMv6DuP-~Fft=bVfBjqc%GIm4(M{ivva_?>
zzKTj6;@>jcNxqu$_Vr4eD~6|SzOShv
zEiiFNO0COtDt>TL^1bP*-Mg<}cTG)BPEPdjsB$%Z%Y{G67#ti7wxb@J!z(|1`iYNO
zkvu8ssx5)=79IMvZmFHvYtO+;jeWed-;&=9Tb~OnYv0X4el0%)M9uSoKJHNcYFZdG
zZ?BE9AD8PSKN(SuUy}mZ>Dgv)%P>`2
zqoN*K7_ayB-4}i#W-BaVv9<;IP5+XG6tM<<|Q{8u@(bzYi7e{>S
zd!EyD!ehU~!R!4mK25k4xtGV(_wZQ48DiqlW#=qDiGS=>^T9_4qJ=`C83{46Wl1wX
zNEz_(1bPDdK3t%C-};PdUJY)~rcIj+2U}52RGCgIc>Xtwi2{=b~9USAypZb
z{UOX{s1EjB{?TCjnZrFNN=pau4Z9ZBBp*m}adCOJ$}iu$J#poMH~^Qt(=je@NdMz{
z*ktcYTie<~I~PCC>_Vjv9_g%qdnl<3%x`@_QqB#0S#a!1^<~x4RqYA>W*{-Suqs
zLZ2Z^{27yY<&h*(%0idS);$gVWr#&`4nFJNI@6|8Etl4q8+~OU#9cdFQ*9pn&F4jEdxaVL+KL6V?sCJrXGdFGn^nxNw%mPfIiBnM-zy@E
zhIYHHAY(iKuxZ+721%>@3r6(5Z8~?d=pnTeGV8Tmv?9wUDAZfmuX^C-$l4NO+O_=D
z4zMAz94;*>DT&Iq%8V}PHou}W{`qZR%_00{{m^u;{-11h!+H#*g%QQY#jpNyIyzKP
zOWaorVLxk~YqRTLB33}+o8&U#?PdC}bT=5D4K$qF5pI8U*Dl>OyWu6Z^TW?VmfOC*
z+obf)EOjWf%C65P-zeNuSse5H#GQU0uZlWjKYe_ZgNgkR-Ri0J4&N8%O2@{_Zz8@H
zTW_Q>V?Jw6_|3D!B5Qxqo|cEL+e{OBn>TEXjFbc{t?O&OawFZyf^sJGTt!59xJZ@Z
z|NG~$Jr%Bxw%9DPXWn70$Z0J!7=AH4Y?goUio3h}smjXZ2!d^aAF{HHJfUyShZjS7
z`sA2za>94&tzb*TN6h-TU(oI~yDfgw*GKETecV}iyTD0UpHEhn&2NsiNPQaU;?x|M
z(A(R4wA@~|R(E;9ExlsPSfgmR$3c+$osGlC$g_5jIe?;0H|=@=&KMW{W1wGf)vJ@-
zp$9E3SB-9ib{%z9Efpm@zwTWF)pwSvh6dVh2)1s$cIyjge?MOy?*x7Bsx22O!T00u
z)dl9Z)!^^zG<}eD59YLq==ji;%W`+hRkzJQiC?d*${dRh`0G~y>qNe)
z(Qg%+71&jNuO!CI^|%=tSZ$GAdbgn=VD}={TKBF5;52DV`N+l_x;75xwT#j)``
=bMn!EYzBEw<5=;1B<(!N5he@{*_a{Obrce!Wy t(q~)5OD - - - CAT Trainer - - - - - - {% load static %} - - - {% block style %} - {% endblock %} - - - - {% block body %} - {% endblock %} - - - - - {% block script %} - {% endblock %} - - diff --git a/webapp/webapp/demo/templates/train_annotations.html b/webapp/webapp/demo/templates/train_annotations.html deleted file mode 100644 index 25677cd21..000000000 --- a/webapp/webapp/demo/templates/train_annotations.html +++ /dev/null @@ -1,147 +0,0 @@ -{% extends 'base.html' %} -{% load static %} - -{% block style %} - - - - -{% endblock %} - -{% block body %} - - {% if not doc_html %} -
-
- {% csrf_token %} -
- - - -
- -
-
-
-
-
Disclaimer
-

This software is intended solely for the testing purposes and non-commercial use. THE SOFTWARE IS PROVIDED "AS IS", - WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. -

contact@cogstack.com for more information.

-

-
-
Sample text
-
-Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).
-CC: Left hand numbness on presentation; then developed lethargy later that day.
-
-HX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.
-
-He had been experiencing falling spells without associated LOC up to several times a month for the past year.
-
-MEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.
-
-PMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia.
-            
-
-
-
Please note this is a limited version of MedCAT and it is not trained or validated by clinicans.
-
- - {% else %} -
-
- -
-
- {{ doc_html|safe }} -
-
-
-
[[selected_concept.pretty_name]]
- - - - - - -
- [[name]] - - [[value]] -
-
-
-
Create a new Concept
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- Name - - -
- CUI - - -
- TUI - - -
- Source Value - - -
- Synonyms - - -
- Context - - -
-
-
-
-
- {% endif %} - - - - - - -{% endblock %} diff --git a/webapp/webapp/demo/templates/umls_user_validation.html b/webapp/webapp/demo/templates/umls_user_validation.html deleted file mode 100644 index 7ceb50f8e..000000000 --- a/webapp/webapp/demo/templates/umls_user_validation.html +++ /dev/null @@ -1,67 +0,0 @@ -{% extends 'base.html' %} -{% load static %} - -{% block style %} - - - - -{% endblock %} - -{% block body %} - -
-
{{ message }}
-
    - {% if downloader_form.non_field_errors %} -
  • {{ downloader_form.non_field_errors }}
  • - {% endif %} - {% for field in downloader_form %} - {% if field.errors %} -
  • - {{ field.label }} -
      - {% for error in field.errors %} -
    • {{ error }}
    • - {% endfor %} -
    -
  • - {% endif %} - {% endfor %} -
-{% if is_valid %} -
For any update on your previous information, contact contact@cogstack.org.
-
- {% csrf_token %} - {% for field in downloader_form.visible_fields %} - {% if field.name != 'consent' and field.name != 'modelpack' %} -
-
- {{ field }} -
- {% endif %} - {% endfor %} -
- {{ downloader_form.modelpack.label }}:
- {% for radio in downloader_form.modelpack %} - {{ radio }}
- {% endfor %} -
- {{ downloader_form.consent }} {{ downloader_form.consent.label }}

- -
-{% endif %} -
- - - - - - - -{% endblock %} diff --git a/webapp/webapp/demo/tests.py b/webapp/webapp/demo/tests.py deleted file mode 100644 index 7ce503c2d..000000000 --- a/webapp/webapp/demo/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/webapp/webapp/demo/urls.py b/webapp/webapp/demo/urls.py deleted file mode 100644 index 8919757d0..000000000 --- a/webapp/webapp/demo/urls.py +++ /dev/null @@ -1,9 +0,0 @@ -from django.contrib import admin -from django.urls import path -from .views import * - -urlpatterns = [ - path('', show_annotations, name='train_annotations'), - path('auth-callback', validate_umls_user, name='validate-umls-user'), - path('download-model', download_model, name="download-model") -] diff --git a/webapp/webapp/demo/views.py b/webapp/webapp/demo/views.py deleted file mode 100644 index 58d6bf6a7..000000000 --- a/webapp/webapp/demo/views.py +++ /dev/null @@ -1,129 +0,0 @@ -import sys -sys.path.insert(0, '/home/ubuntu/projects/MedCAT/') -import os -import json -from django.shortcuts import render -from django.http import StreamingHttpResponse, HttpResponse -from wsgiref.util import FileWrapper -from medcat.cat import CAT -from medcat.cdb import CDB -from medcat.utils.helpers import doc2html -from medcat.vocab import Vocab -from urllib.request import urlretrieve, urlopen -from urllib.error import HTTPError -#from medcat.meta_cat import MetaCAT -from .models import * -from .forms import DownloaderForm - -AUTH_CALLBACK_SERVICE = 'https://medcat.rosalind.kcl.ac.uk/auth-callback' -VALIDATION_BASE_URL = 'https://uts-ws.nlm.nih.gov/rest/isValidServiceValidate' -VALIDATION_LOGIN_URL = f'https://uts.nlm.nih.gov/uts/login?service={AUTH_CALLBACK_SERVICE}' - -model_pack_path = os.getenv('MODEL_PACK_PATH', 'models/medmen_wstatus_2021_oct.zip') - -try: - cat = CAT.load_model_pack(model_pack_path) -except Exception as e: - print(str(e)) - -def get_html_and_json(text): - doc = cat(text) - - a = json.loads(cat.get_json(text)) - for id, ent in a['annotations'].items(): - new_ent = {} - for key in ent.keys(): - if key == 'pretty_name': - new_ent['Pretty Name'] = ent[key] - if key == 'icd10': - icd10 = ent.get('icd10', []) - new_ent['ICD-10 Code'] = icd10[-1] if icd10 else '-' - if key == 'cui': - new_ent['Identifier'] = ent[key] - if key == 'types': - new_ent['Type'] = ", ".join(ent[key]) - if key == 'acc': - new_ent['Confidence Score'] = ent[key] - if key == 'start': - new_ent['Start Index'] = ent[key] - if key == 'end': - new_ent['End Index'] = ent[key] - if key == 'id': - new_ent['id'] = ent[key] - if key == 'meta_anns': - meta_anns = ent.get("meta_anns", {}) - if meta_anns: - for meta_ann in meta_anns.keys(): - new_ent[meta_ann] = meta_anns[meta_ann]['value'] - - a['annotations'][id] = new_ent - - doc_json = json.dumps(a) - uploaded_text = UploadedText() - uploaded_text.text = len(str(text))#str(text) no saving of text anymore - uploaded_text.save() - - return doc2html(doc), doc_json - - -def show_annotations(request): - context = {} - context['doc_json'] = '{"msg": "No documents yet"}' - - if request.POST and 'text' in request.POST: - doc_html, doc_json = get_html_and_json(request.POST['text']) - - context['doc_html'] = doc_html - context['doc_json'] = doc_json - context['text'] = request.POST['text'] - return render(request, 'train_annotations.html', context=context) - - -def validate_umls_user(request): - ticket = request.GET.get('ticket', '') - validate_url = f'{VALIDATION_BASE_URL}?service={AUTH_CALLBACK_SERVICE}&ticket={ticket}' - try: - is_valid = urlopen(validate_url, timeout=10).read().decode('utf-8') - context = { - 'is_valid': is_valid == 'true' - } - if is_valid == 'true': - context['message'] = 'License verified! Please fill in the following form before downloading models.' - context['downloader_form'] = DownloaderForm(MedcatModel.objects.all()) - else: - context['message'] = f'License not found. Please request or renew your UMLS Metathesaurus License. If you think you have got the license, try {VALIDATION_LOGIN_URL} again.' - except HTTPError: - context = { - 'is_valid': False, - 'message': 'Something went wrong. Please try again.' - } - finally: - return render(request, 'umls_user_validation.html', context=context) - - -def download_model(request): - if request.method == 'POST': - downloader_form = DownloaderForm(MedcatModel.objects.all(), request.POST) - if downloader_form.is_valid(): - mp_name = downloader_form.cleaned_data['modelpack'] - model = MedcatModel.objects.get(model_name=mp_name) - if model is not None: - mp_path = model.model_file.path - else: - return HttpResponse(f'Error: Unknown model "{downloader_form.modelpack}"') - resp = StreamingHttpResponse(FileWrapper(open(mp_path, 'rb'))) - resp['Content-Type'] = 'application/zip' - resp['Content-Length'] = os.path.getsize(mp_path) - resp['Content-Disposition'] = f'attachment; filename={os.path.basename(mp_path)}' - downloader_form.instance.downloaded_file = os.path.basename(mp_path) - downloader_form.save() - return resp - else: - context = { - 'is_valid': True, - 'downloader_form': downloader_form, - 'message': 'All non-optional fields must be filled out:' - } - return render(request, 'umls_user_validation.html', context=context) - else: - return HttpResponse('Erorr: Unknown HTTP method.') diff --git a/webapp/webapp/etc/cron.d/db-backup-cron b/webapp/webapp/etc/cron.d/db-backup-cron deleted file mode 100644 index 6c8fe22b1..000000000 --- a/webapp/webapp/etc/cron.d/db-backup-cron +++ /dev/null @@ -1 +0,0 @@ -* * * * * /usr/local/bin/python /webapp/manage.py runcrons >/dev/null 2>&1 diff --git a/webapp/webapp/manage.py b/webapp/webapp/manage.py deleted file mode 100755 index cdef512a0..000000000 --- a/webapp/webapp/manage.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python -"""Django's command-line utility for administrative tasks.""" -import os -import sys - - -def main(): - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'webapp.settings') - try: - from django.core.management import execute_from_command_line - except ImportError as exc: - raise ImportError( - "Couldn't import Django. Are you sure it's installed and " - "available on your PYTHONPATH environment variable? Did you " - "forget to activate a virtual environment?" - ) from exc - execute_from_command_line(sys.argv) - - -if __name__ == '__main__': - main() diff --git a/webapp/webapp/models/.keep b/webapp/webapp/models/.keep deleted file mode 100644 index e69de29bb..000000000 diff --git a/webapp/webapp/requirements.txt b/webapp/webapp/requirements.txt deleted file mode 100644 index c525cf0e4..000000000 --- a/webapp/webapp/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -Django==3.2.25 -django-dbbackup==4.0.0b0 -django-storages[boto3]==1.12.3 -django-cron==0.5.1 -medcat==1.2.7 -urllib3==1.26.18 diff --git a/webapp/webapp/webapp/__init__.py b/webapp/webapp/webapp/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/webapp/webapp/webapp/settings.py b/webapp/webapp/webapp/settings.py deleted file mode 100644 index cd68965dc..000000000 --- a/webapp/webapp/webapp/settings.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -Django settings for webapp project. - -Generated by 'django-admin startproject' using Django 2.2.3. - -For more information on this file, see -https://docs.djangoproject.com/en/2.2/topics/settings/ - -For the full list of settings and their values, see -https://docs.djangoproject.com/en/2.2/ref/settings/ -""" - -import os - -# Build paths inside the project like this: os.path.join(BASE_DIR, ...) -BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - - -# Quick-start development settings - unsuitable for production -# See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/ - -# SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = '2y3*wm_n52xyis_kaup96+5^*^*$h^!!na-$n%l9ppc0rhfea$' - -# SECURITY WARNING: don't run with debug turned on in production! -DEBUG = False - -ALLOWED_HOSTS = ['*'] - - -# Application definition - -INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'dbbackup', - 'django_cron', - 'demo', -] - -MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', -] - -ROOT_URLCONF = 'webapp.urls' - -TEMPLATES = [ - { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', - ], - }, - }, -] - -WSGI_APPLICATION = 'webapp.wsgi.application' - - -# Database -# https://docs.djangoproject.com/en/2.2/ref/settings/#databases - -DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(BASE_DIR, 'db/db.sqlite3'), - } -} - -DB_BACKUP_ON_S3 = os.environ.get('DB_BACKUP_ON_S3', 'False') -DB_BACKUP_LOCATION = os.environ.get('DB_BACKUP_LOCATION', 'demo-db-backup/') -if DB_BACKUP_ON_S3 == "False": - DBBACKUP_STORAGE = 'django.core.files.storage.FileSystemStorage' - DBBACKUP_STORAGE_OPTIONS = {'location': f'/tmp/{DB_BACKUP_LOCATION}'} -else: - DBBACKUP_STORAGE = 'storages.backends.s3boto3.S3Boto3Storage' - DBBACKUP_STORAGE_OPTIONS = { - 'region_name': 'eu-west-2', - 'access_key': os.environ.get('ACCESS_KEY', ''), - 'secret_key': os.environ.get('SECRET_KEY', ''), - 'bucket_name': os.environ.get('BUCKET_NAME', ''), - 'default_acl': 'bucket-owner-full-control', - 'location': DB_BACKUP_LOCATION, - } - -CRON_CLASSES = [ - 'demo.db_backup.DbBackup', -] -DJANGO_CRON_DELETE_LOGS_OLDER_THAN = int(os.environ.get('DELETE_LOGS_OLDER_THAN', '7')) - -# Password validation -# https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators - -AUTH_PASSWORD_VALIDATORS = [ - { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', - }, -] - - -# Internationalization -# https://docs.djangoproject.com/en/2.2/topics/i18n/ - -LANGUAGE_CODE = 'en-us' - -TIME_ZONE = 'UTC' - -USE_I18N = True - -USE_L10N = True - -USE_TZ = True - - -# Static files (CSS, JavaScript, Images) -# https://docs.djangoproject.com/en/2.2/howto/static-files/ - -STATIC_URL = '/static/' -STATIC_ROOT = os.path.join(BASE_DIR, 'demo', 'static') -MEDIA_URL = '/media/' -MEDIA_ROOT = os.path.join(BASE_DIR, 'data') diff --git a/webapp/webapp/webapp/urls.py b/webapp/webapp/webapp/urls.py deleted file mode 100644 index 703574edf..000000000 --- a/webapp/webapp/webapp/urls.py +++ /dev/null @@ -1,26 +0,0 @@ -"""webapp URL Configuration - -The `urlpatterns` list routes URLs to views. For more information please see: - https://docs.djangoproject.com/en/2.2/topics/http/urls/ -Examples: -Function views - 1. Add an import: from my_app import views - 2. Add a URL to urlpatterns: path('', views.home, name='home') -Class-based views - 1. Add an import: from other_app.views import Home - 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') -Including another URLconf - 1. Import the include() function: from django.urls import include, path - 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) -""" -from django.contrib import admin -from django.urls import path, include, re_path -from django.conf import settings -from django.views.static import serve - -urlpatterns = [ - path('admin/', admin.site.urls), - path('', include('demo.urls')), - re_path(r'^static/(?P.*)$', serve,{'document_root': settings.STATIC_ROOT}), - re_path(r'^media/(?P.*)$', serve,{'document_root': settings.MEDIA_ROOT}), -] diff --git a/webapp/webapp/webapp/wsgi.py b/webapp/webapp/webapp/wsgi.py deleted file mode 100644 index 420a2338a..000000000 --- a/webapp/webapp/webapp/wsgi.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -WSGI config for webapp project. - -It exposes the WSGI callable as a module-level variable named ``application``. - -For more information on this file, see -https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/ -""" - -import os - -from django.core.wsgi import get_wsgi_application - -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'webapp.settings') - -application = get_wsgi_application() From f6f3654f63eea0e1c6380cdee1d4db8f24eafd53 Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Wed, 19 Jun 2024 17:36:15 +0100 Subject: [PATCH 7/9] Split tests in production workflow to avoid Ouf Of Memory (137) (#456) --- .github/workflows/production.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml index 3d779371b..44fc53ebf 100644 --- a/.github/workflows/production.yml +++ b/.github/workflows/production.yml @@ -28,7 +28,13 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements-dev.txt - python -m unittest discover + all_files=$(git ls-files | grep '^tests/.*\.py$' | grep -v '/__init__\.py$' | sed 's/\.py$//' | sed 's/\//./g') + num_files=$(echo "$all_files" | wc -l) + midpoint=$((num_files / 2)) + first_half_nl=$(echo "$all_files" | head -n $midpoint) + second_half_nl=$(echo "$all_files" | tail -n +$(($midpoint + 1))) + timeout 25m python -m unittest ${first_half_nl[@]} + timeout 25m python -m unittest ${second_half_nl[@]} - name: Install pypa/build run: >- From 5c510b55245bef268e5affee75b2be38d8cbc04f Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Fri, 25 Apr 2025 13:30:22 +0100 Subject: [PATCH 8/9] v1.16.0 PR (#536) * CU-8693bc9kc: Add python 3.12 support (#511) * CU-8693bc9kc: Add python 3.12 support * CU-8693bc9kc: Amend dependencies so as to be compatible with python 3.12 * Bump default spacy model version (to 3.8) * CU-8693bc9kc: Fix some typing issues due to numpy2 * CU-8693bc9kc: Fix some typing issues due to numpy2 (try 2) * CU-8693bc9kc: Change spacy models to 3.7.2 * CU-8693bc9kc: Pin numpy to v1 * CU-8693bc9kc: Fix numpy requirement comment * CU-8693bc9kc: Fix usage of old/deprecated assert methods in tests * CU-8693bc9kc: Update some requirement comments * CU-8697c86rf: Update docs build requirements (#514) * CU-8697c86rf: Update docs build requirements * CU-8697c86rf: Fix docs build requirements (hopefully) * CU-8697c86rf: Fix docs build requirements (hopefully) x2 * CU-8697x7y9x: Fix issue with transformers 4.47+ affecting DeID (#517) * CU-8697x7y9x: Fix issue with transformers 4.47+ affecting DeID * CU-8697x7y9x: Add type-ignore to module unrelated to current change * Updates for MetaCAT (#515) * Pushing update for MetaCAT - Addressing the multiple zero-division-error warnings per epoch while training - Accommodating the variations in category name and class name across NHS sites * Adding comments * Pushing requested changes * Pushing type fix * Pushing updates to metacat config * Support expansion of transformers ner models to include new concepts (#519) * CU-8697v6qr2 support expansion of transformers ner models to include new concepts * CU-8697v6qr2 add logging suggested by the review * CU-869805t7e alt names fixes (#520) * CU-869805t7e: Move getting of applicable category name to the config * CU-869805t7e: Use alternative category names in eval method * CU-869805t7e: Reduce indentation * CU-869805t7e: Reduce indentation (again) * CU-869805t7e: Some comment fixing due to rearrangements before * CU-869805t7e: Fix usage of matched class name when encoding category values * CU-869805t7e: Avoid duplicating exception message * CU-8697qfvzz train metacat on sup train (#516) * CU-8697qfvzz: Add new optional keyword argumnet to allow training MetaCAT models during supervised training * CU-8697qfvzz: Add tests regarding training meta-cats during supervised training * CU-8697qfvzz: Fix small typo in comment * CU-8697qfvzz: Allow using alternative category names if/when training meta cats through CAT.train_supervised * CU-8698ek477: Fix AdamW import from tranformers to torch (#523) * CU-8698ek477: Add TODO to MetaCAT ML utils regarding AdamW import * CU-8698ek477: Fix AdamW import (trf->torch) * CU-8698f8fgc: Fix negative sampling including indices for words without a vector (#524) * CU-8698f8fgc: Add new test to check that the negative sampling indices do not include non-vectored indices * CU-8698f8fgc: Add fix for negative sampling including indices for words without a vector * CU-8698f8fgc: Update tests to make sure index frequencies are respected * CU-8698f8fgc: Add 3.9-friendly counter totalling method * CU-8698gkrqa: Add argument to allow specifying the changes warrenting a model save (#525) * CU-8698hfkch: Add eval method to deid model * CU-8698hfkch: lint checks * CU-8698gqumv: Fix regression test vocab vector sizes (#526) * CU-8698gqumv: Add tests for Vocab upon regression testing * CU-8698gqumv: Fix regression time vocab data * CU-86983ruw9 Fix test train split (#521) * CU-86983ruw9: Fix train-test splitter leaving train set empty for smaller datasets * CU-86983ruw9: Add additional optional arguments to test-train splitting for minimum concept count and maximum test fraction * CU-86983ruw9: Add a few tests for test-train splitting * CU-8698hfkch: Add eval method to deid model (#527) * CU-8698hfkch: Add eval method to deid model * CU-8698hfkch: lint checks --------- Co-authored-by: Tom Searle * CU-8698jzjj3: pass in extra param if ignore_extra_labels is set, and test * CU-8698mqu96 Transformers update (4.51.0) fix (#531) * CU-8698mqu96: Update special tokens lengths attribute * CU-8698mqu96: Update MetaCAT usage of BertTokenizer.from_pretrained for type safety * CU-8698mqu96: Ignore typing where mypy is wrong + add note in code * CU-8698mqu96: Ignore typing where mypy may be wrong + add comment * CU-8698mqu96: Fix tokenizer wrapper import for rel cat * CU-8698mqu96: Rename evaluation strategy keyword argument in line with changes * CU-8698mqu96: Type-ignore method where mypy says it does not exist * CU-8698mqu96: Fix TRF-NER output dir typing issue * CU-8698mqu96: Update a doc string for darglint * CU-8698mqu96: Fix typing issue for TrfNER trainer callback * Relation extraction llama (#522) * Added files. * More additions to rel extraction. * Rel base. * Update. * Updates. * Dependency parsing. * Updates. * Added pre-training steps. * Added training & model utils. * Cleanup & fixes. * Update. * Evaluation updates for pretraining. * Removed duplicate relation storage. * Moved RE model file location. * Structure revisions. * Added custom config for RE. * Implemented custom dataset loader for RE. * More changes. * Small fix. * Latest additions to RelCAT (pipe + predictions) * Setup.py fix. * RE utils update. * rel model update. * rel dataset + tokenizer improvements. * RelCAT updates. * RelCAT saving/loading improvements. * RelCAT saving/loading improvements. * RelCAT model fixes. * Attempted gpu learning fix. Dataset label generation fixes. * Minor train dataset gen fix. * Minor train dataset gen fix No.2. * Config updates. * Gpu support fixes. Added label stats. * Evaluation stat fixes. * Cleaned stat output mode during training. * Build fix. * removed unused dependencies and fixed code formatting * Mypy compliance. * Fixed linting. * More Gpu mode train fixes. * Fixed model saving/loading issues when using other baes models. * More fixes to stat evaluation. Added proper CAT integration of RelCAT. * Setup.py typo fix. * RelCAT loading fix. * RelCAT Config changes. * Type fix. Minor additions to RelCAT model. * Type fixes. * Type corrections. * RelCAT update. * Type fixes. * Fixed type issue. * RelCATConfig: added seed param. * Adaptations to the new codebase + type fixes.. * Doc/type fixes. * Fixed input size issue for model. * Fixed issue(s) with model size and config. * RelCAT: updated configs to new style. * RelCAT: removed old refs to logging. * Fixed GPU training + added extra stat print for train set. * Type fixes. * Updated dev requirements. * Linting. * Fixed pin_memory issue when training on CPU. * Updated RelCAT dataset get + default config. * Updated RelDS generator + default config * Linting. * Updated RelDatset + config. * Pushing updates to model Made changes to: 1) Extracting given number of context tokens left and right of the entities 2) Extracting hidden state from bert for all the tokens of the entities and performing max pooling on them * Fixing formatting * Update rel_dataset.py * Update rel_dataset.py * Update rel_dataset.py * RelCAT: added test resource files. * RelCAT: Fixed model load/checkpointing. * RelCAT: updated to pipe spacy doc call. * RelCAT: added tests. * Fixed lint/type issues & added rel tag to test DS. * Fixed ann id to token issue. * RelCAT: updated test dataset + tests. * RelCAT: updates to requested changes + dataset improvements. * RelCAT: updated docs/logs according to commends. * RelCAT: type fix. * RelCAT: mct export dataset updates. * RelCAT: test updates + requested changes p2. * RelCAT: log for MCT export train. * Updated docs + split train_test & dataset for benchmarks. * type fixes. * RelCAT: Initial Llama integration. * RelCAT: updates to Llama impl. * RelCAT: model typo fix. * RelCAT: label_id /sample no. mixup fix. * Updated cleaned up Relataset, added new ways to create relations via anno types (doc/export only for now). * Added option to predict any text /w annotations via RelCAT. MCT export train fixes. * RelCAT: added sample limiter / class, more logging info. * RelCAT: test/train ds shuffle update. * RelCAT: added option to keep original text when using reldataset class. * Pushing change for stratified batching Implement stratified batching for improved class representation and balanced training * RelCAT: fixed doc processing issue + class weights. * RelCAT: class weights addtions to cfg + param. * RelCAT: added config params for Adam optimizer. * RelCAT updated default config. * RelCAT: config update + optimizer change. * RelCAT: fixed model freeze flags. * RelCAT: model optimizer save/load fix. * RelCAT: added export ent tag check. * Fixed issues when saving/loading model for class weights + inference device cast. * RelCAT: bug fix for ents that are @ EoS. * Rel Dataset updates. * Rel Dataset updates. * Pushing change for ModernBERT * Bumped transformers version. * Updated rel dataset generation from fake Spacy Docs. * ModernBert updates. * Updated RelCAT model-load/save. * Minor relCAT updates, code format. * Type check updates. * Fixed inference issue. * RelCAT: testing updates. * Type fixes. * Type fixes. * Type fixes. * Type fixes IV. * Type fixes python 3.9. * RelCAT: flake8 fixes. * RelCAT: flake8 fixes. * RelCAT: Updates (fixed model loading after save). * Fixed test. * Update RelCAT stuff for improved abstraction * Move separate model implementations to separate packages * Some minor abstraction changes * Remove accidentally copied abstract method decorator * Fix import in test * Fix RelCAT impport in pipe tests * Update base relcat model implementation to include config * Latest RelCAT module updates. * Type fixes + run issues. * Type fixes. * Fixed Llama tokenizer. * Type fixes. * Type fixes: Python3.10 adjustements. * Linting. * Fix base flake8 lint issues * Fix doc string in ConfigRelCAT.load * Fix base component init doc string * Fixed BaseComponent.load method doc string * Fix doc strings in rel_cat ml_utils * Fix doc strings in rel_cat models module * Fix rel-cat test time import * Fix type casting * Align pipe tests with rel cat changes * Fix property paths in rel cat tests * Updates. * Fixed tests. * Fixed relCAT config save. * Latest fixes for model saving/loading. * Lint fix. * RelCAT cfg load test fix. * Remove install requirements from gitignore --------- Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com> Co-authored-by: mart-r * CU-8698vewzp: Fix docs requirements (hopefully) (#534) * CU-8698veb6y: Use Ubuntu 24.04 for publishing to test PyPI (#533) --------- Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com> Co-authored-by: Xi Bai <82581439+baixiac@users.noreply.github.com> Co-authored-by: Tom Searle Co-authored-by: tomolopolis Co-authored-by: Vlad Dinu <62345326+vladd-bit@users.noreply.github.com> --- .github/workflows/main.yml | 4 +- .gitignore | 1 + docs/requirements.txt | 94 +-- install_requires.txt | 10 +- medcat/cat.py | 47 +- medcat/cdb.py | 8 +- medcat/config_meta_cat.py | 33 +- medcat/config_rel_cat.py | 98 ++- medcat/meta_cat.py | 36 +- medcat/ner/transformers_ner.py | 100 ++- medcat/rel_cat.py | 448 ++++++----- medcat/tokenizers/meta_cat_tokenizers.py | 9 +- medcat/utils/data_utils.py | 27 +- medcat/utils/meta_cat/data_utils.py | 65 +- medcat/utils/meta_cat/ml_utils.py | 7 +- medcat/utils/ner/deid.py | 4 + medcat/utils/ner/model.py | 32 + .../relation_extraction/base_component.py | 142 ++++ .../relation_extraction/bert/__init__.py | 0 .../utils/relation_extraction/bert/config.py | 30 + .../utils/relation_extraction/bert/model.py | 79 ++ .../relation_extraction/bert/tokenizer.py | 35 + medcat/utils/relation_extraction/config.py | 59 ++ .../relation_extraction/llama/__init__.py | 0 .../utils/relation_extraction/llama/config.py | 30 + .../utils/relation_extraction/llama/model.py | 218 +++++ .../relation_extraction/llama/tokenizer.py | 34 + medcat/utils/relation_extraction/ml_utils.py | 305 +++++++ medcat/utils/relation_extraction/models.py | 282 ++++--- .../modernbert/__init__.py | 0 .../relation_extraction/modernbert/config.py | 30 + .../relation_extraction/modernbert/model.py | 77 ++ .../modernbert/tokenizer.py | 34 + .../utils/relation_extraction/rel_dataset.py | 744 +++++++++--------- medcat/utils/relation_extraction/tokenizer.py | 79 +- medcat/utils/relation_extraction/utils.py | 277 ------- medcat/vocab.py | 23 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- setup.py | 1 + tests/ner/test_transformers_ner.py | 17 + .../regression/creation/vocab_data.txt | 4 +- tests/resources/regression/run_regression.sh | 4 + tests/resources/regression/test_vocab.py | 31 + tests/test_cat.py | 46 +- tests/test_pipe.py | 8 +- tests/test_rel_cat.py | 57 +- tests/test_transformers_ner.py | 242 ++++++ tests/test_vocab.py | 48 +- tests/utils/ner/test_deid.py | 7 + tests/utils/test_data_utils.py | 90 +++ tests/utils/test_memory_optimiser.py | 8 +- 52 files changed, 2897 insertions(+), 1171 deletions(-) create mode 100644 medcat/utils/relation_extraction/base_component.py create mode 100644 medcat/utils/relation_extraction/bert/__init__.py create mode 100644 medcat/utils/relation_extraction/bert/config.py create mode 100644 medcat/utils/relation_extraction/bert/model.py create mode 100644 medcat/utils/relation_extraction/bert/tokenizer.py create mode 100644 medcat/utils/relation_extraction/config.py create mode 100644 medcat/utils/relation_extraction/llama/__init__.py create mode 100644 medcat/utils/relation_extraction/llama/config.py create mode 100644 medcat/utils/relation_extraction/llama/model.py create mode 100644 medcat/utils/relation_extraction/llama/tokenizer.py create mode 100644 medcat/utils/relation_extraction/ml_utils.py create mode 100644 medcat/utils/relation_extraction/modernbert/__init__.py create mode 100644 medcat/utils/relation_extraction/modernbert/config.py create mode 100644 medcat/utils/relation_extraction/modernbert/model.py create mode 100644 medcat/utils/relation_extraction/modernbert/tokenizer.py delete mode 100644 medcat/utils/relation_extraction/utils.py create mode 100644 tests/resources/regression/test_vocab.py create mode 100644 tests/test_transformers_ner.py create mode 100644 tests/utils/test_data_utils.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1b7232bb6..e4dc454b6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.9', '3.10', '3.11' ] + python-version: [ '3.9', '3.10', '3.11', '3.12' ] max-parallel: 4 steps: @@ -76,7 +76,7 @@ jobs: github.ref == 'refs/heads/master' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') != true - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 timeout-minutes: 45 concurrency: publish-to-test-pypi needs: [build] diff --git a/.gitignore b/.gitignore index dcd2743f0..833b6e49f 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ tests/model_creator/output/* docs/auto/ docs/_build +models/ diff --git a/docs/requirements.txt b/docs/requirements.txt index 226900abf..5e0563461 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,23 +2,23 @@ sphinx==6.2.1 sphinx-rtd-theme~=1.0 myst-parser~=0.17 sphinx-autoapi~=3.0.0 -MarkupSafe==2.1.5 -accelerate==0.34.2 +MarkupSafe==3.0.2 +accelerate==1.2.1 aiofiles==24.1.0 -aiohttp==3.10.5 -aiosignal==1.3.1 -asttokens==2.4.1 -async-timeout==4.0.3 -attrs==24.2.0 +aiohttp==3.11.11 +aiosignal==1.3.2 +asttokens==3.0.0 +async-timeout==5.0.1 +attrs==24.3.0 backcall==0.2.0 blis==0.7.11 catalogue==2.0.10 -certifi==2024.8.30 -charset-normalizer==3.3.2 -click==8.1.7 +certifi==2024.12.14 +charset-normalizer==3.4.1 +click==8.1.8 comm==0.2.2 confection==0.1.5 -cymem==2.0.8 +cymem==2.0.10 darglint==1.8.1 datasets==2.21.0 decorator==5.1.1 @@ -27,80 +27,80 @@ exceptiongroup==1.2.2 executing==2.1.0 filelock==3.16.0 flake8==7.0.0 -frozenlist==1.4.1 +frozenlist==1.5.0 fsspec==2024.6.1 gensim==4.3.3 -huggingface-hub==0.24.7 +huggingface-hub==0.30.2 idna==3.10 -ipython==8.27.0 +ipython==8.31.0 ipywidgets==8.1.5 -jedi==0.19.1 -jinja2==3.1.4 +jedi==0.19.2 +jinja2==3.1.5 joblib==1.4.2 -jsonpickle==3.3.0 +jsonpickle==4.0.1 jupyterlab-widgets==3.0.13 -langcodes==3.4.0 +langcodes==3.5.0 matplotlib-inline==0.1.7 mccabe==0.7.0 mpmath==1.3.0 multidict==6.1.0 multiprocess==0.70.16 -murmurhash==1.0.10 +murmurhash==1.0.11 mypy==1.11.2 mypy-extensions==1.0.0 -networkx==3.3 -numpy==1.25.2 -packaging==24.1 -pandas==2.2.2 +networkx==3.4.2 +numpy==1.26.4 +packaging==24.2 +pandas==2.2.3 parso==0.8.4 pathy==0.11.0 -peft==0.12.0 +peft==0.14.0 pexpect==4.9.0 pickleshare==0.7.5 preshed==3.0.9 -prompt-toolkit==3.0.47 -psutil==6.0.0 +prompt-toolkit==3.0.48 +psutil==6.1.1 ptyprocess==0.7.0 pure-eval==0.2.3 -pyarrow==17.0.0 +pyarrow==18.1.0 pycodestyle==2.11.1 -pydantic==1.10.18 +pydantic==2.10.4 pyflakes==3.2.0 -pygments==2.18.0 +pygments==2.19.1 python-dateutil==2.9.0 pytz==2024.2 pyyaml==6.0.2 -regex==2024.9.11 +regex==2024.11.6 requests==2.32.3 safetensors==0.4.5 -scikit-learn==1.5.2 -scipy==1.9.3 -six==1.16.0 +scikit-learn==1.6.0 +scipy==1.13.1 +six==1.17.0 smart-open==6.4.0 -spacy==3.6.1 +spacy==3.7.5 spacy-legacy==3.0.12 spacy-loggers==1.0.5 -srsly==2.4.8 +srsly==2.5.0 stack-data==0.6.3 -sympy==1.13.2 -thinc==8.1.12 +sympy==1.13.1 +thinc==8.2.5 threadpoolctl==3.5.0 -tokenizers==0.19.1 -tomli==2.0.1 -torch==2.4.1 -tqdm==4.66.5 +tokenizers==0.21.0 +tomli==2.2.1 +torch==2.5.1 +tqdm==4.67.1 traitlets==5.14.3 -transformers==4.44.2 -triton==3.0.0 -typer==0.9.4 +transformers==4.51.3 +triton==3.1.0 +typer==0.15.1 types-PyYAML==6.0.3 types-aiofiles==0.8.3 types-setuptools==57.4.10 typing-extensions==4.12.2 -tzdata==2024.1 -urllib3==2.2.3 +tzdata==2024.2 +urllib3==2.3.0 wasabi==1.1.3 wcwidth==0.2.13 widgetsnbextension==4.0.13 xxhash==3.5.0 -yarl==1.11.1 \ No newline at end of file +yarl==1.18.3 diff --git a/install_requires.txt b/install_requires.txt index 136728d89..d4f1e1609 100644 --- a/install_requires.txt +++ b/install_requires.txt @@ -1,11 +1,11 @@ -'numpy>=1.22.0,<1.26.0' # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy +'numpy>=1.26.0,<2.0.0' # 1.26 is first to support 3.12; cannod support numpy2 due to spacy 'pandas>=1.4.2' # first to support 3.11 'gensim>=4.3.0,<5.0.0' # 5.3.0 is first to support 3.11; avoid major version bump -'spacy>=3.6.0,<3.8.0' # 3.8 only supports numpy2 which we can't use due to other dependencies -'scipy~=1.9.2' # 1.9.2 is first to support 3.11 -'transformers>=4.34.0,<5.0.0' # avoid major version bump +'spacy>=3.6.0,<4.0.0' # avoid major bump +'scipy>=1.9.2,<1.14.0' # 1.9.2 is first to support 3.11; 1.14.0 does not support 3.9 +'transformers>=4.48.1,<5.0.0' # avoid major version bump 'accelerate>=0.23.0' # required by Trainer class in de-id -'torch>=1.13.0,<3.0.0' # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now +'torch>=2.4.0,<3.0.0' # 2.4.0 is first to support 3.12; avoid major 3.0.0 for now 'tqdm>=4.27' 'scikit-learn>=1.1.3,<2.0.0' # 1.1.3 is first to supporrt 3.11; avoid major version bump 'dill>=0.3.6,<1.0.0' # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump diff --git a/medcat/cat.py b/medcat/cat.py index 13042acd0..508d3b897 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -40,6 +40,7 @@ from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME from medcat.stats.stats import get_stats +from medcat.stats.mctexport import count_all_annotations, iter_anns from medcat.utils.filters import set_project_filters from medcat.utils.usage_monitoring import UsageMonitor @@ -142,7 +143,7 @@ def _create_pipeline(self, config: Config): self.pipe.add_meta_cat(meta_cat, meta_cat.config.general.category_name) for rel_cat in self._rel_cats: - self.pipe.add_rel_cat(rel_cat, "_".join(list(rel_cat.config.general["labels2idx"].keys()))) + self.pipe.add_rel_cat(rel_cat, "_".join(list(rel_cat.component.relcat_config.general["labels2idx"].keys()))) # Set max document length self.pipe.spacy_nlp.max_length = config.preprocessing.max_document_length @@ -209,8 +210,12 @@ def get_model_card(self, as_dict: bool = False): else: return json.dumps(card, indent=2, sort_keys=False) - def _versioning(self, force_rehash: bool = False): + def _versioning(self, force_rehash: bool = False, + change_description: Optional[str] = None): # Check version info and do not allow without it + date_today = date.today().strftime("%d %B %Y") + if change_description is not None: + self.config.version.description += f"\n[{date_today}] {change_description}" if self.config.version.description == 'No description': logger.warning("Please consider populating the version information [description, performance, location, ontology] in cat.config.version") @@ -221,14 +226,17 @@ def _versioning(self, force_rehash: bool = False): if version.id is not None: version.history.append(version['id']) version.id = m - version.last_modified = date.today().strftime("%d %B %Y") + version.last_modified = date_today version.cdb_info = self.cdb.make_stats() version.meta_cats = [meta_cat.get_model_card(as_dict=True) for meta_cat in self._meta_cats] version.medcat_version = __version__ logger.warning("Please consider updating [description, performance, location, ontology] in cat.config.version") - def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_MODEL_PACK_NAME, force_rehash: bool = False, - cdb_format: str = 'dill') -> str: + def create_model_pack(self, save_dir_path: str, + model_pack_name: str = DEFAULT_MODEL_PACK_NAME, + force_rehash: bool = False, + change_description: Optional[str] = None, + cdb_format: str = 'dill') -> str: """Will crete a .zip file containing all the models in the current running instance of MedCAT. This is not the most efficient way, for sure, but good enough for now. @@ -239,6 +247,8 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M The model pack name. Defaults to DEFAULT_MODEL_PACK_NAME. force_rehash (bool): Force recalculation of hash. Defaults to `False`. + change_description (Optional[str]): + The description of the change due to which a save is required. Defaults to None. cdb_format (str): The format of the saved CDB in the model pack. The available formats are: @@ -253,7 +263,7 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M # Spacy model always should be just the name, but during loading it can be reset to path self.config.general.spacy_model = os.path.basename(self.config.general.spacy_model) # Versioning - self._versioning(force_rehash) + self._versioning(force_rehash, change_description) model_pack_name += "_{}".format(self.config.version.id) logger.warning("This will save all models into a zip file, can take some time and require quite a bit of disk space.") @@ -808,7 +818,8 @@ def train_supervised_from_json(self, retain_extra_cui_filter: bool = False, checkpoint: Optional[Checkpoint] = None, retain_filters: bool = False, - is_resumed: bool = False) -> Tuple: + is_resumed: bool = False, + train_meta_cats: bool = False) -> Tuple: """ Run supervised training on a dataset from MedCATtrainer in JSON format. @@ -825,7 +836,7 @@ def train_supervised_from_json(self, devalue_others, use_groups, never_terminate, train_from_false_positives, extra_cui_filter, retain_extra_cui_filter, checkpoint, - retain_filters, is_resumed) + retain_filters, is_resumed, train_meta_cats) def train_supervised_raw(self, data: Dict[str, List[Dict[str, dict]]], @@ -845,7 +856,8 @@ def train_supervised_raw(self, retain_extra_cui_filter: bool = False, checkpoint: Optional[Checkpoint] = None, retain_filters: bool = False, - is_resumed: bool = False) -> Tuple: + is_resumed: bool = False, + train_meta_cats: bool = False) -> Tuple: """Train supervised based on the raw data provided. The raw data is expected in the following format: @@ -922,6 +934,8 @@ def train_supervised_raw(self, a ValueError is raised. The merging is done in the first epoch. is_resumed (bool): If True resume the previous training; If False, start a fresh new training. + train_meta_cats (bool): + If True, also trains the appropriate MetaCATs. Raises: ValueError: If attempting to retain filters with while training over multiple projects. @@ -1081,6 +1095,21 @@ def train_supervised_raw(self, use_overlaps=use_overlaps, use_groups=use_groups, extra_cui_filter=extra_cui_filter) + if (train_meta_cats and + # NOTE if no annnotaitons, no point + count_all_annotations(data) > 0): # type: ignore + # NOTE: if there + logger.info("Training MetaCATs within train_supervised_raw") + _, _, ann0 = next(iter_anns(data)) # type: ignore + for meta_cat in self._meta_cats: + # only consider meta-cats that have been defined for the category + if 'meta_anns' in ann0: + ann_names = ann0['meta_anns'].keys() # type: ignore + # adapt to alternative names if applicable + cat_name = meta_cat.config.general.get_applicable_category_name(ann_names) + if cat_name in ann_names: + logger.debug("Training MetaCAT %s", meta_cat.config.general.category_name) + meta_cat.train_raw(data) # reset the state of filters self.config.linking.filters = orig_filters diff --git a/medcat/cdb.py b/medcat/cdb.py index 3961fc921..507e7d3b9 100644 --- a/medcat/cdb.py +++ b/medcat/cdb.py @@ -818,15 +818,17 @@ def most_similar(self, sim_data['sim_vectors_cuis'] = np.array(sim_vectors_cuis) # Select appropriate concepts - type_id_inds = np.arange(0, len(sim_data['sim_vectors_type_ids'])) + type_id_inds = np.arange(0, len(sim_data['sim_vectors_type_ids']), dtype=np.int32) if len(type_id_filter) > 0: - type_id_inds = np.array([], dtype=np.int32) + # NOTE: change in numpy 2 + type_id_inds = np.array([], dtype=np.int32) # type: ignore for type_id in type_id_filter: type_id_inds = np.union1d(np.array([ind for ind, type_ids in enumerate(sim_data['sim_vectors_type_ids']) if type_id in type_ids]), type_id_inds) cnt_inds = np.arange(0, len(sim_data['sim_vectors_counts'])) if min_cnt > 0: - cnt_inds = np.where(sim_data['sim_vectors_counts'] >= min_cnt)[0] + # NOTE: change in numpy 2 + cnt_inds = np.where(sim_data['sim_vectors_counts'] >= min_cnt)[0] # type: ignore # Intersect cnt and type_id inds = np.intersect1d(type_id_inds, cnt_inds) diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py index 0d6eb7a64..42a9dab7a 100644 --- a/medcat/config_meta_cat.py +++ b/medcat/config_meta_cat.py @@ -1,7 +1,12 @@ -from typing import Dict, Any +import logging +from typing import Dict, Any, List +from collections.abc import Container from medcat.config import MixingConfig, BaseModel, Optional +logger = logging.getLogger(__name__) + + class General(MixingConfig, BaseModel): """The General part of the MetaCAT config""" device: str = 'cpu' @@ -27,8 +32,22 @@ class General(MixingConfig, BaseModel): """What category is this meta_cat model predicting/training. NB! For these changes to take effect, the pipe would need to be recreated.""" + alternative_category_names: List = [] + """List that stores the variations of possible category names + Example: For Experiencer, the alternate name is Subject + alternative_category_names: ['Experiencer','Subject'] + + In the case that one specified in self.general.category_name parameter does not match the data, this ensures no error is raised and it is automatically mapped + """ category_value2id: Dict = {} """Map from category values to ID, if empty it will be autocalculated during training""" + alternative_class_names: List[List] = [[]] + """List of lists that stores the variations of possible class names for each class mentioned in self.general.category_value2id + + Example: For Presence task, the class names vary across NHS sites. + To accommodate for this, alternative_class_names is populated as: [["Hypothetical (N/A)","Hypothetical"],["Not present (False)","False"],["Present (True)","True"]] + Each sub list contains the possible variations of the given class. + """ vocab_size: Optional[int] = None """Will be set automatically if the tokenizer is provided during meta_cat init""" lowercase: bool = True @@ -64,6 +83,18 @@ class General(MixingConfig, BaseModel): """If set, the spacy span group that the metacat model will assign annotations. Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings""" + def get_applicable_category_name(self, available_names: Container[str]) -> Optional[str]: + if self.category_name in available_names: + return self.category_name + matches = [cat for cat in self.alternative_category_names if cat in available_names] + if len(matches) > 0: + logger.info("The category name provided in the config - '%s' is not present in the data. " + "However, the corresponding name - '%s' from the category_name_mapping has been found. " + "Updating the category name...", self.category_name, *matches) + self.category_name = matches[0] + return self.category_name + return None + class Config: extra = 'allow' validate_assignment = True diff --git a/medcat/config_rel_cat.py b/medcat/config_rel_cat.py index c16735d66..722f6c7cb 100644 --- a/medcat/config_rel_cat.py +++ b/medcat/config_rel_cat.py @@ -1,5 +1,6 @@ +import os import logging -from typing import Dict, Any, List +from typing import Any, Dict, List, Tuple, Union, cast from medcat.config import MixingConfig, BaseModel, Optional @@ -21,10 +22,14 @@ class General(MixingConfig, BaseModel): window_size: int = 300 """Max acceptable dinstance between entities (in characters), care when using this as it can produce sentences that are over 512 tokens (limit is given by tokenizer)""" - mct_export_max_non_rel_sample_size:int = 200 + limit_samples_per_class: int = -1 + """Number of samples per class, this limit is applied for train samples, so if train samples are 100 then test would be 20.""" + addl_rels_max_sample_size:int = 200 """Limit the number of 'Other' samples selected for training/test. This is applied per encountered medcat project, sample_size/num_projects. """ - mct_export_create_addl_rels: bool = False - """When processing relations from a MedCAT export, relations labeled as 'Other' are created from all the annotations pairs available""" + create_addl_rels: bool = False + """When processing relations from a MedCAT export/docs, relations labeled as 'Other' are created from all the annotations pairs available""" + create_addl_rels_by_type: bool = False + """When creating the 'Other' relation class, actually split this class into subclasses based on concept types""" tokenizer_name: str = "bert" """The name of the tokenizer user. @@ -46,21 +51,47 @@ class General(MixingConfig, BaseModel): """Tokenizer. NB! For these changes to take effect, the pipe would need to be recreated.""" - annotation_schema_tag_ids: List = [] + annotation_schema_tag_ids: List = [30522, 30523, 30524, 30525] """If a foreign non-MCAT trainer dataset is used, you can insert your own Rel entity token delimiters into the tokenizer, \ - copy those token IDs here, and also resize your tokenizer embeddings and adjust the hidden_size of the model, this will depend on the number of tokens you introduce""" - labels2idx: Dict = {} - idx2labels: Dict = {} + copy those token IDs here, and also resize your tokenizer embeddings and adjust the hidden_size of the model, this will depend on the number of tokens you introduce + for example: 30522 - [s1], 30523 - [e1], 30524 - [s2], 30525 - [e2], 30526 - [BLANK], 30527 - [ENT1], 30528 - [ENT2], 30529 - [/ENT1], 30530 - [/ENT2] + Please note that the tokenizer special tokens are supposed to be in pairs of two for example [s1] and [e1], [s2] and [e2], the [BLANK] is just an example placeholder token + If you have more than four tokens here then you need to make sure they are present in the text, + otherwise the pipeline will throw an error in the get_annotation_schema_tag() function. + """ + + tokenizer_relation_annotation_special_tokens_tags: List[str] = ["[s1]", "[e1]", "[s2]", "[e2]"] + + tokenizer_other_special_tokens: Dict[str, str] = {"pad_token": "[PAD]"} + """ + The special tokens used by the tokenizer. The {PAD} is for Lllama tokenizer.""" + + labels2idx: Dict[str, int] = {} + idx2labels: Dict[int, str] = {} + pin_memory: bool = True + """If True the data loader will copy the tensors to the GPU pinned memory""" + seed: int = 13 """The seed for random number generation. - NOTE: If used along MetaCAT or additional NER, only one of the seeds will take effect NB! For these changes to take effect, the pipe would need to be recreated.""" task: str = "train" - """The task for RelCAT. + """The task for RelCAT.""" - NB! For these changes to take effect, the pipe would need to be recreated.""" + language: str = "en" + """Used for Spacy lang setting""" + + @classmethod + def convert_keys_to_int(cls, value): + if isinstance(value, dict): + return {int(k): v for k, v in value.items()} + return value + + def __setattr__(self, key: str, value: Any): + if key == "idx2labels" and isinstance(value, dict): + value = self.convert_keys_to_int(value) # Ensure conversion + super().__setattr__(key, value) class Model(MixingConfig, BaseModel): @@ -82,12 +113,18 @@ class Model(MixingConfig, BaseModel): num_directions: int = 2 """2 - bidirectional model, 1 - unidirectional""" + freeze_layers: bool = True + """If we update the weights during training""" + padding_idx: int = -1 emb_grad: bool = True """If True the embeddings will also be trained""" ignore_cpos: bool = False """If set to True center positions will be ignored when calculating representation""" + llama_use_pooled_output: bool = False + """If set to True, used only in Llama model, it will add the extra tensor formed from selecting the max of the last hidden layer""" + class Config: extra = 'allow' validate_assignment = True @@ -98,9 +135,24 @@ class Train(MixingConfig, BaseModel): nclasses: int = 2 """Number of classes that this model will output""" batch_size: int = 25 + """batch size""" nepochs: int = 1 + """Epochs""" lr: float = 1e-4 - adam_epsilon: float = 1e-4 + """Learning rate""" + stratified_batching: bool = False + """Train the model with stratified batching""" + batching_samples_per_class: list = [] + """Number of samples per class in each batch + example for batch size 64: [6,6,6,8,8,8,6,8,8]""" + batching_minority_limit: Union[List[int], int] = 0 + """Maximum number of samples the minority class can have. + Since the minority class elements need to be repeated, this is used to facilitate that + example: batching_samples_per_class - [6,6,6,8,8,8,6,8,8] + batching_minority_limit - 6""" + adam_betas: Tuple[float, float] = (0.9, 0.999) + adam_weight_decay: float = 0 + adam_epsilon: float = 1e-8 test_size: float = 0.2 gradient_acc_steps: int = 1 multistep_milestones: List[int] = [ @@ -109,7 +161,8 @@ class Train(MixingConfig, BaseModel): max_grad_norm: float = 1.0 shuffle_data: bool = True """Used only during training, if set the dataset will be shuffled before train/test split""" - class_weights: Optional[Any] = None + class_weights: Union[List[float], None] = None + enable_class_weights: bool = False score_average: str = "weighted" """What to use for averaging F1/P/R across labels""" auto_save_model: bool = True @@ -129,3 +182,22 @@ class ConfigRelCAT(MixingConfig, BaseModel): class Config: extra = 'allow' validate_assignment = True + + @classmethod + def load(cls, load_path: str = "./") -> "ConfigRelCAT": + """Load the config from a file. + + Args: + load_path (str): Path to RelCAT config. Defaults to "./". + + Returns: + ConfigRelCAT: The loaded config. + """ + config = cls() + if os.path.exists(load_path): + if "config.json" not in load_path: + load_path = os.path.join(load_path, "config.json") + config = cast(ConfigRelCAT, super().load(load_path)) + logging.info("Loaded config.json") + + return config diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index 9182fe00e..e4e647d7c 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -95,8 +95,9 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module: if not config.model.model_freeze_layers: peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16, target_modules=["query", "value"], lora_dropout=0.2) - - model = get_peft_model(model, peft_config) + # Not sure what changed between transformers 4.50.3 and 4.50.1 that made this + # fail for mypy. But as best as I Can tell, it still works just the same + model = get_peft_model(model, peft_config) # type: ignore # model.print_trainable_parameters() logger.info("BERT model used for classification") @@ -243,10 +244,12 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data lowercase=g_config['lowercase']) # Check is the name present - category_name = g_config['category_name'] - if category_name not in data: + category_name = g_config.get_applicable_category_name(data) + if category_name is None: raise Exception( - "The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format( + "The category name does not exist in this json file. You've provided '{}', " + "while the possible options are: {}. Additionally, ensure the populate the " + "'alternative_category_names' attribute to accommodate for variations.".format( category_name, " | ".join(list(data.keys())))) data = data[category_name] @@ -258,27 +261,21 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data if not category_value2id: # Encode the category values full_data, data_undersampled, category_value2id = encode_category_values(data, - category_undersample=self.config.model.category_undersample) - g_config['category_value2id'] = category_value2id + category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names']) else: # We already have everything, just get the data full_data, data_undersampled, category_value2id = encode_category_values(data, existing_category_value2id=category_value2id, - category_undersample=self.config.model.category_undersample) - g_config['category_value2id'] = category_value2id - # Make sure the config number of classes is the same as the one found in the data - if len(category_value2id) != self.config.model['nclasses']: - logger.warning( - "The number of classes set in the config is not the same as the one found in the data: %d vs %d",self.config.model['nclasses'], len(category_value2id)) - logger.warning("Auto-setting the nclasses value in config and rebuilding the model.") - self.config.model['nclasses'] = len(category_value2id) + category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names']) + g_config['category_value2id'] = category_value2id + self.config.model['nclasses'] = len(category_value2id) if self.config.model.phase_number == 2 and save_dir_path is not None: model_save_path = os.path.join(save_dir_path, 'model.dat') device = torch.device(g_config['device']) try: self.model.load_state_dict(torch.load(model_save_path, map_location=device)) - logger.info("Model state loaded from dict for 2 phase learning") + logger.info("Training model for Phase 2, with model dict loaded from disk") except FileNotFoundError: raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.") @@ -295,6 +292,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data if not t_config['auto_save_model']: logger.info("For phase 1, model state has to be saved. Saving model...") t_config['auto_save_model'] = True + logger.info("Training model for Phase 1 now...") report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path) @@ -342,8 +340,8 @@ def eval(self, json_path: str) -> Dict: lowercase=g_config['lowercase']) # Check is the name there - category_name = g_config['category_name'] - if category_name not in data: + category_name = g_config.get_applicable_category_name(data) + if category_name is None: raise Exception("The category name does not exist in this json file.") data = data[category_name] @@ -415,7 +413,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA tokenizer = TokenizerWrapperBPE.load(save_dir_path) elif config.general['tokenizer_name'] == 'bert-tokenizer': from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT - tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model['model_variant']) + tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model.model_variant) # Create meta_cat meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config) diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 392b4a94d..715771027 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -2,6 +2,7 @@ import json import logging import datasets +import torch from spacy.tokens import Doc from datetime import datetime from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable, Type @@ -69,7 +70,8 @@ def __init__(self, cdb, config: Optional[ConfigTransformersNER] = None, eval_accumulation_steps=1, gradient_accumulation_steps=4, # We want to get to bs=4 do_eval=True, - evaluation_strategy='epoch', # type: ignore + # eval_strategy over evaluation_strategy since trf==4.46 (apperently) + eval_strategy='epoch', # type: ignore logging_strategy='epoch', # type: ignore save_strategy='epoch', # type: ignore metric_for_best_model='eval_recall', # Can be changed if our preference is not recall but precision or f1 @@ -89,8 +91,22 @@ def create_eval_pipeline(self): self.ner_pipe.tokenizer._in_target_context_manager = False if not hasattr(self.ner_pipe.tokenizer, 'split_special_tokens'): # NOTE: this will fix the DeID model(s) created with transformers before 4.42 - # and allow them to run with later transforemrs + # and allow them to run with later transformers self.ner_pipe.tokenizer.split_special_tokens = False + if not hasattr(self.ner_pipe.tokenizer, 'pad_token') and hasattr(self.ner_pipe.tokenizer, '_pad_token'): + # NOTE: This will fix the DeID model(s) created with transformers before 4.47 + # and allow them to run with later transformmers versions + # In 4.47 the special tokens started to be used differently, yet our saved model + # is not aware of that. So we need to explicitly fix that. + special_tokens_map = self.ner_pipe.tokenizer.__dict__.get('_special_tokens_map', {}) + for name in self.ner_pipe.tokenizer.SPECIAL_TOKENS_ATTRIBUTES: + # previously saved in (e.g) _pad_token + prev_val = getattr(self.ner_pipe.tokenizer, f"_{name}") + # now saved in the special tokens map by its name + special_tokens_map[name] = prev_val + # the map is saved in __dict__ explicitly, and it is later used in __getattr__ of the base class. + self.ner_pipe.tokenizer.__dict__['_special_tokens_map'] = special_tokens_map + self.ner_pipe.device = self.model.device self._consecutive_identical_failures = 0 self._last_exception: Optional[Tuple[str, Type[Exception]]] = None @@ -161,7 +177,7 @@ def train(self, ignore_extra_labels=False, dataset=None, meta_requirements=None, - trainer_callbacks: Optional[List[TrainerCallback]]=None) -> Tuple: + trainer_callbacks: Optional[List[Callable[[Trainer], TrainerCallback]]] = None) -> Tuple: """Train or continue training a model give a json_path containing a MedCATtrainer export. It will continue training if an existing model is loaded or start new training if the model is blank/new. @@ -173,10 +189,13 @@ def train(self, labels that did not exist in the old model. dataset: Defaults to None. meta_requirements: Defaults to None - trainer_callbacks (List[TrainerCallback]): + trainer_callbacks (List[Callable[[Trainer], TrainerCallback]]]): A list of trainer callbacks for collecting metrics during the training at the client side. The transformers Trainer object will be passed in when each callback is called. + Raises: + ValueError: If something went wrong with model save path. + Returns: Tuple: The dataframe, examples, and the dataset """ @@ -212,7 +231,9 @@ def train(self, if self.model.num_labels != len(self.tokenizer.label_map): logger.warning("The dataset contains labels we've not seen before, model is being reinitialized") logger.warning("Model: {} vs Dataset: {}".format(self.model.num_labels, len(self.tokenizer.label_map))) - self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'], num_labels=len(self.tokenizer.label_map)) + self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'], + num_labels=len(self.tokenizer.label_map), + ignore_mismatched_sizes=True) self.tokenizer.cui2name = {k:self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()} self.model.config.id2label = {v:k for k,v in self.tokenizer.label_map.items()} @@ -237,7 +258,9 @@ def train(self, tokenizer=None) if trainer_callbacks: for callback in trainer_callbacks: - trainer.add_callback(callback(trainer)) + # No idea why mypy isn't picking up the method. + # It most certainly does exist + trainer.add_callback(callback(trainer)) # type: ignore trainer.train() # type: ignore @@ -245,7 +268,11 @@ def train(self, self.config.general.last_train_on = datetime.now().timestamp() # type: ignore # Save everything - self.save(save_dir_path=os.path.join(self.training_arguments.output_dir, 'final_model')) + output_dir = self.training_arguments.output_dir + if output_dir is None: + # NOTE: this shouldn't really happen, but we'll do this for type safety + raise ValueError("Output path should not be None!") + self.save(save_dir_path=os.path.join(output_dir, 'final_model')) # Run an eval step and return metrics p = trainer.predict(encoded_dataset['test']) # type: ignore @@ -289,7 +316,7 @@ def eval(self, json_path: Union[str, list, None] = None, dataset=None, ignore_ex p = trainer.predict(encoded_dataset) # type: ignore df, examples = metrics(p, return_df=True, tokenizer=self.tokenizer, dataset=encoded_dataset) - return df, examples + return df, examples, dataset def save(self, save_dir_path: str) -> None: """Save all components of this class to a file @@ -316,6 +343,63 @@ def save(self, save_dir_path: str) -> None: # This is everything we need to save from the class, we do not #save the class itself. + def expand_model_with_concepts(self, cui2preferred_name: Dict[str, str], use_avg_init: bool = True) -> None: + """Expand the model with new concepts and their preferred names, which requires subsequent retraining on the model. + + Args: + cui2preferred_name(Dict[str, str]): + Dictionary where each key is the literal ID of the concept to be added and each value is its preferred name. + use_avg_init(bool): + Whether to use the average of existing weights or biases as the initial value for the new concept. Defaults to True. + """ + + avg_weight = torch.mean(self.model.classifier.weight, dim=0, keepdim=True) + avg_bias = torch.mean(self.model.classifier.bias, dim=0, keepdim=True) + + new_cuis = set() + for label, preferred_name in cui2preferred_name.items(): + if label in self.model.config.label2id.keys(): + logger.warning("Concept ID '%s' already exists in the model, skipping...", label) + continue + + sname = preferred_name.lower().replace(" ", "~") + new_names = { + sname: { + "tokens": [], + "snames": [sname], + "raw_name": preferred_name, + "is_upper": True + } + } + self.cdb.add_names(cui=label, names=new_names, name_status="P", full_build=True) + + new_label_id = sorted(self.model.config.label2id.values())[-1] + 1 + self.model.config.label2id[label] = new_label_id + self.model.config.id2label[new_label_id] = label + self.tokenizer.label_map[label] = new_label_id + self.tokenizer.cui2name = {k: self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()} + + if use_avg_init: + self.model.classifier.weight = torch.nn.Parameter( + torch.cat((self.model.classifier.weight, avg_weight), 0) + ) + self.model.classifier.bias = torch.nn.Parameter( + torch.cat((self.model.classifier.bias, avg_bias), 0) + ) + else: + self.model.classifier.weight = torch.nn.Parameter( + torch.cat((self.model.classifier.weight, torch.randn(1, self.model.config.hidden_size)), 0) + ) + self.model.classifier.bias = torch.nn.Parameter( + torch.cat((self.model.classifier.bias, torch.randn(1)), 0) + ) + self.model.num_labels += 1 + self.model.classifier.out_features += 1 + + new_cuis.add(label) + + logger.info("Model expanded with the new concept(s): %s and shall be retrained before use.", str(new_cuis)) + @classmethod def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "TransformersNER": """Load a meta_cat object. diff --git a/medcat/rel_cat.py b/medcat/rel_cat.py index 374997927..b8d72f663 100644 --- a/medcat/rel_cat.py +++ b/medcat/rel_cat.py @@ -1,31 +1,68 @@ import json import logging import os +import random +import numpy +from sklearn.utils import compute_class_weight import torch.optim import torch import torch.nn as nn +import spacy from tqdm import tqdm from datetime import date, datetime -from transformers import BertConfig from medcat.cdb import CDB from medcat.config import Config from medcat.config_rel_cat import ConfigRelCAT from medcat.pipeline.pipe_runner import PipeRunner -from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT -from spacy.tokens import Doc -from typing import Dict, Iterable, Iterator, cast -from transformers import AutoTokenizer -from torch.utils.data import DataLoader -from torch.optim import Adam +from spacy.tokens import Doc, Span +from typing import Dict, Iterable, Iterator, List +from torch.utils.data import DataLoader, Sampler +from torch.optim import AdamW from torch.optim.lr_scheduler import MultiStepLR +from medcat.utils.relation_extraction.base_component import BaseComponent_RelationExtraction from medcat.utils.meta_cat.ml_utils import set_all_seeds -from medcat.utils.relation_extraction.models import BertModel_RelationExtraction -from medcat.utils.relation_extraction.pad_seq import Pad_Sequence -from medcat.utils.relation_extraction.utils import create_tokenizer_pretrain, load_results, load_state, save_results, save_state, split_list_train_test_by_class +from medcat.utils.relation_extraction.ml_utils import load_results, load_state, save_results, save_state, split_list_train_test_by_class from medcat.utils.relation_extraction.rel_dataset import RelData +class BalancedBatchSampler(Sampler): + def __init__(self, dataset, classes, batch_size, max_samples, max_minority): + self.dataset = dataset + self.classes = classes + self.batch_size = batch_size + self.num_classes = len(classes) + self.indices = list(range(len(dataset))) + + self.max_minority = max_minority + + self.max_samples_per_class = max_samples + + def __len__(self): + return (len(self.indices) + self.batch_size - 1) // self.batch_size + + def __iter__(self): + batch_counter = 0 + indices = self.indices.copy() + while batch_counter != self.__len__(): + batch = [] + + class_counts = {c: 0 for c in self.classes} + while len(batch) < self.batch_size: + + index = random.choice(indices) + label = self.dataset[index][2].numpy().tolist()[0] # Assuming label is at index 1 + if class_counts[label] < self.max_samples_per_class[label]: + batch.append(index) + class_counts[label] += 1 + if self.max_samples_per_class[label] > self.max_minority: + indices.remove(index) + + print("class_counts:", class_counts) + yield batch + batch_counter += 1 + + class RelCAT(PipeRunner): """The RelCAT class used for training 'Relation-Annotation' models, i.e., annotation of relations between clinical concepts. @@ -38,7 +75,7 @@ class RelCAT(PipeRunner): a BERT-style model. For now, only BERT models are supported. config (ConfigRelCAT): - the configuration for RelCAT. Param descriptions available in ConfigRelCAT docs. + the configuration for RelCAT. Param descriptions available in ConfigRelCAT class docs. task (str, optional): What task is this model supposed to handle. Defaults to "train" init_model (bool, optional): loads default model. Defaults to False. @@ -50,67 +87,30 @@ class RelCAT(PipeRunner): log = logging.getLogger(__name__) - def __init__(self, cdb: CDB, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT = ConfigRelCAT(), task="train", init_model=False): - self.config = config - self.tokenizer: TokenizerWrapperBERT = tokenizer - self.cdb = cdb - - logging.basicConfig(level=self.config.general.log_level) - self.log.setLevel(self.config.general.log_level) - - self.is_cuda_available = torch.cuda.is_available() - self.device = torch.device( - "cuda" if self.is_cuda_available and self.config.general.device != "cpu" else "cpu") + def __init__(self, cdb: CDB, config: ConfigRelCAT = ConfigRelCAT(), task="train", init_model=False): - self.model_config = BertConfig() - self.model: BertModel_RelationExtraction + self.component: BaseComponent_RelationExtraction = BaseComponent_RelationExtraction() # type: ignore self.task: str = task self.checkpoint_path: str = "./" - self.optimizer: Adam = None # type: ignore - self.scheduler: MultiStepLR = None # type: ignore - self.best_f1: float = 0.0 - self.epoch: int = 0 - - self.pad_id = self.tokenizer.hf_tokenizers.pad_token_id - self.padding_seq = Pad_Sequence(seq_pad_value=self.pad_id, - label_pad_value=self.pad_id) set_all_seeds(config.general.seed) if init_model: - self._get_model() + self.component = BaseComponent_RelationExtraction(config=config, + task=task, + init_model=True) - def save(self, save_path: str) -> None: - """ Saves model and its dependencies to specified save_path folder. - The CDB is obviously not saved, it is however necessary to save the tokenizer used. + self.cdb = cdb + logging.basicConfig(level=self.component.relcat_config.general.log_level) + self.log.setLevel(self.component.relcat_config.general.log_level) - Args: - save_path (str): folder path in which to save the model & deps. - """ + self.is_cuda_available = torch.cuda.is_available() + self.device = torch.device( + "cuda" if self.is_cuda_available and self.component.relcat_config.general.device != "cpu" else "cpu") - assert self.config is not None - self.config.save(os.path.join(save_path, "config.json")) - - assert self.model_config is not None - self.model_config.vocab_size = self.tokenizer.hf_tokenizers.vocab_size - self.model_config.to_json_file( - os.path.join(save_path, "model_config.json")) - assert self.tokenizer is not None - self.tokenizer.save(os.path.join(save_path)) - - assert self.model is not None - self.model.bert_model.resize_token_embeddings( - self.tokenizer.hf_tokenizers.vocab_size) - save_state(self.model, optimizer=self.optimizer, scheduler=self.scheduler, epoch=self.epoch, best_f1=self.best_f1, - path=save_path, model_name=self.config.general.model_name, - task=self.task, is_checkpoint=False, final_export=True) - - def _get_model(self): - """ Used only for model initialisation. - """ - self.model = BertModel_RelationExtraction(pretrained_model_name_or_path="bert-base-uncased", - relcat_config=self.config, - model_config=self.model_config) + + def save(self, save_path: str = "./") -> None: + self.component.save(save_path=save_path) @classmethod def load(cls, load_path: str = "./") -> "RelCAT": @@ -122,116 +122,33 @@ def load(cls, load_path: str = "./") -> "RelCAT": cls.log.info("The default CDB file name 'cdb.dat' doesn't exist in the specified path, you will need to load & set \ a CDB manually via rel_cat.cdb = CDB.load('path') ") - config_path = os.path.join(load_path, "config.json") - config = ConfigRelCAT() - if os.path.exists(config_path): - config = cast(ConfigRelCAT, ConfigRelCAT.load( - os.path.join(load_path, "config.json"))) - cls.log.info("Loaded config.json") - - tokenizer = None - tokenizer_path = os.path.join(load_path, config.general.tokenizer_name) - - if "bert" in config.general.tokenizer_name: - tokenizer_path = load_path - - if os.path.exists(tokenizer_path): - tokenizer = TokenizerWrapperBERT.load(tokenizer_path) - - cls.log.info("Tokenizer loaded from:" + tokenizer_path) - elif config.general.model_name: - cls.log.info("Attempted to load Tokenizer from path:" + tokenizer_path + - ", but it doesn't exist, loading default toknizer from model_name config.general.model_name:" + config.general.model_name) - tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained(pretrained_model_name_or_path=config.general.model_name), - max_seq_length=config.general.max_seq_length, - add_special_tokens=config.general.tokenizer_special_tokens - ) - create_tokenizer_pretrain(tokenizer, tokenizer_path) - else: - cls.log.info("Attempted to load Tokenizer from path:" + tokenizer_path + - ", but it doesn't exist, loading default toknizer from model_name config.general.model_name:bert-base-uncased") - tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained(pretrained_model_name_or_path="bert-base-uncased"), - max_seq_length=config.general.max_seq_length, - add_special_tokens=config.general.tokenizer_special_tokens - ) - - model_config = BertConfig() - model_config_path = os.path.join(load_path, "model_config.json") - - if os.path.exists(model_config_path): - cls.log.info("Loaded config from : " + model_config_path) - model_config = BertConfig.from_json_file(model_config_path) # type: ignore - else: - try: - model_config = BertConfig.from_pretrained( - pretrained_model_name_or_path=config.general.model_name, num_hidden_layers=config.model.hidden_layers) # type: ignore - except Exception as e: - cls.log.error("%s", str(e)) - cls.log.info("Config for HF model not found: " + - config.general.model_name + ". Using bert-base-uncased.") - model_config = BertConfig.from_pretrained( - pretrained_model_name_or_path="bert-base-uncased") # type: ignore - - model_config.vocab_size = tokenizer.hf_tokenizers.vocab_size - - rel_cat = cls(cdb=cdb, config=config, - tokenizer=tokenizer, - task=config.general.task) - - rel_cat.model_config = model_config - - device = torch.device("cuda" if torch.cuda.is_available( - ) and config.general.device != "cpu" else "cpu") - - try: - model_path = os.path.join(load_path, "model.dat") - - if os.path.exists(os.path.join(load_path, config.general.model_name)): - rel_cat.model = BertModel_RelationExtraction(pretrained_model_name_or_path=config.general.model_name, - relcat_config=config, - model_config=model_config) - else: - rel_cat.model = BertModel_RelationExtraction( - pretrained_model_name_or_path="", - relcat_config=config, - model_config=model_config) - rel_cat.model.load_state_dict( - torch.load(model_path, map_location=device)) - - cls.log.info("Loaded HF model : " + config.general.model_name) - except Exception as e: - cls.log.error("%s", str(e)) - cls.log.error("Failed to load specified HF model, defaulting to 'bert-base-uncased', loading...") - rel_cat.model = BertModel_RelationExtraction( - pretrained_model_name_or_path="bert-base-uncased", - relcat_config=config, - model_config=model_config) - - rel_cat.model.bert_model.resize_token_embeddings((len(tokenizer.hf_tokenizers))) - - rel_cat.optimizer = None # type: ignore - rel_cat.scheduler = None # type: ignore - - rel_cat.epoch, rel_cat.best_f1 = load_state(rel_cat.model, rel_cat.optimizer, rel_cat.scheduler, path=load_path, - model_name=config.general.model_name, - file_prefix=config.general.task, - device=device, - config=config) + component = BaseComponent_RelationExtraction.load(pretrained_model_name_or_path=load_path) + + device = torch.device("cuda" if torch.cuda.is_available() and component.relcat_config.general.device != "cpu" else "cpu") + + rel_cat = RelCAT(cdb=cdb, config=component.relcat_config, task=component.task) + rel_cat.device = device + rel_cat.component = component return rel_cat + def __call__(self, doc: Doc) -> Doc: + doc = next(self.pipe(iter([doc]))) + return doc + def _create_test_train_datasets(self, data: Dict, split_sets:bool = False): train_data: Dict = {} test_data: Dict = {} if split_sets: train_data["output_relations"], test_data["output_relations"] = split_list_train_test_by_class(data["output_relations"], - test_size=self.config.train.test_size) + test_size=self.component.relcat_config.train.test_size, shuffle=self.component.relcat_config.train.shuffle_data, + sample_limit=self.component.relcat_config.general.limit_samples_per_class) test_data_label_names = [rec[4] for rec in test_data["output_relations"]] test_data["nclasses"], test_data["labels2idx"], test_data["idx2label"] = RelData.get_labels( - test_data_label_names, self.config) + test_data_label_names, self.component.relcat_config) for idx in range(len(test_data["output_relations"])): test_data["output_relations"][idx][5] = test_data["labels2idx"][test_data["output_relations"][idx][4]] @@ -247,7 +164,7 @@ def _create_test_train_datasets(self, data: Dict, split_sets:bool = False): for rec in train_data["output_relations"]] train_data["nclasses"], train_data["labels2idx"], train_data["idx2label"] = RelData.get_labels( - train_data_label_names, self.config) + train_data_label_names, self.component.relcat_config) for idx in range(len(train_data["output_relations"])): train_data["output_relations"][idx][5] = train_data["labels2idx"][train_data["output_relations"][idx][4]] @@ -257,18 +174,18 @@ def _create_test_train_datasets(self, data: Dict, split_sets:bool = False): def train(self, export_data_path:str = "", train_csv_path:str = "", test_csv_path:str = "", checkpoint_path: str = "./"): if self.is_cuda_available: - self.log.info("Training on device:", - torch.cuda.get_device_name(0), self.device) + self.log.info("Training on device:" + + str(torch.cuda.get_device_name(0)) + str(self.device)) - self.model = self.model.to(self.device) + self.component.model = self.component.model.to(self.device) # resize vocab just in case more tokens have been added - self.model_config.vocab_size = len(self.tokenizer.hf_tokenizers) + self.component.model_config.vocab_size = self.component.tokenizer.get_size() train_rel_data = RelData( - cdb=self.cdb, config=self.config, tokenizer=self.tokenizer) + cdb=self.cdb, config=self.component.relcat_config, tokenizer=self.component.tokenizer) test_rel_data = RelData( - cdb=self.cdb, config=self.config, tokenizer=self.tokenizer) + cdb=self.cdb, config=self.component.relcat_config, tokenizer=self.component.tokenizer) if train_csv_path != "": if test_csv_path != "": @@ -287,51 +204,77 @@ def train(self, export_data_path:str = "", train_csv_path:str = "", test_csv_pat train_rel_data.dataset, test_rel_data.dataset = self._create_test_train_datasets( train_rel_data.create_relations_from_export(export_data), split_sets=True) else: - raise ValueError("NO DATA HAS BEEN PROVIDED (JSON/CSV/spacy_DOCS)") + raise ValueError("NO DATA HAS BEEN PROVIDED (MedCAT Trainer export JSON/CSV/spacy_DOCS)") train_dataset_size = len(train_rel_data) - batch_size = train_dataset_size if train_dataset_size < self.config.train.batch_size else self.config.train.batch_size - train_dataloader = DataLoader(train_rel_data, batch_size=batch_size, shuffle=self.config.train.shuffle_data, - num_workers=0, collate_fn=self.padding_seq, - pin_memory=self.config.general.pin_memory) - test_dataset_size = len(test_rel_data) - test_batch_size = test_dataset_size if test_dataset_size < self.config.train.batch_size else self.config.train.batch_size - test_dataloader = DataLoader(test_rel_data, batch_size=test_batch_size, shuffle=self.config.train.shuffle_data, - num_workers=0, collate_fn=self.padding_seq, - pin_memory=self.config.general.pin_memory) + batch_size = train_dataset_size if train_dataset_size < self.component.relcat_config.train.batch_size else self.component.relcat_config.train.batch_size + + # to use stratified batching + if self.component.relcat_config.train['stratified_batching']: + sampler = BalancedBatchSampler(train_rel_data, [i for i in range(self.component.relcat_config.train.nclasses)], + batch_size, + self.component.relcat_config.train['batching_samples_per_class'], + self.component.relcat_config.train['batching_minority_limit']) - criterion = nn.CrossEntropyLoss(ignore_index=-1) + train_dataloader = DataLoader(train_rel_data,num_workers=0, collate_fn=self.component.padding_seq, + batch_sampler=sampler,pin_memory=self.component.relcat_config.general.pin_memory) + else: + train_dataloader = DataLoader(train_rel_data, batch_size=batch_size, + shuffle=self.component.relcat_config.train.shuffle_data, + num_workers=0, + collate_fn=self.component.padding_seq, + pin_memory=self.component.relcat_config.general.pin_memory) + + test_dataset_size = len(test_rel_data) + test_batch_size = test_dataset_size if test_dataset_size < self.component.relcat_config.train.batch_size else self.component.relcat_config.train.batch_size + test_dataloader = DataLoader(test_rel_data, + batch_size=test_batch_size, + shuffle=self.component.relcat_config.train.shuffle_data, + num_workers=0, + collate_fn=self.component.padding_seq, + pin_memory=self.component.relcat_config.general.pin_memory) + + if self.component.relcat_config.train.class_weights is not None and self.component.relcat_config.train.enable_class_weights: + criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(numpy.asarray(self.component.relcat_config.train.class_weights)).to(self.device)) + elif self.component.relcat_config.train.enable_class_weights: + all_class_lbl_ids = [rec[5] for rec in train_rel_data.dataset["output_relations"]] + self.component.relcat_config.train.class_weights = compute_class_weight(class_weight="balanced", + classes=numpy.unique(all_class_lbl_ids), + y=all_class_lbl_ids).tolist() + criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(self.component.relcat_config.train.class_weights).to(self.device)) + else: + criterion = nn.CrossEntropyLoss() - if self.optimizer is None: - parameters = filter(lambda p: p.requires_grad, self.model.parameters()) - self.optimizer = Adam(parameters, lr=self.config.train.lr) + if self.component.optimizer is None: + parameters = filter(lambda p: p.requires_grad, self.component.model.parameters()) + self.component.optimizer = AdamW(parameters, lr=self.component.relcat_config.train.lr, weight_decay=self.component.relcat_config.train.adam_weight_decay, + betas=self.component.relcat_config.train.adam_betas, eps=self.component.relcat_config.train.adam_epsilon) - if self.scheduler is None: - self.scheduler = MultiStepLR( - self.optimizer, milestones=self.config.train.multistep_milestones, - gamma=self.config.train.multistep_lr_gamma) # type: ignore + if self.component.scheduler is None: + self.component.scheduler = MultiStepLR( + self.component.optimizer, milestones=self.component.relcat_config.train.multistep_milestones, + gamma=self.component.relcat_config.train.multistep_lr_gamma) # type: ignore self.epoch, self.best_f1 = load_state( - self.model, self.optimizer, self.scheduler, load_best=False, path=checkpoint_path, device=self.device) + self.component.model, self.component.optimizer, self.component.scheduler, load_best=False, path=checkpoint_path, relcat_config=self.component.relcat_config) self.log.info("Starting training process...") losses_per_epoch, accuracy_per_epoch, f1_per_epoch = load_results( path=checkpoint_path) - if train_rel_data.dataset["nclasses"] > self.config.train.nclasses: - self.config.train.nclasses = train_rel_data.dataset["nclasses"] - self.model.relcat_config.train.nclasses = self.config.train.nclasses + if train_rel_data.dataset["nclasses"] > self.component.relcat_config.train.nclasses: + self.component.relcat_config.train.nclasses = train_rel_data.dataset["nclasses"] + self.component.model.relcat_config.train.nclasses = self.component.relcat_config.train.nclasses - self.config.general.labels2idx.update( - train_rel_data.dataset["labels2idx"]) - self.config.general.idx2labels = { - int(v): k for k, v in self.config.general["labels2idx"].items()} + self.component.relcat_config.general.labels2idx.update(train_rel_data.dataset["labels2idx"]) + self.component.relcat_config.general.idx2labels = { + int(v): k for k, v in self.component.relcat_config.general["labels2idx"].items()} - gradient_acc_steps = self.config.train.gradient_acc_steps - max_grad_norm = self.config.train.max_grad_norm + gradient_acc_steps = self.component.relcat_config.train.gradient_acc_steps + max_grad_norm = self.component.relcat_config.train.max_grad_norm - _epochs = self.epoch + self.config.train.nepochs + _epochs = self.epoch + self.component.relcat_config.train.nepochs for epoch in range(0, _epochs): start_time = datetime.now().time() @@ -346,21 +289,21 @@ def train(self, export_data_path:str = "", train_csv_path:str = "", test_csv_pat pbar = tqdm(total=train_dataset_size) for i, data in enumerate(train_dataloader, 0): - self.model.train() - self.model.zero_grad() + self.component.model.train() + self.component.model.zero_grad() current_batch_size = len(data[0]) token_ids, e1_e2_start, labels, _, _ = data attention_mask = ( - token_ids != self.pad_id).float().to(self.device) + token_ids != self.component.pad_id).float().to(self.device) token_type_ids = torch.zeros( (token_ids.shape[0], token_ids.shape[1])).long().to(self.device) labels = labels.to(self.device) - model_output, classification_logits = self.model( + model_output, classification_logits = self.component.model( input_ids=token_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, @@ -368,7 +311,7 @@ def train(self, export_data_path:str = "", train_csv_path:str = "", test_csv_pat ) batch_loss = criterion( - classification_logits.view(-1, self.config.train.nclasses).to(self.device), labels.squeeze(1)) + classification_logits.view(-1, self.component.relcat_config.train.nclasses).to(self.device), labels.squeeze(1)) batch_loss.backward() batch_loss = batch_loss / gradient_acc_steps @@ -382,12 +325,11 @@ def train(self, export_data_path:str = "", train_csv_path:str = "", test_csv_pat accuracy_per_batch.append(batch_acc) torch.nn.utils.clip_grad_norm_( - self.model.parameters(), max_grad_norm) + self.component.model.parameters(), max_grad_norm) if (i % gradient_acc_steps) == 0: - self.optimizer.step() - self.scheduler.step() - + self.component.optimizer.step() + self.component.scheduler.step() if ((i + 1) % current_batch_size == 0): self.log.debug( "[Epoch: %d, loss per batch, accuracy per batch: %.3f, %.3f, average total loss %.3f , total loss %.3f]" % @@ -415,11 +357,11 @@ def train(self, export_data_path:str = "", train_csv_path:str = "", test_csv_pat self.log.info( "======================== TRAIN SET TEST RESULTS ========================") - _ = self.evaluate_results(train_dataloader, self.pad_id) + _ = self.evaluate_results(train_dataloader, self.component.pad_id) self.log.info( "======================== TEST SET TEST RESULTS ========================") - results = self.evaluate_results(test_dataloader, self.pad_id) + results = self.evaluate_results(test_dataloader, self.component.pad_id) f1_per_epoch.append(results['f1']) @@ -431,14 +373,14 @@ def train(self, export_data_path:str = "", train_csv_path:str = "", test_csv_pat if len(f1_per_epoch) > 0 and f1_per_epoch[-1] > self.best_f1: self.best_f1 = f1_per_epoch[-1] - save_state(self.model, self.optimizer, self.scheduler, self.epoch, self.best_f1, checkpoint_path, - model_name=self.config.general.model_name, task=self.task, is_checkpoint=False) + save_state(self.component.model, self.component.optimizer, self.component.scheduler, self.epoch, self.best_f1, checkpoint_path, + model_name=self.component.relcat_config.general.model_name, task=self.task, is_checkpoint=False) if (epoch % 1) == 0: save_results({"losses_per_epoch": losses_per_epoch, "accuracy_per_epoch": accuracy_per_epoch, "f1_per_epoch": f1_per_epoch, "epoch": epoch}, file_prefix="train", path=checkpoint_path) - save_state(self.model, self.optimizer, self.scheduler, self.epoch, self.best_f1, checkpoint_path, - model_name=self.config.general.model_name, task=self.task, is_checkpoint=True) + save_state(self.component.model, self.component.optimizer, self.component.scheduler, self.epoch, self.best_f1, checkpoint_path, + model_name=self.component.relcat_config.general.model_name, task=self.task, is_checkpoint=True) def evaluate_(self, output_logits, labels, ignore_idx): # ignore index (padding) when calculating accuracy @@ -464,6 +406,7 @@ def evaluate_(self, output_logits, labels, ignore_idx): for label in unique_labels: stat_per_label[label] = { "tp": 0, "fp": 0, "tn": 0, "fn": 0, "f1": 0.0, "acc": 0.0, "prec": 0.0, "recall": 0.0} + for true_label_idx in range(len(true_labels)): if true_labels[true_label_idx] == label: if pred_labels[true_label_idx] == label: @@ -520,11 +463,15 @@ def evaluate_(self, output_logits, labels, ignore_idx): def evaluate_results(self, data_loader, pad_id): self.log.info("Evaluating test samples...") - criterion = nn.CrossEntropyLoss(ignore_index=-1) + if self.component.relcat_config.train.class_weights is not None and self.component.relcat_config.train.enable_class_weights: + criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(self.component.relcat_config.train.class_weights).to(self.device)) + else: + criterion = nn.CrossEntropyLoss() + total_loss, total_acc, total_f1, total_recall, total_precision = 0.0, 0.0, 0.0, 0.0, 0.0 all_batch_stats_per_label = [] - self.model.eval() + self.component.model.eval() for i, data in enumerate(data_loader): with torch.no_grad(): @@ -535,12 +482,12 @@ def evaluate_results(self, data_loader, pad_id): labels = labels.to(self.device) - model_output, pred_classification_logits = self.model(token_ids, token_type_ids=token_type_ids, + model_output, pred_classification_logits = self.component.model(token_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, Q=None, e1_e2_start=e1_e2_start) batch_loss = criterion(pred_classification_logits.view( - -1, self.config.train.nclasses).to(self.device), labels.squeeze(1)) + -1, self.component.relcat_config.train.nclasses).to(self.device), labels.squeeze(1)) total_loss += batch_loss.item() batch_accuracy, batch_recall, batch_precision, batch_f1, pred_labels, true_labels, batch_stats_per_label = \ @@ -590,7 +537,7 @@ def evaluate_results(self, data_loader, pad_id): self.log.info("----------------------- class stats -----------------------") for label_id, stat_dict in final_stats_per_label.items(): self.log.info("label: %s | f1: %0.3f | prec : %0.3f | acc: %0.3f | recall: %0.3f " % ( - self.config.general.idx2labels[label_id], + self.component.relcat_config.general.idx2labels[label_id], stat_dict["f1"], stat_dict["prec"], stat_dict["acc"], @@ -604,17 +551,19 @@ def evaluate_results(self, data_loader, pad_id): def pipe(self, stream: Iterable[Doc], *args, **kwargs) -> Iterator[Doc]: predict_rel_dataset = RelData( - cdb=self.cdb, config=self.config, tokenizer=self.tokenizer) + cdb=self.cdb, config=self.component.relcat_config, tokenizer=self.component.tokenizer) - self.model = self.model.to(self.device) # type: ignore + self.component.model = self.component.model.to(self.device) # type: ignore for doc_id, doc in enumerate(stream, 0): predict_rel_dataset.dataset, _ = self._create_test_train_datasets( - predict_rel_dataset.create_base_relations_from_doc(doc, str(doc_id)), False) + data=predict_rel_dataset.create_base_relations_from_doc(doc, doc_id=str(doc_id)), + split_sets=False) - predict_dataloader = DataLoader(predict_rel_dataset, shuffle=False, batch_size=self.config.train.batch_size, - num_workers=0, collate_fn=self.padding_seq, - pin_memory=self.config.general.pin_memory) + predict_dataloader = DataLoader(dataset=predict_rel_dataset, shuffle=False, + batch_size=self.component.relcat_config.train.batch_size, + num_workers=0, collate_fn=self.component.padding_seq, + pin_memory=self.component.relcat_config.general.pin_memory) total_rel_found = len(predict_rel_dataset.dataset["output_relations"]) rel_idx = -1 @@ -628,11 +577,11 @@ def pipe(self, stream: Iterable[Doc], *args, **kwargs) -> Iterator[Doc]: with torch.no_grad(): token_ids, e1_e2_start, labels, _, _ = data - attention_mask = (token_ids != self.pad_id).float() + attention_mask = (token_ids != self.component.pad_id).float().to(self.device) token_type_ids = torch.zeros( - token_ids.shape[0], token_ids.shape[1]).long() + token_ids.shape[0], token_ids.shape[1]).long().to(self.device) - model_output, pred_classification_logits = self.model( + model_output, pred_classification_logits = self.component.model( token_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, e1_e2_start=e1_e2_start) # type: ignore @@ -641,9 +590,9 @@ def pipe(self, stream: Iterable[Doc], *args, **kwargs) -> Iterator[Doc]: confidence = torch.softmax( pred_rel_logits, dim=0).max(0) - predicted_label_id = confidence[1].item() + predicted_label_id = int(confidence[1].item()) - doc._.relations.append({"relation": self.config.general.idx2labels[predicted_label_id], + doc._.relations.append({"relation": self.component.relcat_config.general.idx2labels[predicted_label_id], "label_id": predicted_label_id, "ent1_text": predict_rel_dataset.dataset["output_relations"][rel_idx][ 2], @@ -661,6 +610,49 @@ def pipe(self, stream: Iterable[Doc], *args, **kwargs) -> Iterator[Doc]: yield doc - def __call__(self, doc: Doc) -> Doc: - doc = next(self.pipe(iter([doc]))) + def predict_text_with_anns(self, text: str, annotations: List[Dict]) -> Doc: + """ Creates spacy doc from text and annotation input. Predicts using self.__call__ + + Args: + text (str): text + annotations (Dict): dict containing the entities from NER (of your choosing), the format + must be the following format: + [ + { + "cui": "202099003", -this is optional + "value": "discoid lateral meniscus", + "start": 294, + "end": 318 + }, + { + "cui": "202099003", + "value": "Discoid lateral meniscus", + "start": 1905, + "end": 1929, + } + ] + + Returns: + Doc: spacy doc with the relations. + """ + + Span.set_extension('id', default=0, force=True) + Span.set_extension('cui', default=None, force=True) + Doc.set_extension('ents', default=[], force=True) + Doc.set_extension('relations', default=[], force=True) + nlp = spacy.blank(self.component.relcat_config.general.language) + doc = nlp(text) + + for ann in annotations: + tkn_idx = [] + for ind, word in enumerate(doc): + end_char = word.idx + len(word.text) + if end_char <= ann['end'] and end_char > ann['start']: + tkn_idx.append(ind) + entity = Span(doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"]) + entity._.cui = ann["cui"] + doc._.ents.append(entity) + + doc = self(doc) + return doc diff --git a/medcat/tokenizers/meta_cat_tokenizers.py b/medcat/tokenizers/meta_cat_tokenizers.py index 4c4daf200..53ac64791 100644 --- a/medcat/tokenizers/meta_cat_tokenizers.py +++ b/medcat/tokenizers/meta_cat_tokenizers.py @@ -193,8 +193,13 @@ def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "To try: tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs) except Exception as e: - logging.warning("Could not load tokenizer from path due to error: {}. Loading from library for model variant: {}".format(e,model_variant)) - tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(model_variant) + # So that this is a string - it should be as it's only used in MetaCAT.load method + # with `config.model.model_variant` which is a `str` rathern than None + # NOTE: The reason the type in method signature is Optional[str] is because supertype defines it as such + variant = str(model_variant) + logging.warning("Could not load tokenizer from path due to error: %s. Loading from library for model variant: %s", + e, variant) + tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(variant) return tokenizer diff --git a/medcat/utils/data_utils.py b/medcat/utils/data_utils.py index 90ef1d345..967ba3c5f 100644 --- a/medcat/utils/data_utils.py +++ b/medcat/utils/data_utils.py @@ -814,7 +814,9 @@ def prepare_from_json_chars(data: Dict, return out_data -def make_mc_train_test(data: Dict, cdb: CDB, test_size: float = 0.2) -> Tuple: +def make_mc_train_test(data: Dict, cdb: CDB, test_size: float = 0.2, + min_test_count: int = 10, + max_test_fraction: float = 0.3) -> Tuple: """Make train set. This is a disaster. @@ -823,6 +825,10 @@ def make_mc_train_test(data: Dict, cdb: CDB, test_size: float = 0.2) -> Tuple: data (Dict): The data. cdb (CDB): The concept database. test_size (float): The test size. Defaults to 0.2. + min_test_count (int): The minimum numer of examples of a concepts + for it to be considered for the test set. Defaults to 10. + max_test_fraction (float): The maximum fraction of a concept + in the test set. Defaults to 0.3 Returns: Tuple: The train set, the test set, the test annotations, and the total annotations @@ -912,14 +918,17 @@ def make_mc_train_test(data: Dict, cdb: CDB, test_size: float = 0.2) -> Tuple: # Did we get more than 30% of concepts for any CUI with >=10 cnt - is_test = True - for cui, v in _cnts.items(): - if (v + test_cnts.get(cui, 0)) / cnts[cui] > 0.3: - if cnts[cui] >= 10: - # We only care for concepts if count >= 10, else they will be ignored - #during the test phase (for all metrics and similar) - is_test = False - break + # NOTE: This implementation is true to the INTENT of the previous one + # but the previous one would act quite a bit differently since + # the logic was flawed. The previous implementation guaranteed + # any document with only rare concepts (i.e ones with fewer than 10 + # examples across the entire dataset) would get a chance to be + # included in the test set (as long as the test size wasn't met) + is_test = any( + cnts[cui] >= min_test_count and + (v + test_cnts.get(cui, 0)) / cnts[cui] < max_test_fraction + for cui, v in _cnts.items() + ) # Add to test set if is_test and np.random.rand() < test_prob: diff --git a/medcat/utils/meta_cat/data_utils.py b/medcat/utils/meta_cat/data_utils.py index 3d0431308..3fff06514 100644 --- a/medcat/utils/meta_cat/data_utils.py +++ b/medcat/utils/meta_cat/data_utils.py @@ -1,5 +1,6 @@ -from typing import Dict, Optional, Tuple, Iterable, List +from typing import Any, Dict, Optional, Tuple, Iterable, List, Union from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase +import copy import logging logger = logging.getLogger(__name__) @@ -100,10 +101,10 @@ def prepare_from_json(data: Dict, ln = e_ind - s_ind tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos + ln + 1:] - # Backward compatibility if meta_anns is a list vs dict in the new approach - meta_anns = [] + meta_anns: Union[Dict[Any, Any], List, Any] = [] + if 'meta_anns' in ann: - meta_anns = ann['meta_anns'].values() if isinstance(ann['meta_anns'],dict) else ann['meta_anns'] + meta_anns = ann['meta_anns'].values() if isinstance(ann['meta_anns'], dict) else ann['meta_anns'] # If the annotation is validated for meta_ann in meta_anns: @@ -153,7 +154,7 @@ def prepare_for_oversampled_data(data: List, def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None, - category_undersample=None) -> Tuple: + category_undersample=None, alternative_class_names: List[List] = []) -> Tuple: """Converts the category values in the data outputted by `prepare_from_json` into integer values. @@ -164,6 +165,8 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict Map from category_value to id (old/existing). category_undersample: Name of class that should be used to undersample the data (for 2 phase learning) + alternative_class_names: + Map that stores the variations of possible class names for the given category (task) Returns: dict: @@ -172,6 +175,9 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict New undersampled data (for 2 phase learning) with integers inplace of strings for category values dict: Map from category value to ID for all categories in the data. + + Raises: + Exception: If categoryvalue2id is pre-defined and its labels do not match the labels found in the data """ data = list(data) if existing_category_value2id is not None: @@ -180,9 +186,48 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict category_value2id = {} category_values = set([x[2] for x in data]) - for c in category_values: - if c not in category_value2id: - category_value2id[c] = len(category_value2id) + + # If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data + if len(category_value2id) != 0 and set(category_value2id.keys()) != category_values: + # if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations + if len(alternative_class_names) == 0: + # Raise an exception since the labels don't match + raise Exception( + "The classes set in the config are not the same as the one found in the data. " + "The classes present in the config vs the ones found in the data - " + f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the " + "'alternative_class_names' attribute to accommodate for variations.") + updated_category_value2id = {} + for _class in category_value2id.keys(): + if _class in category_values: + updated_category_value2id[_class] = category_value2id[_class] + else: + found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map] + failed_to_find = False + if len(found_in) != 0: + class_name_matched = [label for label in found_in[0] if label in category_values] + if len(class_name_matched) != 0: + updated_category_value2id[class_name_matched[0]] = category_value2id[_class] + logger.info("Class name '%s' does not exist in the data; however a variation of it " + "'%s' is present; updating it...", _class, class_name_matched[0]) + else: + failed_to_find = True + else: + failed_to_find = True + if failed_to_find: + raise Exception("The classes set in the config are not the same as the one found in the data. " + "The classes present in the config vs the ones found in the data - " + f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the " + "populate the 'alternative_class_names' attribute to accommodate for variations.") + category_value2id = copy.deepcopy(updated_category_value2id) + logger.info("Updated categoryvalue2id mapping - %s", category_value2id) + + # Else create the mapping from the labels found in the data + else: + for c in category_values: + if c not in category_value2id: + category_value2id[c] = len(category_value2id) + logger.info("Categoryvalue2id mapping created with labels found in the data - %s", category_value2id) # Map values to numbers for i in range(len(data)): @@ -194,7 +239,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict if data[i][2] in category_value2id.values(): label_data_[data[i][2]] = label_data_[data[i][2]] + 1 - logger.info("Original label_data: %s",label_data_) + logger.info("Original number of samples per label: %s",label_data_) # Undersampling data if category_undersample is None or category_undersample == '': min_label = min(label_data_.values()) @@ -217,7 +262,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict for i in range(len(data_undersampled)): if data_undersampled[i][2] in category_value2id.values(): label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1 - logger.info("Updated label_data: %s",label_data) + logger.info("Updated number of samples per label (for 2-phase learning): %s",label_data) return data, data_undersampled, category_value2id diff --git a/medcat/utils/meta_cat/ml_utils.py b/medcat/utils/meta_cat/ml_utils.py index a7acf34c9..0d75eabe8 100644 --- a/medcat/utils/meta_cat/ml_utils.py +++ b/medcat/utils/meta_cat/ml_utils.py @@ -14,7 +14,8 @@ from sklearn.metrics import classification_report, precision_recall_fscore_support, confusion_matrix from sklearn.model_selection import train_test_split from sklearn.utils.class_weight import compute_class_weight -from transformers import AdamW, get_linear_schedule_with_warmup +from transformers import get_linear_schedule_with_warmup +from torch.optim import AdamW import logging @@ -329,12 +330,12 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4): print_report(epoch, running_loss_test, all_logits_test, y=y_test, name='Test') _report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), - output_dict=True) + output_dict=True,zero_division=0) if not winner_report or _report[config.train['metric']['base']][config.train['metric']['score']] > \ winner_report['report'][config.train['metric']['base']][config.train['metric']['score']]: report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), - output_dict=True) + output_dict=True,zero_division=0) cm = confusion_matrix(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), normalize='true') report_train = classification_report(y_train, np.argmax(np.concatenate(all_logits, axis=0), axis=1), output_dict=True) diff --git a/medcat/utils/ner/deid.py b/medcat/utils/ner/deid.py index 8a4310a2f..522924503 100644 --- a/medcat/utils/ner/deid.py +++ b/medcat/utils/ner/deid.py @@ -66,6 +66,10 @@ def train(self, json_path: Union[str, list, None], *args, **kwargs) -> Tuple[Any, Any, Any]: return super().train(json_path, *args, train_nr=0, **kwargs) # type: ignore + def eval(self, json_path: Union[str, list, None], + *args, **kwargs) -> Tuple[Any, Any, Any]: + return super().eval(json_path, *args, train_nr=0, **kwargs) # type: ignore + def deid_text(self, text: str, redact: bool = False) -> str: """Deidentify text and potentially redact information. diff --git a/medcat/utils/ner/model.py b/medcat/utils/ner/model.py index 23d32c658..d7f8b01e3 100644 --- a/medcat/utils/ner/model.py +++ b/medcat/utils/ner/model.py @@ -42,6 +42,23 @@ def train(self, json_path: Union[str, list, None], train_nr: int = 0, """ return self.cat._addl_ner[train_nr].train(json_path, *args, **kwargs) + def eval(self, json_path: Union[str, list, None], train_nr: int = 0, + *args, **kwargs) -> Tuple[Any, Any, Any]: + """Evaluate the underlying transformers NER model. + + All the extra arguments are passed to the TransformersNER eval method. + + Args: + json_path (Union[str, list, None]): The JSON file path to read the training data from. + train_nr (int): The number of the NER object in cat._addl_train to train. Defaults to 0. + *args: Additional arguments for TransformersNER.eval . + **kwargs: Additional keyword arguments for TransformersNER.eval . + + Returns: + Tuple[Any, Any, Any]: df, examples, dataset + """ + return self.cat._addl_ner[train_nr].eval(json_path, *args, **kwargs) + def __call__(self, text: Optional[str], *args, **kwargs) -> Optional[Doc]: """Get the annotated document for text. @@ -76,6 +93,21 @@ def get_entities(self, text: str, *args, **kwargs) -> dict: """ return self.cat.get_entities(text, *args, **kwargs) + def add_new_concepts(self, + cui2preferred_name: Dict[str, str], + train_nr: int = 0, + with_random_init: bool = False) -> None: + """Add new concepts to the model and the concept database. + + Invoking this requires subsequent retraining on the model. + + Args: + cui2preferred_name(Dict[str, str]): Dictionary where each key is the literal ID of the concept to be added and each value is its preferred name. + train_nr (int): The number of the NER object in cat._addl_train to which new concepts will be added. Defaults to 0. + with_random_init (bool): Whether to use the random init strategy for the new concepts. Defaults to False. + """ + self.cat._addl_ner[train_nr].expand_model_with_concepts(cui2preferred_name, use_avg_init=not with_random_init) + @property def config(self) -> Config: return self.cat.config diff --git a/medcat/utils/relation_extraction/base_component.py b/medcat/utils/relation_extraction/base_component.py new file mode 100644 index 000000000..38e0c6c07 --- /dev/null +++ b/medcat/utils/relation_extraction/base_component.py @@ -0,0 +1,142 @@ +import logging +import os + +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.relation_extraction.models import BaseModel_RelationExtraction +from medcat.utils.relation_extraction.pad_seq import Pad_Sequence +from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction +from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction + +from torch.optim import AdamW +from torch.optim.lr_scheduler import MultiStepLR + +from medcat.utils.relation_extraction.ml_utils import load_state, save_state + +logger = logging.getLogger(__name__) + + +class BaseComponent_RelationExtraction(): + + name = "base_component_rel" + + def __init__(self, tokenizer: BaseTokenizerWrapper_RelationExtraction = BaseTokenizerWrapper_RelationExtraction(), + model: BaseModel_RelationExtraction = None, # type: ignore + model_config: BaseConfig_RelationExtraction = None, # type: ignore + config: ConfigRelCAT = ConfigRelCAT(), + task: str = "train", + init_model: bool = False): + """ Component that holds the model and everything for RelCAT. + + Args: + tokenizer (BaseTokenizerWrapper_RelationExtraction): The base tokenizer for RelCAT. + model (BaseModel_RelationExtraction): The model wrapper. + model_config (BaseConfig_RelationExtraction): The model-specific config. + config (ConfigRelCAT): The RelCAT config. + task (str): The task - used for checkpointing. + init_model (bool): Loads default BERT base model, tokenizer, model config. Defaults to False. + """ + + self.model: BaseModel_RelationExtraction = model # type: ignore + self.tokenizer: BaseTokenizerWrapper_RelationExtraction = tokenizer # type: ignore + self.relcat_config: ConfigRelCAT = config + self.model_config: BaseConfig_RelationExtraction = model_config + self.optimizer: AdamW = None # type: ignore + self.scheduler: MultiStepLR = None # type: ignore + self.task: str = task + self.epoch: int = 0 + self.best_f1: float = 0.0 + + if init_model: + self.model_config = BaseConfig_RelationExtraction.load(pretrained_model_name_or_path=self.relcat_config.general.model_name, + relcat_config=self.relcat_config) + + self.tokenizer = BaseTokenizerWrapper_RelationExtraction.load(tokenizer_path=self.relcat_config.general.model_name, + relcat_config=self.relcat_config) + + self.tokenizer.hf_tokenizers.add_tokens(self.relcat_config.general.tokenizer_relation_annotation_special_tokens_tags, special_tokens=True) + + # used in llama tokenizer, may produce issues with other tokenizers + self.tokenizer.hf_tokenizers.add_special_tokens(self.relcat_config.general.tokenizer_other_special_tokens) + self.relcat_config.general.annotation_schema_tag_ids = self.tokenizer.hf_tokenizers.convert_tokens_to_ids(self.relcat_config.general.tokenizer_relation_annotation_special_tokens_tags) + self.relcat_config.model.padding_idx = self.model_config.pad_token_id = self.tokenizer.get_pad_id() + self.model_config.hf_model_config.vocab_size = self.tokenizer.get_size() + + self.model = BaseModel_RelationExtraction.load(pretrained_model_name_or_path=self.relcat_config.general.model_name, + model_config=self.model_config, + relcat_config=self.relcat_config) + + self.model.hf_model.resize_token_embeddings(self.tokenizer.get_size()) # type: ignore + + self.pad_id = self.relcat_config.model.padding_idx + self.padding_seq = Pad_Sequence(seq_pad_value=self.pad_id, + label_pad_value=self.pad_id) + + self.log = logging.getLogger(__name__) + logging.basicConfig(level=self.relcat_config.general.log_level) + self.log.setLevel(self.relcat_config.general.log_level) + + self.log.info("BaseComponent_RelationExtraction initialized") + + def save(self, save_path: str) -> None: + """ Saves model and its dependencies to specified save_path folder. + The CDB is obviously not saved, it is however necessary to save the tokenizer used. + + Args: + save_path (str): folder path in which to save the model & deps. + """ + + assert self.relcat_config is not None + self.relcat_config.save(os.path.join(save_path, "config.json")) + + assert self.tokenizer is not None + self.tokenizer.save(os.path.join(save_path)) + + assert self.model is not None and self.model.hf_model is not None + self.model.hf_model.resize_token_embeddings(self.tokenizer.get_size()) # type: ignore + + assert self.model_config is not None + self.model_config.hf_model_config.vocab_size = self.tokenizer.get_size() + self.model_config.hf_model_config.pad_token_id = self.pad_id + self.model_config.save(save_path) + + save_state(self.model, optimizer=self.optimizer, scheduler=self.scheduler, epoch=self.epoch, best_f1=self.best_f1, + path=save_path, model_name=self.relcat_config.general.model_name, + task=self.task, is_checkpoint=False, final_export=True) + + @classmethod + def load(cls, pretrained_model_name_or_path: str = "./") -> "BaseComponent_RelationExtraction": + """ + Args: + pretrained_model_name_or_path (str): Path to RelCAT model. Defaults to "./". + + Returns: + BaseComponent_RelationExtraction: component. + """ + + relcat_config = ConfigRelCAT.load(load_path=pretrained_model_name_or_path) + + model_config = BaseConfig_RelationExtraction.load(pretrained_model_name_or_path=pretrained_model_name_or_path, + relcat_config=relcat_config) + + tokenizer = BaseTokenizerWrapper_RelationExtraction.load(tokenizer_path=pretrained_model_name_or_path, + relcat_config=relcat_config) + + model = BaseModel_RelationExtraction.load(pretrained_model_name_or_path=pretrained_model_name_or_path, + model_config=model_config, + relcat_config=relcat_config) + + model.hf_model.resize_token_embeddings(len(tokenizer.hf_tokenizers)) # type: ignore + + optimizer = None # type: ignore + scheduler = None # type: ignore + + epoch, best_f1 = load_state(model, optimizer, scheduler, path=pretrained_model_name_or_path, + model_name=relcat_config.general.model_name, + file_prefix=relcat_config.general.task, + relcat_config=relcat_config) + + component = cls(model=model, tokenizer=tokenizer, model_config=model_config, config=relcat_config) + cls.epoch = epoch + cls.best_f1 = best_f1 + + return component diff --git a/medcat/utils/relation_extraction/bert/__init__.py b/medcat/utils/relation_extraction/bert/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat/utils/relation_extraction/bert/config.py b/medcat/utils/relation_extraction/bert/config.py new file mode 100644 index 000000000..047903643 --- /dev/null +++ b/medcat/utils/relation_extraction/bert/config.py @@ -0,0 +1,30 @@ +import logging +import os +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction +from transformers import BertConfig + +logger = logging.getLogger(__name__) + + +class BertConfig_RelationExtraction(BaseConfig_RelationExtraction): + """ Class for BertConfig + """ + + name = 'bert-config' + pretrained_model_name_or_path = "bert-base-uncased" + + @classmethod + def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "BertConfig_RelationExtraction": + model_config = cls(pretrained_model_name_or_path, **kwargs) + + if pretrained_model_name_or_path and os.path.exists(pretrained_model_name_or_path): + model_config.hf_model_config = BertConfig.from_json_file(pretrained_model_name_or_path) + logger.info("Loaded config from file: " + pretrained_model_name_or_path) + else: + relcat_config.general.model_name = cls.pretrained_model_name_or_path + model_config.hf_model_config = BertConfig.from_pretrained( + pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs) + logger.info("Loaded config from pretrained: " + relcat_config.general.model_name) + + return model_config diff --git a/medcat/utils/relation_extraction/bert/model.py b/medcat/utils/relation_extraction/bert/model.py new file mode 100644 index 000000000..6a9c9cd5f --- /dev/null +++ b/medcat/utils/relation_extraction/bert/model.py @@ -0,0 +1,79 @@ +import logging +import os +import torch + + +from typing import Union + +from torch import nn +from transformers import PreTrainedModel +from transformers.models.bert.modeling_bert import BertModel + +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.relation_extraction.ml_utils import create_dense_layers +from medcat.utils.relation_extraction.models import BaseModel_RelationExtraction +from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction +from medcat.utils.relation_extraction.bert.config import BertConfig_RelationExtraction + + +class BertModel_RelationExtraction(BaseModel_RelationExtraction): + """ BertModel class for RelCAT + """ + + name = "bertmodel_relcat" + + log = logging.getLogger(__name__) + + def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, BertConfig_RelationExtraction]): + """ Class to hold the BERT model + model_config + + Args: + pretrained_model_name_or_path (str): path to load the model from, + this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' + using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model. + relcat_config (ConfigRelCAT): relcat config. + model_config (Union[BaseConfig_RelationExtraction | BertConfig_RelationExtraction]): HF bert config for model. + """ + super(BertModel_RelationExtraction, self).__init__(pretrained_model_name_or_path=pretrained_model_name_or_path, + relcat_config=relcat_config, + model_config=model_config) + + self.relcat_config: ConfigRelCAT = relcat_config + self.model_config: Union[BaseConfig_RelationExtraction, BertConfig_RelationExtraction] = model_config + self.pretrained_model_name_or_path: str = pretrained_model_name_or_path + + self.hf_model: Union[BertModel, PreTrainedModel] = BertModel(model_config.hf_model_config) # type: ignore + + for param in self.hf_model.parameters(): # type: ignore + if self.relcat_config.model.freeze_layers: + param.requires_grad = False + else: + param.requires_grad = True + + self.drop_out = nn.Dropout(self.relcat_config.model.dropout) + + # dense layers + self.fc1, self.fc2, self.fc3 = create_dense_layers(self.relcat_config) + + @classmethod + def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, BertConfig_RelationExtraction], **kwargs) -> "BertModel_RelationExtraction": + + model = BertModel_RelationExtraction(pretrained_model_name_or_path=pretrained_model_name_or_path, + relcat_config=relcat_config, + model_config=model_config) + + model_path = os.path.join(pretrained_model_name_or_path, "model.dat") + + if os.path.exists(model_path): + model.load_state_dict(torch.load(model_path, map_location=relcat_config.general.device)) + cls.log.info("Loaded model from file: " + model_path) + elif pretrained_model_name_or_path: + model.hf_model = BertModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, config=model_config.hf_model_config, ignore_mismatched_sizes=True, **kwargs) + cls.log.info("Loaded model from pretrained: " + pretrained_model_name_or_path) + else: + model.hf_model = BertModel.from_pretrained( + pretrained_model_name_or_path=cls.pretrained_model_name_or_path, config=model_config.hf_model_config, ignore_mismatched_sizes=True, **kwargs) + cls.log.info("Loaded model from pretrained: " + cls.pretrained_model_name_or_path) + + return model diff --git a/medcat/utils/relation_extraction/bert/tokenizer.py b/medcat/utils/relation_extraction/bert/tokenizer.py new file mode 100644 index 000000000..8b9fd8a76 --- /dev/null +++ b/medcat/utils/relation_extraction/bert/tokenizer.py @@ -0,0 +1,35 @@ +import os +from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast +import logging + +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction + +logger = logging.getLogger(__name__) + + +class TokenizerWrapperBERT_RelationExtraction(BaseTokenizerWrapper_RelationExtraction): + + name = "tokenizer_wrapper_bert_rel" + + ''' Wrapper around a huggingface BERT tokenizer so that it works with the + RelCAT models. + + Args: + hf_tokenizers (`transformers.models.bert.tokenization_bert_fast.PreTrainedTokenizerFast`): + A huggingface Fast BERT. + ''' + name = 'bert-tokenizer' + pretrained_model_name_or_path = "bert-base-uncased" + + @classmethod + def load(cls, tokenizer_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "TokenizerWrapperBERT_RelationExtraction": + tokenizer = cls() + path = os.path.join(tokenizer_path, cls.name) + + if tokenizer_path: + tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path=path, **kwargs) + else: + relcat_config.general.model_name = cls.pretrained_model_name_or_path + tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path=relcat_config.general.model_name) + return tokenizer diff --git a/medcat/utils/relation_extraction/config.py b/medcat/utils/relation_extraction/config.py new file mode 100644 index 000000000..70ec591c0 --- /dev/null +++ b/medcat/utils/relation_extraction/config.py @@ -0,0 +1,59 @@ +import os +import logging +from transformers import PretrainedConfig + +from medcat.config_rel_cat import ConfigRelCAT + + +logger = logging.getLogger(__name__) + + +class BaseConfig_RelationExtraction(PretrainedConfig): + """ Base class for the RelCAT models + """ + name = "base-config-relcat" + + def __init__(self, pretrained_model_name_or_path, **kwargs): + super().__init__(**kwargs) + self.model_type = "relcat" + self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.hf_model_config: PretrainedConfig = kwargs.get("model_config", PretrainedConfig()) + + def to_dict(self): + output = super().to_dict() + output["model_type"] = self.model_type + output["pretrained_model_name_or_path"] = self.pretrained_model_name_or_path + output["model_config"] = self.hf_model_config + return output + + def save(self, save_path: str): + self.hf_model_config.to_json_file( + os.path.join(save_path, "model_config.json")) + + @classmethod + def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "BaseConfig_RelationExtraction": + + model_config_path = os.path.join(pretrained_model_name_or_path, "model_config.json") + model_config = BaseConfig_RelationExtraction(pretrained_model_name_or_path=pretrained_model_name_or_path, relcat_config=relcat_config, **kwargs) + + if os.path.exists(model_config_path): + if "modern-bert" in relcat_config.general.tokenizer_name or \ + "modern-bert" in relcat_config.general.model_name: + from medcat.utils.relation_extraction.modernbert.config import ModernBertConfig_RelationExtraction + model_config = ModernBertConfig_RelationExtraction.load(model_config_path, relcat_config=relcat_config, **kwargs) + elif "bert" in relcat_config.general.tokenizer_name or \ + "bert" in relcat_config.general.model_name: + from medcat.utils.relation_extraction.bert.config import BertConfig_RelationExtraction + model_config = BertConfig_RelationExtraction.load(model_config_path, relcat_config=relcat_config, **kwargs) + elif "llama" in relcat_config.general.tokenizer_name or \ + "llama" in relcat_config.general.model_name: + from medcat.utils.relation_extraction.llama.config import LlamaConfig_RelationExtraction + model_config = LlamaConfig_RelationExtraction.load(model_config_path, relcat_config=relcat_config, **kwargs) + else: + if pretrained_model_name_or_path: + model_config.hf_model_config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + else: + model_config.hf_model_config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=relcat_config.general.model_name, **kwargs) + logger.info("Loaded config from : " + model_config_path) + + return model_config diff --git a/medcat/utils/relation_extraction/llama/__init__.py b/medcat/utils/relation_extraction/llama/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat/utils/relation_extraction/llama/config.py b/medcat/utils/relation_extraction/llama/config.py new file mode 100644 index 000000000..18bc0f322 --- /dev/null +++ b/medcat/utils/relation_extraction/llama/config.py @@ -0,0 +1,30 @@ +import logging +import os +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction +from transformers import LlamaConfig + +logger = logging.getLogger(__name__) + + +class LlamaConfig_RelationExtraction(BaseConfig_RelationExtraction): + """ Class for LlamaConfig + """ + + name = 'llama-config' + pretrained_model_name_or_path = "meta-llama/Llama-3.1-8B" + + @classmethod + def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "LlamaConfig_RelationExtraction": + model_config = cls(pretrained_model_name_or_path, **kwargs) + + if pretrained_model_name_or_path and os.path.exists(pretrained_model_name_or_path): + model_config.hf_model_config = LlamaConfig.from_json_file(pretrained_model_name_or_path) + logger.info("Loaded config from file: " + pretrained_model_name_or_path) + else: + relcat_config.general.model_name = cls.pretrained_model_name_or_path + model_config.hf_model_config = LlamaConfig.from_pretrained( + pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs) + logger.info("Loaded config from pretrained: " + relcat_config.general.model_name) + + return model_config diff --git a/medcat/utils/relation_extraction/llama/model.py b/medcat/utils/relation_extraction/llama/model.py new file mode 100644 index 000000000..6dc85dd76 --- /dev/null +++ b/medcat/utils/relation_extraction/llama/model.py @@ -0,0 +1,218 @@ +import logging +from typing import Any, Optional, Tuple, Union +import torch +from torch import nn +import os +from transformers.models.llama import LlamaModel + +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction +from medcat.utils.relation_extraction.llama.config import LlamaConfig_RelationExtraction +from medcat.utils.relation_extraction.models import BaseModel_RelationExtraction +from medcat.utils.relation_extraction.ml_utils import create_dense_layers, get_annotation_schema_tag + + +class LlamaModel_RelationExtraction(BaseModel_RelationExtraction): + """ LlamaModel class for RelCAT + """ + + name = "llamamodel_relcat" + + log = logging.getLogger(__name__) + + def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, LlamaConfig_RelationExtraction]): + """ Class to hold the Llama model + model_config + + Args: + pretrained_model_name_or_path (str): path to load the model from, + this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' + using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model. + relcat_config (ConfigRelCAT): relcat config. + model_config (Union[BaseConfig_RelationExtraction | LlamaConfig_RelationExtraction]): HF bert config for model. + """ + + super(LlamaModel_RelationExtraction, self).__init__(pretrained_model_name_or_path=pretrained_model_name_or_path, + relcat_config=relcat_config, + model_config=model_config) + + self.relcat_config: ConfigRelCAT = relcat_config + self.model_config: Union[BaseConfig_RelationExtraction, LlamaConfig_RelationExtraction] = model_config + + self.hf_model: LlamaModel = LlamaModel(config=model_config) # type: ignore + + if pretrained_model_name_or_path != "": + self.hf_model = LlamaModel.from_pretrained(pretrained_model_name_or_path, config=model_config, ignore_mismatched_sizes=True) + + for param in self.hf_model.parameters(): + param.requires_grad = False + + self.drop_out = nn.Dropout(self.relcat_config.model.dropout) + + # dense layers + self.fc1, self.fc2, self.fc3 = create_dense_layers(self.relcat_config) + + # for pooled output + self.llama_pooler = LlamaPooler(self.model_config.hidden_size) + + self.log.info("RelCAT LlamaConfig: " + str(self.model_config)) + + def output2logits(self, pooled_output: torch.Tensor, sequence_output: torch.Tensor, input_ids: torch.Tensor, e1_e2_start: torch.Tensor) -> torch.Tensor: + """ + + Args: + pooled_output (torch.Tensor): embedding of the CLS token + sequence_output (torch.Tensor): hidden states/embeddings for each token in the input text + input_ids (torch.Tensor): input token ids. + e1_e2_start (torch.Tensor): annotation tags token position + + Returns: + torch.Tensor: classification probabilities for each token. + """ + + new_pooled_output = pooled_output + + if self.relcat_config.general.annotation_schema_tag_ids: + annotation_schema_tag_ids_ = [self.relcat_config.general.annotation_schema_tag_ids[i:i + 2] for i in + range(0, len(self.relcat_config.general.annotation_schema_tag_ids), 2)] + seq_tags = [] + + # for each pair of tags (e1,s1) and (e2,s2) + for each_tags in annotation_schema_tag_ids_: + seq_tags.append(get_annotation_schema_tag( + sequence_output, input_ids, each_tags)) + + seq_tags = torch.stack(seq_tags, dim=0) + + if self.relcat_config.model.llama_use_pooled_output: + new_pooled_output = torch.cat((pooled_output, *seq_tags), dim=1) + else: + new_pooled_output = torch.cat((seq_tags[0], seq_tags[1]), dim=1) + else: + e1e2_output = [] + temp_e1 = [] + temp_e2 = [] + + for i, seq in enumerate(sequence_output): + # e1e2 token sequences + temp_e1.append(seq[e1_e2_start[i][0]]) + temp_e2.append(seq[e1_e2_start[i][1]]) + + e1e2_output.append(torch.stack(temp_e1, dim=0)) + e1e2_output.append(torch.stack(temp_e2, dim=0)) + + new_pooled_output = torch.cat((pooled_output, *e1e2_output), dim=1) + + del e1e2_output + del temp_e2 + del temp_e1 + + x = self.drop_out(new_pooled_output) + x = self.fc1(x) + x = self.drop_out(x) + x = self.fc2(x) + classification_logits = self.fc3(x) + return classification_logits.to(self.relcat_config.general.device) + + + def forward(self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Any = None, + head_mask: Any = None, + encoder_hidden_states: Any = None, + encoder_attention_mask: Any = None, + Q: Any = None, + e1_e2_start: Any = None, + pooled_output: Any = None) -> Tuple[torch.Tensor, torch.Tensor]: + + if input_ids is not None: + input_shape = input_ids.size() + else: + raise ValueError("You have to specify input_ids") + + if attention_mask is None: + attention_mask = torch.ones( + input_shape, device=self.relcat_config.general.device) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + input_shape, device=self.relcat_config.general.device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.relcat_config.general.device) + + input_ids = input_ids.to(self.relcat_config.general.device) + attention_mask = attention_mask.to(self.relcat_config.general.device) + encoder_attention_mask = encoder_attention_mask.to( + self.relcat_config.general.device) + + self.hf_model = self.hf_model.to(self.relcat_config.general.device) # type: ignore + + model_output = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + + # (batch_size, sequence_length, hidden_size) + sequence_output = model_output.last_hidden_state + + if self.relcat_config.model.llama_use_pooled_output: + pooled_output = self.llama_pooler(model_output) + pooled_output = pooled_output.to(self.relcat_config.general.device) + + classification_logits = self.output2logits( + pooled_output, sequence_output, input_ids, e1_e2_start) + + return model_output, classification_logits.to(self.relcat_config.general.device) + + @classmethod + def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, LlamaConfig_RelationExtraction], **kwargs) -> "LlamaModel_RelationExtraction": + + model = LlamaModel_RelationExtraction(pretrained_model_name_or_path=pretrained_model_name_or_path, + relcat_config=relcat_config, + model_config=model_config) + + model_path = os.path.join(pretrained_model_name_or_path, "model.dat") + + if os.path.exists(model_path): + model.load_state_dict(torch.load(model_path, map_location=relcat_config.general.device)) + cls.log.info("Loaded model from file: " + model_path) + elif pretrained_model_name_or_path: + model.hf_model = LlamaModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, config=model_config.hf_model_config, ignore_mismatched_sizes=True, **kwargs) + cls.log.info("Loaded model from pretrained: " + pretrained_model_name_or_path) + else: + model.hf_model = LlamaModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, config=model_config.hf_model_config, ignore_mismatched_sizes=True, **kwargs) + cls.log.info("Loaded model from pretrained: " + cls.pretrained_model_name_or_path) + + return model + + +class LlamaPooler(nn.Module): + """ An attempt to copy the BERT pooling technique for an increase in performance. + + Args: + nn (nn.Module): . + """ + def __init__(self, hidden_size: int): + """ Initialises the pooler with a linear layer of size: + self.model_config.hidden_size x self.model_config.hidden_size + + Args: + hidden_size (int): size of tensor + """ + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. (original BERT) + # We can do the same here but the [CLS] token equivalent is not the same as + # for bert as there is not much learning contained in it. + # e.g: first_token_tensor = hidden_states[:, 0] # original + # so instead we pool across all the tokens from the last hidden layer. + + pooled_output, _ = torch.max(hidden_states[-1], dim=1) + pooled_output = self.dense(pooled_output) + pooled_output = self.activation(pooled_output) + + return pooled_output diff --git a/medcat/utils/relation_extraction/llama/tokenizer.py b/medcat/utils/relation_extraction/llama/tokenizer.py new file mode 100644 index 000000000..e82ab6046 --- /dev/null +++ b/medcat/utils/relation_extraction/llama/tokenizer.py @@ -0,0 +1,34 @@ +import os +from transformers import LlamaTokenizerFast +import logging + +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction + +logger = logging.getLogger(__name__) + + +class TokenizerWrapperLlama_RelationExtraction(BaseTokenizerWrapper_RelationExtraction): + ''' Wrapper around a huggingface Llama tokenizer so that it works with the + RelCAT models. + + Args: + hf_tokenizers (`transformers.LlamaTokenizerFast`): + A huggingface Fast Llama. + ''' + name = "tokenizer_wrapper_llama_rel" + pretrained_model_name_or_path = "meta-llama/Llama-3.1-8B" + + @classmethod + def load(cls, tokenizer_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "TokenizerWrapperLlama_RelationExtraction": + tokenizer = cls() + path = os.path.join(tokenizer_path, cls.name) + + if tokenizer_path: + tokenizer.hf_tokenizers = LlamaTokenizerFast.from_pretrained( + path, **kwargs) + else: + relcat_config.general.model_name = cls.pretrained_model_name_or_path + tokenizer.hf_tokenizers = LlamaTokenizerFast.from_pretrained( + path=relcat_config.general.model_name) + return tokenizer diff --git a/medcat/utils/relation_extraction/ml_utils.py b/medcat/utils/relation_extraction/ml_utils.py new file mode 100644 index 000000000..ed9d01b2d --- /dev/null +++ b/medcat/utils/relation_extraction/ml_utils.py @@ -0,0 +1,305 @@ +import torch +import logging +import os +import pickle +from typing import Any, Dict, List, Tuple +import random + +from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction +from medcat.config_rel_cat import ConfigRelCAT + +from torch import nn + + +logger = logging.getLogger(__name__) + + +def split_list_train_test_by_class(data: List, sample_limit: int = -1, test_size: float = 0.2, shuffle: bool = True) -> Tuple[List, List]: + """ + + Args: + data (List): "output_relations": relation_instances, <-- see create_base_relations_from_doc/csv + for data columns + sample_limit (int): limit the number of samples per class, useful for dataset balancing . Defaults to -1. + test_size (float): Defaults to 0.2. + shuffle (bool): shuffle data randomly. Defaults to True. + + Returns: + Tuple[List, List]: train and test datasets + """ + + train_data = [] + test_data = [] + + row_id_labels = {row_idx: data[row_idx][5] for row_idx in range(len(data))} + lbl_id_to_name = {data[row_idx][5]: data[row_idx][4] for row_idx in range((len(data)))} + + count_per_label = {lbl: list(row_id_labels.values()).count( + lbl) for lbl in set(row_id_labels.values())} + + new_label_count_train = {} + new_label_count_test = {} + + for lbl_id, count in count_per_label.items(): + if sample_limit != -1 and count > sample_limit: + count = sample_limit + + _test_records_size = int(count * test_size) + + test_sample_count = 0 + train_sample_count = 0 + + if _test_records_size not in [0, 1]: + for row_idx, _lbl_id in row_id_labels.items(): + if _lbl_id == lbl_id: + if test_sample_count < _test_records_size: + test_data.append(data[row_idx]) + test_sample_count += 1 + else: + if sample_limit != -1: + if train_sample_count < sample_limit: + train_data.append(data[row_idx]) + train_sample_count += 1 + else: + train_data.append(data[row_idx]) + train_sample_count += 1 + + else: + for row_idx, _lbl_id in row_id_labels.items(): + if _lbl_id == lbl_id: + train_data.append(data[row_idx]) + test_data.append(data[row_idx]) + train_sample_count += 1 + test_sample_count += 1 + + new_label_count_test[lbl_id] = test_sample_count + new_label_count_train[lbl_id] = train_sample_count + + logging.info("Relations after train, test split : train - " + str(sum(new_label_count_train.values())) + " | test - " + str(sum(new_label_count_test.values()))) + + for label_id in list(lbl_id_to_name.keys()): + logging.info(" label: " + lbl_id_to_name[label_id] + " samples | train " + str(new_label_count_train[label_id]) + " | test " + str(new_label_count_test[label_id])) + + if shuffle: + random.shuffle(train_data) + random.shuffle(test_data) + + return train_data, test_data + + +def load_bin_file(file_name, path="./") -> Any: + with open(os.path.join(path, file_name), 'rb') as f: + data = pickle.load(f) + return data + + +def save_bin_file(file_name, data, path="./"): + with open(os.path.join(path, file_name), "wb") as f: + pickle.dump(data, f) + + +def save_state(model, optimizer: torch.optim.AdamW, scheduler: torch.optim.lr_scheduler.MultiStepLR, epoch:int = 1, best_f1:float = 0.0, path:str = "./", model_name: str = "BERT", task:str = "train", is_checkpoint=False, final_export=False) -> None: + """ Used by RelCAT.save() and RelCAT.train() + Saves the RelCAT model state. + For checkpointing multiple files are created, best_f1, loss etc. score. + If you want to export the model after training set final_export=True and leave is_checkpoint=False. + + Args: + model (BaseModel_RelationExtraction): BertModel_RelationExtraction | LlamaModel_RelationExtraction etc. + optimizer (torch.optim.AdamW, optional): Defaults to None. + scheduler (torch.optim.lr_scheduler.MultiStepLR, optional): Defaults to None. + epoch (int): Defaults to None. + best_f1 (float): Defaults to None. + path (str):Defaults to "./". + model_name (str): . Defaults to "BERT". This is used to checkpointing only. + task (str): Defaults to "train". This is used to checkpointing only. + is_checkpoint (bool): Defaults to False. + final_export (bool): Defaults to False, if True then is_checkpoint must be False also. Exports model.state_dict(), out into"model.dat". + """ + + model_name = model_name.replace("/", "_") + file_name = "%s_checkpoint_%s.dat" % (task, model_name) + + if not is_checkpoint: + file_name = "%s_best_%s.dat" % (task, model_name) + if final_export: + file_name = "model.dat" + torch.save(model.state_dict(), os.path.join(path, file_name)) + + if is_checkpoint: + torch.save({ + 'epoch': epoch, + 'state_dict': model.state_dict(), + 'best_f1': best_f1, + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict() + }, os.path.join(path, file_name)) + + +def load_state(model, optimizer, scheduler, path="./", model_name="BERT", file_prefix="train", load_best=False, relcat_config: ConfigRelCAT = ConfigRelCAT()) -> Tuple[int, int]: + """ Used by RelCAT.load() and RelCAT.train() + + Args: + model (BaseModel_RelationExtraction): BaseModel_RelationExtraction, it has to be initialized before calling this method via (Bert/Llama)Model_RelationExtraction(...) + optimizer (_type_): optimizer + scheduler (_type_): scheduler + path (str, optional): Defaults to "./". + model_name (str, optional): Defaults to "BERT". + file_prefix (str, optional): Defaults to "train". + load_best (bool, optional): Defaults to False. + relcat_config (ConfigRelCAT): Defaults to ConfigRelCAT(). + + Returns: + Tuple (int, int): last epoch and f1 score. + """ + + device: torch.device =torch.device(relcat_config.general.device) + + model_name = model_name.replace("/", "_") + logging.info("Attempting to load RelCAT model on device: " + str(device)) + checkpoint_path = os.path.join( + path, file_prefix + "_checkpoint_%s.dat" % model_name) + best_path = os.path.join( + path, file_prefix + "_best_%s.dat" % model_name) + start_epoch, best_f1, checkpoint = 0, 0, None + + if load_best is True and os.path.isfile(best_path): + checkpoint = torch.load(best_path, map_location=device) + logging.info("Loaded best model.") + elif os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location=device) + logging.info("Loaded checkpoint model.") + + if checkpoint is not None: + start_epoch = checkpoint['epoch'] + best_f1 = checkpoint['best_f1'] + model.load_state_dict(checkpoint['state_dict']) + model.to(device) + + if optimizer is None: + parameters = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = torch.optim.AdamW(params=parameters, lr=relcat_config.train.lr, weight_decay=relcat_config.train.adam_weight_decay, + betas=relcat_config.train.adam_betas, eps=relcat_config.train.adam_epsilon) + + if scheduler is None: + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, + milestones=relcat_config.train.multistep_milestones, + gamma=relcat_config.train.multistep_lr_gamma) + optimizer.load_state_dict(checkpoint['optimizer']) + scheduler.load_state_dict(checkpoint['scheduler']) + logging.info("Loaded model and optimizer.") + + return start_epoch, best_f1 + + +def save_results(data, model_name: str = "BERT", path: str = "./", file_prefix: str = "train"): + save_bin_file(file_prefix + "_losses_accuracy_f1_per_epoch_%s.dat" % + model_name, data, path) + + +def load_results(path, model_name: str = "BERT", file_prefix: str = "train") -> Tuple[List, List, List]: + data_dict_path = os.path.join( + path, file_prefix + "_losses_accuracy_f1_per_epoch_%s.dat" % model_name) + + data_dict: Dict = {"losses_per_epoch": [], + "accuracy_per_epoch": [], "f1_per_epoch": []} + if os.path.isfile(data_dict_path): + data_dict = load_bin_file(data_dict_path) + + return data_dict["losses_per_epoch"], data_dict["accuracy_per_epoch"], data_dict["f1_per_epoch"] + + +def create_tokenizer_pretrain(tokenizer: BaseTokenizerWrapper_RelationExtraction, relcat_config: ConfigRelCAT + ) -> BaseTokenizerWrapper_RelationExtraction: + """ + This method simply adds the default special tokens that we ecounter. + + Args: + tokenizer (BaseTokenizerWrapper_RelationExtraction): BERT/Llama tokenizer. + relcat_config (ConfigRelCAT): The RelCAT config. + + Returns: + BaseTokenizerWrapper_RelationExtraction: The same tokenizer. + """ + + tokenizer.hf_tokenizers.add_tokens(relcat_config.general.tokenizer_relation_annotation_special_tokens_tags, special_tokens=True) + + # used in llama tokenizer, may produce issues with other tokenizers + tokenizer.hf_tokenizers.add_special_tokens(relcat_config.general.tokenizer_other_special_tokens) + + return tokenizer + + +def create_dense_layers(relcat_config: ConfigRelCAT): + + # dense layers + fc1 = nn.Linear(relcat_config.model.model_size, relcat_config.model.hidden_size) + fc2 = nn.Linear(relcat_config.model.hidden_size, int(relcat_config.model.hidden_size / 2)) + fc3 = nn.Linear(int(relcat_config.model.hidden_size / 2), relcat_config.train.nclasses) + + return fc1, fc2, fc3 + + +def get_annotation_schema_tag(sequence_output: torch.Tensor, input_ids: torch.Tensor, special_tag: List) -> torch.Tensor: + """ Gets to token sequences from the sequence_ouput for the specific token + tag ids in self.relcat_config.general.annotation_schema_tag_ids. + + Args: + sequence_output (torch.Tensor): hidden states/embeddings for each token in the input text + input_ids (torch.Tensor): input token ids + special_tag (List): special annotation token id pairs + + Returns: + torch.Tensor: new seq_tags + """ + + idx_start = torch.where(input_ids == special_tag[0]) # returns: row ids, idx of token[0]/star token in row + idx_end = torch.where(input_ids == special_tag[1]) # returns: row ids, idx of token[1]/end token in row + + seen = [] # List to store seen elements and their indices + duplicate_indices = [] + + for i in range(len(idx_start[0])): + if idx_start[0][i] in seen: + duplicate_indices.append(i) + else: + seen.append(idx_start[0][i]) + + if len(duplicate_indices) > 0: + logger.info("Duplicate entities found, removing them...") + for idx_remove in duplicate_indices: + idx_start_0 = torch.cat((idx_start[0][:idx_remove], idx_start[0][idx_remove + 1:])) + idx_start_1 = torch.cat((idx_start[1][:idx_remove], idx_start[1][idx_remove + 1:])) + idx_start = (idx_start_0, idx_start_1) # type: ignore + + seen = [] + duplicate_indices = [] + + for i in range(len(idx_end[0])): + if idx_end[0][i] in seen: + duplicate_indices.append(i) + else: + seen.append(idx_end[0][i]) + + if len(duplicate_indices) > 0: + logger.info("Duplicate entities found, removing them...") + for idx_remove in duplicate_indices: + idx_end_0 = torch.cat((idx_end[0][:idx_remove], idx_end[0][idx_remove + 1:])) + idx_end_1 = torch.cat((idx_end[1][:idx_remove], idx_end[1][idx_remove + 1:])) + idx_end = (idx_end_0, idx_end_1) # type: ignore + + assert len(idx_start[0]) == input_ids.shape[0] + assert len(idx_start[0]) == len(idx_end[0]) + + sequence_output_entities = [] + + for i in range(len(idx_start[0])): + to_append = sequence_output[i, idx_start[1][i] + 1:idx_end[1][i], ] + + # to_append = torch.sum(to_append, dim=0) + to_append, _ = torch.max(to_append, axis=0) # type: ignore + + sequence_output_entities.append(to_append) + sequence_output_entities = torch.stack(sequence_output_entities) + + return sequence_output_entities diff --git a/medcat/utils/relation_extraction/models.py b/medcat/utils/relation_extraction/models.py index d0003a1c5..7bbf5c55f 100644 --- a/medcat/utils/relation_extraction/models.py +++ b/medcat/utils/relation_extraction/models.py @@ -1,120 +1,171 @@ import logging -from typing import Any, List, Optional, Tuple import torch +from typing import Any, Optional, Tuple, Union from torch import nn -from transformers.models.bert.modeling_bert import BertPreTrainingHeads, BertModel -from transformers.models.bert.configuration_bert import BertConfig +from transformers import PretrainedConfig, PreTrainedModel + from medcat.config_rel_cat import ConfigRelCAT +from transformers.models.llama import LlamaModel +from transformers import BertModel +from transformers import ModernBertModel +from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction +from medcat.utils.relation_extraction.ml_utils import create_dense_layers, get_annotation_schema_tag -class BertModel_RelationExtraction(nn.Module): - """ BertModel class for RelCAT +class BaseModelBluePrint_RelationExtraction(nn.Module): + """ Base class for the RelCAT models """ - name = "bertmodel_relcat" - - log = logging.getLogger(__name__) + hf_model: PreTrainedModel + relcat_config: ConfigRelCAT + model_config: PretrainedConfig + drop_out: nn.Dropout + fc1: nn.Linear + fc2: nn.Linear + fc3: nn.Linear - def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: BertConfig): - """ Class to hold the BERT model + model_config + def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[PretrainedConfig, BaseConfig_RelationExtraction]): + """ Class to hold the HF model + model_config Args: pretrained_model_name_or_path (str): path to load the model from, this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' - using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model. + using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model. relcat_config (ConfigRelCAT): relcat config. - model_config (BertConfig): HF bert config for model. + model_config (PretrainedConfig): HF bert config for model. """ - super(BertModel_RelationExtraction, self).__init__() + super(BaseModelBluePrint_RelationExtraction, self).__init__() - self.relcat_config: ConfigRelCAT = relcat_config - self.model_config: BertConfig = model_config + def forward(self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Any = None, + head_mask: Any = None, + encoder_hidden_states: Any = None, + encoder_attention_mask: Any = None, + Q: Any = None, + e1_e2_start: Any = None, + pooled_output: Any = None) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Forward pass for the model - self.bert_model:BertModel = BertModel(config=model_config) + Args: + input_ids (torch.Tensor): input token ids. Defaults to None. + attention_mask (torch.Tensor): attention mask for the input ids. Defaults to None. + token_type_ids (torch.Tensor): token type ids for the input ids. Defaults to None. + position_ids (Any): The position IDs. Defaults to None. + head_mask (Any): The head mask. Defaults to None. + encoder_hidden_states (Any): Encoder hidden states. Defaults to None. + encoder_attention_mask (Any): Encoder attention mask. Defaults to None. + Q (Any): Q. Defaults to None. + e1_e2_start (Any): start and end indices for the entities in the input ids. Defaults to None. + pooled_output (Any): The pooled output. Defaults to None. - if pretrained_model_name_or_path != "": - self.bert_model = BertModel.from_pretrained(pretrained_model_name_or_path, config=model_config) + Returns: + Optional[Tuple[torch.Tensor, torch.Tensor]]: logits for the relation classification task. + """ + return None - for param in self.bert_model.parameters(): - param.requires_grad = False + def output2logits(self, pooled_output: torch.Tensor, sequence_output: torch.Tensor, input_ids: torch.Tensor, e1_e2_start: torch.Tensor) -> Optional[torch.Tensor]: + """ Convert the output of the model to logits - self.drop_out = nn.Dropout(self.model_config.hidden_dropout_prob) + Args: + pooled_output (torch.Tensor): output of the pooled layer. + sequence_output (torch.Tensor): output of the sequence layer. + input_ids (torch.Tensor): input token ids. + e1_e2_start (torch.Tensor): start and end indices for the entities in the input ids. - if self.relcat_config.general.task == "pretrain": - self.activation = nn.Tanh() - self.cls = BertPreTrainingHeads(self.model_config) + Returns: + logits (torch.Tensor): logits for the relation classification task. + """ + return None - self.relu = nn.ReLU() - # dense layers - self.fc1 = nn.Linear(self.relcat_config.model.model_size, self.relcat_config.model.hidden_size) - self.fc2 = nn.Linear(self.relcat_config.model.hidden_size, int(self.relcat_config.model.hidden_size / 2)) - self.fc3 = nn.Linear(int(self.relcat_config.model.hidden_size / 2), self.relcat_config.train.nclasses) +class BaseModel_RelationExtraction(BaseModelBluePrint_RelationExtraction): - self.log.info("RelCAT BertConfig: " + str(self.model_config)) + name = "basemodel_relcat" + log = logging.getLogger(__name__) - def get_annotation_schema_tag(self, sequence_output: torch.Tensor, input_ids: torch.Tensor, special_tag: List) -> torch.Tensor: - """ Gets to token sequences from the sequence_ouput for the specific token - tag ids in self.relcat_config.general.annotation_schema_tag_ids. + def __init__(self, relcat_config: ConfigRelCAT, + model_config: BaseConfig_RelationExtraction, + pretrained_model_name_or_path): + super(BaseModel_RelationExtraction, self).__init__(pretrained_model_name_or_path=pretrained_model_name_or_path, + relcat_config=relcat_config, + model_config=model_config) - Args: - sequence_output (torch.Tensor): hidden states/embeddings for each token in the input text - input_ids (torch.Tensor): input token ids - special_tag (List): special annotation token id pairs + self.relcat_config: ConfigRelCAT = relcat_config + self.model_config: BaseConfig_RelationExtraction = model_config + self.hf_model: Union[ModernBertModel, BertModel, LlamaModel, PreTrainedModel] = PreTrainedModel(config=model_config.hf_model_config) # type: ignore + self.pretrained_model_name_or_path: str = pretrained_model_name_or_path - Returns: - torch.Tensor: new seq_tags - """ + self._reinitialize_dense_and_frozen_layers(relcat_config=relcat_config) - idx_start = torch.where(input_ids == special_tag[0]) # returns: row ids, idx of token[0]/star token in row - idx_end = torch.where(input_ids == special_tag[1]) # returns: row ids, idx of token[1]/end token in row + self.log.info("RelCAT model config: " + str(self.model_config.hf_model_config)) - seen = [] # List to store seen elements and their indices - duplicate_indices = [] + def _reinitialize_dense_and_frozen_layers(self, relcat_config: ConfigRelCAT) -> None: + """ Reinitialize the dense layers of the model - for i in range(len(idx_start[0])): - if idx_start[0][i] in seen: - duplicate_indices.append(i) + Args: + relcat_config (ConfigRelCAT): relcat config. + """ + + self.drop_out = nn.Dropout(relcat_config.model.dropout) + self.fc1, self.fc2, self.fc3 = create_dense_layers(relcat_config) + + for param in self.hf_model.parameters(): # type: ignore + if self.relcat_config.model.freeze_layers: + param.requires_grad = False else: - seen.append(idx_start[0][i]) + param.requires_grad = True - if len(duplicate_indices) > 0: - self.log.info("Duplicate entities found, removing them...") - for idx_remove in duplicate_indices: - idx_start_0 = torch.cat((idx_start[0][:idx_remove], idx_start[0][idx_remove + 1:])) - idx_start_1 = torch.cat((idx_start[1][:idx_remove], idx_start[1][idx_remove + 1:])) - idx_start = (idx_start_0, idx_start_1) # type: ignore + def forward(self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Any = None, + head_mask: Any = None, + encoder_hidden_states: Any = None, + encoder_attention_mask: Any = None, + Q: Any = None, + e1_e2_start: Any = None, + pooled_output: Any = None) -> Tuple[torch.Tensor, torch.Tensor]: - seen = [] - duplicate_indices = [] + if input_ids is not None: + input_shape = input_ids.size() + else: + raise ValueError("You have to specify input_ids") - for i in range(len(idx_end[0])): - if idx_end[0][i] in seen: - duplicate_indices.append(i) - else: - seen.append(idx_end[0][i]) + if attention_mask is None: + attention_mask = torch.ones( + input_shape, device=self.relcat_config.general.device) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + input_shape, device=self.relcat_config.general.device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.relcat_config.general.device) - if len(duplicate_indices) > 0: - self.log.info("Duplicate entities found, removing them...") - for idx_remove in duplicate_indices: - idx_end_0 = torch.cat((idx_end[0][:idx_remove], idx_end[0][idx_remove + 1:])) - idx_end_1 = torch.cat((idx_end[1][:idx_remove], idx_end[1][idx_remove + 1:])) - idx_end = (idx_end_0, idx_end_1) # type: ignore + input_ids = input_ids.to(self.relcat_config.general.device) + attention_mask = attention_mask.to(self.relcat_config.general.device) + encoder_attention_mask = encoder_attention_mask.to( + self.relcat_config.general.device) - assert len(idx_start[0]) == input_ids.shape[0] - assert len(idx_start[0]) == len(idx_end[0]) - sequence_output_entities = [] + self.hf_model = self.hf_model.to(self.relcat_config.general.device) # type: ignore - for i in range(len(idx_start[0])): - to_append = sequence_output[i, idx_start[1][i] + 1:idx_end[1][i], ] + model_output = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask) # type: ignore - # to_append = torch.sum(to_append, dim=0) - to_append, _ = torch.max(to_append, axis=0) # type: ignore + # (batch_size, sequence_length, hidden_size) + sequence_output = model_output[0] + pooled_output = model_output[1] - sequence_output_entities.append(to_append) - sequence_output_entities = torch.stack(sequence_output_entities) + classification_logits = self.output2logits( + pooled_output, sequence_output, input_ids, e1_e2_start) + + return model_output, classification_logits.to(self.relcat_config.general.device) - return sequence_output_entities def output2logits(self, pooled_output: torch.Tensor, sequence_output: torch.Tensor, input_ids: torch.Tensor, e1_e2_start: torch.Tensor) -> torch.Tensor: """ @@ -138,7 +189,7 @@ def output2logits(self, pooled_output: torch.Tensor, sequence_output: torch.Tens # for each pair of tags (e1,s1) and (e2,s2) for each_tags in annotation_schema_tag_ids_: - seq_tags.append(self.get_annotation_schema_tag( + seq_tags.append(get_annotation_schema_tag( sequence_output, input_ids, each_tags)) seq_tags = torch.stack(seq_tags, dim=0) @@ -168,52 +219,45 @@ def output2logits(self, pooled_output: torch.Tensor, sequence_output: torch.Tens x = self.drop_out(x) x = self.fc2(x) classification_logits = self.fc3(x) - return classification_logits.to(self.relcat_config.general.device) - - def forward(self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Any = None, - head_mask: Any = None, - encoder_hidden_states: Any = None, - encoder_attention_mask: Any = None, - Q: Any = None, - e1_e2_start: Any = None, - pooled_output: Any = None) -> Tuple[torch.Tensor, torch.Tensor]: - if input_ids is not None: - input_shape = input_ids.size() - else: - raise ValueError("You have to specify input_ids") + return classification_logits.to(self.relcat_config.general.device) - if attention_mask is None: - attention_mask = torch.ones( - input_shape, device=self.relcat_config.general.device) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - input_shape, device=self.relcat_config.general.device) - if token_type_ids is None: - token_type_ids = torch.zeros( - input_shape, dtype=torch.long, device=self.relcat_config.general.device) + @classmethod + def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: BaseConfig_RelationExtraction) -> "BaseModel_RelationExtraction": + """ Load the model from the given path - input_ids = input_ids.to(self.relcat_config.general.device) - attention_mask = attention_mask.to(self.relcat_config.general.device) - encoder_attention_mask = encoder_attention_mask.to( - self.relcat_config.general.device) + Args: + pretrained_model_name_or_path (str): path to load the model from. + relcat_config (ConfigRelCAT): relcat config. + model_config (BaseConfig_RelationExtraction): The model-specific config. - self.bert_model = self.bert_model.to(self.relcat_config.general.device) + returns: + BaseModel_RelationExtraction: The loaded model. + """ - model_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask, - token_type_ids=token_type_ids, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask) + model = BaseModel_RelationExtraction(relcat_config=relcat_config, model_config=model_config, pretrained_model_name_or_path=pretrained_model_name_or_path) + + if "modern-bert" in relcat_config.general.tokenizer_name or \ + "modern-bert" in relcat_config.general.model_name: + from medcat.utils.relation_extraction.modernbert.model import ModernBertModel_RelationExtraction + model = ModernBertModel_RelationExtraction.load(pretrained_model_name_or_path, relcat_config=relcat_config, model_config=model_config) + elif "bert" in relcat_config.general.tokenizer_name or \ + "bert" in relcat_config.general.model_name: + from medcat.utils.relation_extraction.bert.model import BertModel_RelationExtraction + model = BertModel_RelationExtraction.load(pretrained_model_name_or_path, relcat_config=relcat_config, model_config=model_config) + elif "llama" in relcat_config.general.tokenizer_name or \ + "llama" in relcat_config.general.model_name: + from medcat.utils.relation_extraction.llama.model import LlamaModel_RelationExtraction + model = LlamaModel_RelationExtraction.load(pretrained_model_name_or_path, relcat_config=relcat_config, model_config=model_config) + else: + if pretrained_model_name_or_path: + model.hf_model = PreTrainedModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path, config=model_config) + else: + model.hf_model = PreTrainedModel.from_pretrained(pretrained_model_name_or_path=relcat_config.general.model_name, config=model_config) + cls.log.info("Loaded model from relcat_config: " + relcat_config.general.model_name) - # (batch_size, sequence_length, hidden_size) - sequence_output = model_output[0] - pooled_output = model_output[1] + cls.log.info("Loaded " + str(model.__class__.__name__) + " from pretrained_model_name_or_path: " + pretrained_model_name_or_path) - classification_logits = self.output2logits( - pooled_output, sequence_output, input_ids, e1_e2_start) + model._reinitialize_dense_and_frozen_layers(relcat_config=relcat_config) - return model_output, classification_logits.to(self.relcat_config.general.device) + return model diff --git a/medcat/utils/relation_extraction/modernbert/__init__.py b/medcat/utils/relation_extraction/modernbert/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat/utils/relation_extraction/modernbert/config.py b/medcat/utils/relation_extraction/modernbert/config.py new file mode 100644 index 000000000..454b07f5e --- /dev/null +++ b/medcat/utils/relation_extraction/modernbert/config.py @@ -0,0 +1,30 @@ +import logging +import os +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction +from transformers import ModernBertConfig + +logger = logging.getLogger(__name__) + + +class ModernBertConfig_RelationExtraction(BaseConfig_RelationExtraction): + """ Class for ModernBertConfig + """ + + name = 'modern-bert-config' + pretrained_model_name_or_path = "answerdotai/ModernBERT-base" + + @classmethod + def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "ModernBertConfig_RelationExtraction": + model_config = cls(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + + if pretrained_model_name_or_path and os.path.exists(pretrained_model_name_or_path): + model_config.hf_model_config = ModernBertConfig.from_json_file(pretrained_model_name_or_path) + logger.info("Loaded config from file: " + pretrained_model_name_or_path) + else: + relcat_config.general.model_name = cls.pretrained_model_name_or_path + model_config.hf_model_config = ModernBertConfig.from_pretrained( + pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs) + logger.info("Loaded config from pretrained: " + relcat_config.general.model_name) + + return model_config diff --git a/medcat/utils/relation_extraction/modernbert/model.py b/medcat/utils/relation_extraction/modernbert/model.py new file mode 100644 index 000000000..a78c8af98 --- /dev/null +++ b/medcat/utils/relation_extraction/modernbert/model.py @@ -0,0 +1,77 @@ +import logging +from typing import Union +import torch +import os +from torch import nn +from transformers import ModernBertModel + +from medcat.config_rel_cat import ConfigRelCAT +from transformers import PreTrainedModel +from medcat.utils.relation_extraction.ml_utils import create_dense_layers +from medcat.utils.relation_extraction.models import BaseModel_RelationExtraction +from medcat.utils.relation_extraction.config import BaseConfig_RelationExtraction +from medcat.utils.relation_extraction.modernbert.config import ModernBertConfig_RelationExtraction + + + +class ModernBertModel_RelationExtraction(BaseModel_RelationExtraction): + """ ModernBertModel class for RelCAT + """ + + name = "modernbertmodel_relcat" + + log = logging.getLogger(__name__) + + def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, ModernBertConfig_RelationExtraction]): + """ Class to hold the ModernBERT model + model_config + + Args: + pretrained_model_name_or_path (str): path to load the model from, + this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' + using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model. + relcat_config (ConfigRelCAT): relcat config. + model_config (Union[BaseConfig_RelationExtraction | ModernBertConfig_RelationExtraction]): HF bert config for model. + """ + super(ModernBertModel_RelationExtraction, self).__init__(pretrained_model_name_or_path=pretrained_model_name_or_path, + relcat_config=relcat_config, + model_config=model_config) + + self.relcat_config: ConfigRelCAT = relcat_config + self.model_config: Union[BaseConfig_RelationExtraction, ModernBertConfig_RelationExtraction] = model_config + self.pretrained_model_name_or_path: str = pretrained_model_name_or_path + + self.hf_model: Union[ModernBertModel, PreTrainedModel] = ModernBertModel(config=model_config.hf_model_config) + + for param in self.hf_model.parameters(): # type: ignore + if self.relcat_config.model.freeze_layers: + param.requires_grad = False + else: + param.requires_grad = True + + self.drop_out = nn.Dropout(self.relcat_config.model.dropout) + + # dense layers + self.fc1, self.fc2, self.fc3 = create_dense_layers(self.relcat_config) + + @classmethod + def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: Union[BaseConfig_RelationExtraction, ModernBertConfig_RelationExtraction], **kwargs) -> "ModernBertModel_RelationExtraction": + + model = ModernBertModel_RelationExtraction(pretrained_model_name_or_path=pretrained_model_name_or_path, + relcat_config=relcat_config, + model_config=model_config) + + model_path = os.path.join(pretrained_model_name_or_path, "model.dat") + + if os.path.exists(model_path): + model.load_state_dict(torch.load(model_path, map_location=relcat_config.general.device)) + cls.log.info("Loaded model from file: " + model_path) + elif pretrained_model_name_or_path: + model.hf_model = ModernBertModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, config=model_config.hf_model_config, ignore_mismatched_sizes=True, **kwargs) + cls.log.info("Loaded model from pretrained: " + pretrained_model_name_or_path) + else: + model.hf_model = ModernBertModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, config=model_config.hf_model_config, ignore_mismatched_sizes=True, **kwargs) + cls.log.info("Loaded model from pretrained: " + cls.pretrained_model_name_or_path) + + return model diff --git a/medcat/utils/relation_extraction/modernbert/tokenizer.py b/medcat/utils/relation_extraction/modernbert/tokenizer.py new file mode 100644 index 000000000..b48789021 --- /dev/null +++ b/medcat/utils/relation_extraction/modernbert/tokenizer.py @@ -0,0 +1,34 @@ +import os +from transformers import PreTrainedTokenizerFast +import logging + +from medcat.config_rel_cat import ConfigRelCAT +from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction + +logger = logging.getLogger(__name__) + + +class TokenizerWrapperModernBERT_RelationExtraction(BaseTokenizerWrapper_RelationExtraction): + ''' Wrapper around a huggingface ModernBERT tokenizer so that it works with the + RelCAT models. + + Args: + hf_tokenizers (`transformers.PreTrainedTokenizerFast`): + A huggingface Fast tokenizer. + ''' + name = "tokenizer_wrapper_modern_bert_rel" + pretrained_model_name_or_path = "answerdotai/ModernBERT-base" + + @classmethod + def load(cls, tokenizer_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "TokenizerWrapperModernBERT_RelationExtraction": + tokenizer = cls() + path = os.path.join(tokenizer_path, cls.name) + + if tokenizer_path: + tokenizer.hf_tokenizers = PreTrainedTokenizerFast.from_pretrained( + path, **kwargs) + else: + relcat_config.general.model_name = cls.pretrained_model_name_or_path + tokenizer.hf_tokenizers = PreTrainedTokenizerFast.from_pretrained( + path=relcat_config.general.model_name) + return tokenizer diff --git a/medcat/utils/relation_extraction/rel_dataset.py b/medcat/utils/relation_extraction/rel_dataset.py index ce2f26f6b..c9b134f12 100644 --- a/medcat/utils/relation_extraction/rel_dataset.py +++ b/medcat/utils/relation_extraction/rel_dataset.py @@ -1,15 +1,14 @@ from ast import literal_eval from typing import Any, Iterable, List, Dict, Tuple, Union from torch.utils.data import Dataset -from spacy.tokens import Doc +from spacy.tokens import Doc, Span import logging import pandas import random import torch from medcat.cdb import CDB from medcat.config_rel_cat import ConfigRelCAT -from medcat.utils.meta_cat.data_utils import Span -from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT +from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction class RelData(Dataset): @@ -18,7 +17,7 @@ class RelData(Dataset): log = logging.getLogger(__name__) - def __init__(self, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT, cdb: CDB = CDB()): + def __init__(self, tokenizer: BaseTokenizerWrapper_RelationExtraction, config: ConfigRelCAT, cdb: CDB = CDB()): """ Use this class to create a dataset for relation annotations from CSV exports, MedCAT exports or Spacy Documents (assuming the documents got generated by MedCAT, if they did not then please set the required parameters manually to match MedCAT output, @@ -28,7 +27,8 @@ def __init__(self, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT, cdb: C interest are surrounded by the special tokens, see create_base_relations_from_csv doc. Args: - tokenizer (TokenizerWrapperBERT): okenizer used to generate token ids from input text + tokenizer (BaseTokenizerWrapper_RelationExtraction): + tokenizer used to generate token ids from input text config (ConfigRelCAT): same config used in RelCAT cdb (CDB): Optional, used to add concept ids and types to detected ents, useful when creating datasets from MedCAT output. Defaults to CDB(). @@ -36,7 +36,7 @@ def __init__(self, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT, cdb: C self.cdb: CDB = cdb self.config: ConfigRelCAT = config - self.tokenizer: TokenizerWrapperBERT = tokenizer + self.tokenizer: BaseTokenizerWrapper_RelationExtraction = tokenizer self.dataset: Dict[Any, Any] = {} self.log.setLevel(self.config.general.log_level) @@ -64,7 +64,7 @@ def generate_base_relations(self, docs: Iterable[Doc]) -> List[Dict]: return output_relations - def create_base_relations_from_csv(self, csv_path: str): + def create_base_relations_from_csv(self, csv_path: str, keep_source_text: bool = False): """ Assumes the columns are as follows ["relation_token_span_ids", "ent1_ent2_start", "ent1", "ent2", "label", "label_id", "ent1_type", "ent2_type", "ent1_id", "ent2_id", "ent1_cui", "ent2_cui", "doc_id", "sents"], @@ -77,6 +77,8 @@ def create_base_relations_from_csv(self, csv_path: str): Args: csv_path (str): path to csv file, must have specific columns, tab separated, + keep_source_text (bool): if the text clumn should be retained in the 'sents' df column, + used for debugging or creating custom datasets. Returns: Dict : { @@ -134,7 +136,10 @@ def create_base_relations_from_csv(self, csv_path: str): df["ent1_ent2_start"] = out_ent1_ent2_starts df = df.drop(index=rows_to_remove) - df = df.drop(columns=col) + text_col = df.pop(col) + df = df.assign(col=text_col) + if keep_source_text: + df = df.assign(col=text_col) break nclasses, labels2idx, idx2label = RelData.get_labels( @@ -149,17 +154,163 @@ def create_base_relations_from_csv(self, csv_path: str): for label_num in list(idx2label.keys()): sample_count = 0 for output_relation in output_relations: - if label_num == output_relation[5]: + if idx2label[label_num] == output_relation[4]: sample_count += 1 self.log.info( " label: " + idx2label[label_num] + " | samples: " + str(sample_count)) # replace/update label_id with actual detected label number for idx in range(len(output_relations)): - output_relations[idx][5] = labels2idx[output_relations[idx][4]] + output_relations[idx][5] = int(labels2idx[output_relations[idx][4]]) return {"output_relations": output_relations, "nclasses": nclasses, "labels2idx": labels2idx, "idx2label": idx2label} + def _create_relation_validation(self, + text: Union[str, Doc], + doc_id: str, + tokenized_text_data: Dict[str, Any], + ent1_start_char_pos: int, + ent2_start_char_pos: int, + ent1_end_char_pos: int, + ent2_end_char_pos: int, + ent1_token_start_pos: int = -1, + ent2_token_start_pos: int = -1, + ent1_token_end_pos: int = -1, + ent2_token_end_pos: int = -1, + is_spacy_doc: bool = False, + is_mct_export: bool = False, + ) -> List: + """ + This function checks if the relation is actually valid by distance criteria, TUIs and so on. + Has diffierent handling cases for text, spacy docs and MCT exports. + + Args: + text (str): doc text + doc_id (str): doc id + tokenized_text_data (Dict[str, Any]): tokenized text + ent1_start_char_pos (int): ent1 start char pos + ent2_start_char_pos (int): ent2 start char pos + ent1_end_char_pos (int): ent1 end char pos + ent2_end_char_pos (int): ent2 end char pos + ent1_token_start_pos (int): ent1_token_start_pos. Defaults to -1. + ent2_token_start_pos (int): ent2_token_start_pos. Defaults to -1. + ent1_token_end_pos (int): ent1_token_end_pos. Defaults to -1. + ent2_token_end_pos (int): ent2_token_end_pos. Defaults to -1. + is_spacy_doc (bool): checks if doc is spacy docs. Defaults to False. + is_mct_export (bool): chekcs if doc is a mct export. Defaults to False. + + Returns: + List: row containing rel data ["relation_token_span_ids", "ent1_ent2_start", "ent1", "ent2", "label", + "label_id", "ent1_type", "ent2_type", "ent1_id", "ent2_id", "ent1_cui", "ent2_cui", "doc_id", "sents"] + """ + + text_length:int = len(text) + + doc_token_length: int = len(tokenized_text_data["tokens"]) + + tmp_doc_text = text + + ent1_token: Union[str, Span] = tmp_doc_text[ent1_start_char_pos: ent1_end_char_pos] + ent2_token: Union[str, Span] = tmp_doc_text[ent2_start_char_pos: ent2_end_char_pos] + + if abs(ent2_start_char_pos - ent1_start_char_pos) <= self.config.general.window_size and \ + ent1_token != ent2_token: + + ent1_left_ent_context_token_pos_end = ent1_token_start_pos - self.config.general.cntx_left + left_context_start_char_pos = 0 + + if ent1_left_ent_context_token_pos_end < 0: + ent1_left_ent_context_token_pos_end = 0 + else: + left_context_start_char_pos = tokenized_text_data["offset_mapping"][ent1_left_ent_context_token_pos_end][0] + + ent2_right_ent_context_token_pos_end = ent2_token_end_pos + self.config.general.cntx_right + right_context_end_char_pos = text_length + + # get correct position, don't get last token as it can be the [SEP] or [EOS] token. + if ent2_right_ent_context_token_pos_end >= (doc_token_length - 1): + ent2_right_ent_context_token_pos_end = doc_token_length - 2 + else: + right_context_end_char_pos = tokenized_text_data["offset_mapping"][ent2_right_ent_context_token_pos_end][1] + + if left_context_start_char_pos > right_context_end_char_pos: + tmp = right_context_end_char_pos + right_context_end_char_pos = left_context_start_char_pos + left_context_start_char_pos = tmp + + if is_spacy_doc or is_mct_export: + tmp_doc_text = text + _pre_e1 = tmp_doc_text[0: (ent1_start_char_pos)] + _e1_s2 = tmp_doc_text[ent1_end_char_pos: ent2_start_char_pos - 1] + _e2_end = tmp_doc_text[ent2_end_char_pos + 1: text_length] + ent2_token_end_pos = (ent2_token_end_pos + 2) + + annotation_token_text = self.tokenizer.hf_tokenizers.convert_ids_to_tokens( + self.config.general.annotation_schema_tag_ids) + + tmp_doc_text = str(_pre_e1) + " " + \ + annotation_token_text[0] + " " + \ + str(ent1_token) + " " + \ + annotation_token_text[1] + " " + str(_e1_s2) + " " + \ + annotation_token_text[2] + " " + str(ent2_token) + " " + \ + annotation_token_text[3] + " " + str(_e2_end) + + ann_tag_token_len = len(annotation_token_text[0]) + + _left_context_start_char_pos = left_context_start_char_pos - ann_tag_token_len - 2 # - 2 spaces + left_context_start_char_pos = 0 if _left_context_start_char_pos <= 0 \ + else _left_context_start_char_pos + + _right_context_start_end_pos = right_context_end_char_pos + (ann_tag_token_len * 4) + 8 # 8 for spces + right_context_end_char_pos = len(tmp_doc_text) + 1 if right_context_end_char_pos >= len(tmp_doc_text) or \ + _right_context_start_end_pos >= len(tmp_doc_text) else _right_context_start_end_pos + + # reassign the new text with added tags + text_length = len(tmp_doc_text) + + # may lead to problems down the line if truncation=False, if it is True we enforce the max 512 token length sentence + # take care when using window_size > 300, we want to make sure both entities are included at least... otherwise the relation + # is considered invalid + window_tokenizer_data = self.tokenizer(tmp_doc_text[left_context_start_char_pos:right_context_end_char_pos], truncation=True) + + if self.config.general.annotation_schema_tag_ids: + try: + ent1_token_start_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[0]) + ent2_token_start_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[2]) + _ent1_token_end_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[1]) + _ent2_token_end_pos = \ + window_tokenizer_data["input_ids"].index( + self.config.general.annotation_schema_tag_ids[3]) + assert ent1_token_start_pos + assert ent2_token_start_pos + assert _ent1_token_end_pos + assert _ent2_token_end_pos + except Exception as exception: + self.log.error("document id : " + str(doc_id) + " failed to process relation") + self.log.info(exception) + return [] + + if not self.config.general.annotation_schema_tag_ids: + # update token loc to match new selection + ent2_token_start_pos = ent2_token_start_pos - ent1_token_start_pos + ent1_token_start_pos = self.config.general.cntx_left if ent1_token_start_pos - self.config.general.cntx_left > 0 else ent1_token_start_pos + ent2_token_start_pos += ent1_token_start_pos + + ent1_ent2_new_start = (ent1_token_start_pos, ent2_token_start_pos) + en1_start, en1_end = window_tokenizer_data["offset_mapping"][ent1_token_start_pos] + en2_start, en2_end = window_tokenizer_data["offset_mapping"][ent2_token_start_pos] + + return [window_tokenizer_data["input_ids"], ent1_ent2_new_start, ent1_token, ent2_token, "UNK", self.config.model.padding_idx, + None, None, None, None, None, None, doc_id, "", + en1_start, en1_end, en2_start, en2_end] + return [] + def create_base_relations_from_doc(self, doc: Union[Doc, str], doc_id: str, ent1_ent2_tokens_start_pos: Union[List, Tuple] = (-1, -1)) -> Dict: """ Creates a list of tuples based on pairs of entities detected (relation, ent1, ent2) for one spacy document or text string. @@ -178,219 +329,147 @@ def create_base_relations_from_doc(self, doc: Union[Doc, str], doc_id: str, ent1 "idx2label": {}} } """ - relation_instances = [] + + _ent1_start_tkn_id, _ent1_end_tkn_id, _ent2_start_tkn_id, _ent2_end_tkn_id = 0, 0, 0, 0 chars_to_exclude = ":!@#$%^&*()-+?_=.,;<>/[]{}" - tokenizer_data = None + + if self.config.general.annotation_schema_tag_ids: + # we assume that ent1 start token is pos 0 and ent2 start token is pos 2 + # e.g: [s1], [e1], [s2], [e2] + _ent1_start_tkn_id = self.config.general.annotation_schema_tag_ids[0] + _ent1_end_tkn_id = self.config.general.annotation_schema_tag_ids[1] + _ent2_start_tkn_id = self.config.general.annotation_schema_tag_ids[2] + _ent2_end_tkn_id = self.config.general.annotation_schema_tag_ids[3] + + relation_instances = [] + + tokenized_text_data = None if isinstance(doc, str): - tokenizer_data = self.tokenizer(doc, truncation=False) doc_text = doc elif isinstance(doc, Doc): - tokenizer_data = self.tokenizer(doc.text, truncation=False) doc_text = doc.text - doc_length = len(tokenizer_data["tokens"]) + tokenized_text_data = self.tokenizer(doc_text, truncation=False) - if ent1_ent2_tokens_start_pos != (-1, -1): - ent1_token_start_pos, ent2_token_start_pos = ent1_ent2_tokens_start_pos[0],\ - ent1_ent2_tokens_start_pos[1] - # add + 1 to the pos cause of [CLS] - if self.config.general.annotation_schema_tag_ids: - ent1_token_start_pos, ent2_token_start_pos = ent1_ent2_tokens_start_pos[0] + 1,\ - ent1_ent2_tokens_start_pos[1] + 1 + doc_length_tokens = len(tokenized_text_data["tokens"]) - ent1_start_char_pos, _ = tokenizer_data["offset_mapping"][ent1_token_start_pos] - ent2_start_char_pos, _ = tokenizer_data["offset_mapping"][ent2_token_start_pos] + if ent1_ent2_tokens_start_pos != (-1, -1) and isinstance(doc, str): + ent1_token_start_pos = tokenized_text_data["input_ids"].index(_ent1_start_tkn_id) + ent2_token_start_pos = tokenized_text_data["input_ids"].index(_ent2_start_tkn_id) + ent1_token_end_pos = tokenized_text_data["input_ids"].index(_ent1_end_tkn_id) + ent2_token_end_pos = tokenized_text_data["input_ids"].index(_ent2_end_tkn_id) - if abs(ent2_start_char_pos - ent1_start_char_pos) <= self.config.general.window_size: + ent1_start_char_pos, ent1_end_char_pos = tokenized_text_data["offset_mapping"][ent1_token_start_pos] + ent2_start_char_pos, ent2_end_char_pos = tokenized_text_data["offset_mapping"][ent2_token_start_pos] - ent1_left_ent_context_token_pos_end = ent1_token_start_pos - \ - self.config.general.cntx_left - - left_context_start_char_pos = 0 - right_context_start_end_pos = len(doc_text) - 1 - - if ent1_left_ent_context_token_pos_end < 0: - ent1_left_ent_context_token_pos_end = 0 - else: - left_context_start_char_pos = tokenizer_data[ - "offset_mapping"][ent1_left_ent_context_token_pos_end][0] - - ent2_right_ent_context_token_pos_end = ent2_token_start_pos + \ - self.config.general.cntx_right - - # get end of 2nd ent token (if using tags) - if self.config.general.annotation_schema_tag_ids: - far_pos = -1 - for tkn_id in self.config.general.annotation_schema_tag_ids: - pos = [i for i in range( - 0, doc_length) if tokenizer_data["input_ids"][i] == tkn_id][0] - far_pos = pos if far_pos < pos else far_pos - ent2_right_ent_context_token_pos_end = far_pos - - if ent2_right_ent_context_token_pos_end >= doc_length - 1: - ent2_right_ent_context_token_pos_end = doc_length - 2 - else: - right_context_start_end_pos = tokenizer_data[ - "offset_mapping"][ent2_right_ent_context_token_pos_end][1] - - ent1_token = tokenizer_data["tokens"][ent1_token_start_pos] - ent2_token = tokenizer_data["tokens"][ent2_token_start_pos] - - window_tokenizer_data = self.tokenizer( - doc_text[left_context_start_char_pos:right_context_start_end_pos]) - - # update token loc to match new selection - if self.config.general.annotation_schema_tag_ids: - ent1_token_start_pos = \ - window_tokenizer_data["input_ids"].index( - self.config.general.annotation_schema_tag_ids[0]) - ent2_token_start_pos = \ - window_tokenizer_data["input_ids"].index( - self.config.general.annotation_schema_tag_ids[2]) - else: - ent2_token_start_pos = ent2_token_start_pos - ent1_token_start_pos - ent1_token_start_pos = self.config.general.cntx_left if ent1_token_start_pos - \ - self.config.general.cntx_left > 0 else ent1_token_start_pos - ent2_token_start_pos += ent1_token_start_pos - - ent1_ent2_new_start = ( - ent1_token_start_pos, ent2_token_start_pos) - - en1_start, en1_end = window_tokenizer_data["offset_mapping"][ent1_token_start_pos] - en2_start, en2_end = window_tokenizer_data["offset_mapping"][ent2_token_start_pos] - - relation_instances.append([window_tokenizer_data["input_ids"], ent1_ent2_new_start, ent1_token, ent2_token, "UNK", self.config.model.padding_idx, - None, None, None, None, None, None, doc_id, "", - en1_start, en1_end, en2_start, en2_end]) + relation_instances.append(self._create_relation_validation(text=doc_text, + doc_id=doc_id, + tokenized_text_data=tokenized_text_data, + ent1_start_char_pos=ent1_start_char_pos, + ent2_start_char_pos=ent2_start_char_pos, + ent1_end_char_pos=ent1_end_char_pos, + ent2_end_char_pos=ent2_end_char_pos, + ent1_token_start_pos=ent1_token_start_pos, + ent2_token_start_pos=ent2_token_start_pos, + ent1_token_end_pos=ent1_token_end_pos, + ent2_token_end_pos=ent2_token_end_pos + )) elif isinstance(doc, Doc): - _ents = doc.ents if len(doc.ents) > 0 else doc._.ents - for ent1_idx in range(0, len(_ents) - 1): + # last two can be a pair + for ent1_idx in range(0, len(_ents) - 2): ent1_token: Span = _ents[ent1_idx] # type: ignore - if str(ent1_token) not in chars_to_exclude: - ent1_type_id = list( - self.cdb.cui2type_ids.get(ent1_token._.cui, '')) - ent1_types = [self.cdb.addl_info['type_id2name'].get( - tui, '') for tui in ent1_type_id] - - ent2pos = ent1_idx + 1 - - ent1_start = ent1_token.start - ent1_end = ent1_token.end - - # get actual token index from the text - _ent1_token_idx = [i for i in range(len(tokenizer_data["offset_mapping"])) if ent1_start in - range( - tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1) - or ent1_end in range(tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1) - ][0] - - left_context_start_char_pos = 0 - ent1_left_ent_context_token_pos_end = _ent1_token_idx - self.config.general.cntx_left + if str(ent1_token) not in chars_to_exclude and str(ent1_token) not in self.tokenizer.hf_tokenizers.all_special_tokens: + ent1_type_id = list(self.cdb.cui2type_ids.get(ent1_token._.cui, '')) + ent1_types = [self. cdb.addl_info["type_id2name"].get(tui, '') for tui in ent1_type_id] + + ent1_start_char_pos = ent1_token.start_char + ent1_end_char_pos = ent1_token.end_char + + ent1_token_start_pos = [i for i in range(0, doc_length_tokens) if ent1_start_char_pos + in range(tokenized_text_data["offset_mapping"][i][0], tokenized_text_data["offset_mapping"][i][1] + 1)][0] + ent1_token_end_pos = [i for i in range(0, doc_length_tokens) if ent1_end_char_pos + in range(tokenized_text_data["offset_mapping"][i][0], tokenized_text_data["offset_mapping"][i][1] + 1)][0] + + for ent2_idx in range((ent1_idx + 1), len(_ents) - 1): + ent2_token: Span = _ents[ent2_idx] + + tmp_ent1 = ent1_token + + if ent1_token.start_char > ent2_token.start_char: + tmp_ent1 = ent1_token + ent1_token = ent2_token + ent2_token = tmp_ent1 + + if str(ent2_token) not in chars_to_exclude and \ + str(ent2_token) not in self.tokenizer.hf_tokenizers.all_special_tokens and \ + str(ent1_token).strip() != str(ent2_token).strip(): + + ent2_type_id = list(self.cdb.cui2type_ids.get(ent2_token._.cui, '')) + ent2_types = [self.cdb.addl_info['type_id2name'].get(tui, '') for tui in ent2_type_id] + + ent2_start_char_pos = ent2_token.start_char + ent2_end_char_pos = ent2_token.end_char + + ent2_token_start_pos = [i for i in range(0, doc_length_tokens) if ent2_start_char_pos + in range(tokenized_text_data["offset_mapping"][i][0], tokenized_text_data["offset_mapping"][i][1] + 1)][0] + + ent2_token_end_pos = [i for i in range(0, doc_length_tokens) if ent2_end_char_pos + in range(tokenized_text_data["offset_mapping"][i][0], tokenized_text_data["offset_mapping"][i][1] + 1)][0] + + if self.config.general.relation_type_filter_pairs: + for rel_pair in self.config.general.relation_type_filter_pairs: + if rel_pair[0] in ent1_types and rel_pair[1] in ent2_types: + relation_instances.append(self._create_relation_validation(text=doc_text, + doc_id=doc_id, + tokenized_text_data=tokenized_text_data, + ent1_start_char_pos=ent1_start_char_pos, + ent2_start_char_pos=ent2_start_char_pos, + ent1_end_char_pos=ent1_end_char_pos, + ent2_end_char_pos=ent2_end_char_pos, + ent1_token_start_pos=ent1_token_start_pos, + ent2_token_start_pos=ent2_token_start_pos, + ent1_token_end_pos=ent1_token_end_pos, + ent2_token_end_pos=ent2_token_end_pos, + is_spacy_doc=True + )) + else: + relation_instances.append(self._create_relation_validation(text=doc_text, + doc_id=doc_id, + tokenized_text_data=tokenized_text_data, + ent1_start_char_pos=ent1_start_char_pos, + ent2_start_char_pos=ent2_start_char_pos, + ent1_end_char_pos=ent1_end_char_pos, + ent2_end_char_pos=ent2_end_char_pos, + ent1_token_start_pos=ent1_token_start_pos, + ent2_token_start_pos=ent2_token_start_pos, + ent1_token_end_pos=ent1_token_end_pos, + ent2_token_end_pos=ent2_token_end_pos, + is_spacy_doc=True + )) + + # restore ent1 + ent1_token = tmp_ent1 + + # remove duplicates by using ent1_ent2_start_pos + dupe_ent1_ent2_start = [] + + _new_rel_instances = [] + for rel in relation_instances: + if rel != []: + if rel[1] not in dupe_ent1_ent2_start: + dupe_ent1_ent2_start.append(rel[1]) + _new_rel_instances.append(rel) + else: + self.log.debug("removing duplicate relation" + str(rel[1])) - if ent1_left_ent_context_token_pos_end < 0: - ent1_left_ent_context_token_pos_end = 0 - else: - left_context_start_char_pos = tokenizer_data[ - "offset_mapping"][ent1_left_ent_context_token_pos_end][0] - - for ent2_idx in range(ent2pos, len(_ents)): - ent2_token: Span = _ents[ent2_idx] # type: ignore - - if ent2_token in _ents: - if str(ent2_token) not in chars_to_exclude and str(ent1_token) != str(ent2_token): - ent2_type_id = list( - self.cdb.cui2type_ids.get(ent2_token._.cui, '')) - ent2_types = [self.cdb.addl_info['type_id2name'].get( - tui, '') for tui in ent2_type_id] - - ent2_start = ent2_token.start - ent2_end = ent2_token.end - if ent2_start - ent1_start <= self.config.general.window_size and ent2_start - ent1_start > 0: - _ent2_token_idx = [i for i in range(len(tokenizer_data["offset_mapping"])) if ent2_start in - range( - tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1) - or ent2_end in - range( - tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1) - ][0] - - right_context_start_end_pos = len( - doc_text) - 1 - ent2_right_ent_context_token_pos_end = _ent2_token_idx + \ - self.config.general.cntx_right - - if ent2_right_ent_context_token_pos_end >= doc_length - 1: - ent2_right_ent_context_token_pos_end = doc_length - 2 - else: - right_context_start_end_pos = tokenizer_data[ - "offset_mapping"][ent2_right_ent_context_token_pos_end][1] - - tmp_doc_text = doc_text - - # check if a tag is present, and if not so then insert the custom annotation tags in - if self.config.general.annotation_schema_tag_ids[0] not in tokenizer_data["input_ids"]: - _pre_e1 = tmp_doc_text[0: (ent1_start)] - _e1_s2 = tmp_doc_text[( - ent1_end): (ent2_start)] - _e2_end = tmp_doc_text[( - ent2_end): len(doc_text)] - _ent2_token_idx = (_ent2_token_idx + 2) - - annotation_token_text = self.tokenizer.hf_tokenizers.convert_ids_to_tokens( - self.config.general.annotation_schema_tag_ids) - - tmp_doc_text = _pre_e1 + " " + \ - annotation_token_text[0] + " " + \ - str(ent1_token) + " " + \ - annotation_token_text[1] + " " + _e1_s2 + " " + \ - annotation_token_text[2] + " " + str(ent2_token) + " " + \ - annotation_token_text[3] + \ - " " + _e2_end - - ann_tag_token_len = len( - annotation_token_text[0]) - - _left_context_start_char_pos = left_context_start_char_pos - ann_tag_token_len - left_context_start_char_pos = 0 if _left_context_start_char_pos <= 0 \ - else _left_context_start_char_pos - - right_context_start_end_pos = right_context_start_end_pos if right_context_start_end_pos >= len(tmp_doc_text) \ - else right_context_start_end_pos + (ann_tag_token_len * 4) - - window_tokenizer_data = self.tokenizer( - tmp_doc_text[left_context_start_char_pos:right_context_start_end_pos]) - - if self.config.general.annotation_schema_tag_ids: - ent1_token_start_pos = \ - window_tokenizer_data["input_ids"].index( - self.config.general.annotation_schema_tag_ids[0]) - ent2_token_start_pos = \ - window_tokenizer_data["input_ids"].index( - self.config.general.annotation_schema_tag_ids[2]) - else: - ent2_token_start_pos = _ent2_token_idx - _ent1_token_idx if _ent1_token_idx - \ - self.config.general.cntx_left > 0 else _ent2_token_idx - ent1_token_start_pos = self.config.general.cntx_left if _ent1_token_idx - \ - self.config.general.cntx_left > 0 else _ent1_token_idx - ent2_token_start_pos += ent1_token_start_pos - - ent1_ent2_new_start = ( - ent1_token_start_pos, ent2_token_start_pos) - - en1_start, en1_end = window_tokenizer_data[ - "offset_mapping"][ent1_token_start_pos] - en2_start, en2_end = window_tokenizer_data[ - "offset_mapping"][ent2_token_start_pos] - - relation_instances.append([window_tokenizer_data["input_ids"], ent1_ent2_new_start, ent1_token, ent2_token, "UNK", self.config.model.padding_idx, - ent1_types, ent2_types, ent1_token._.id, ent2_token._.id, ent1_token._.cui, ent2_token._.cui, doc_id, "", - en1_start, en1_end, en2_start, en2_end]) + # cleanup + relation_instances = _new_rel_instances return {"output_relations": relation_instances, "nclasses": self.config.model.padding_idx, "labels2idx": {}, "idx2label": {}} @@ -412,69 +491,72 @@ def create_relations_from_export(self, data: Dict): output_relations = [] - relation_type_filter_pairs = self.config.general.relation_type_filter_pairs + for project in data["projects"]: + for _doc_id, document in enumerate(project["documents"]): + doc_text: str = str(document["text"]) + doc_id: str = str(document["id"]) - annotation_token_text = self.tokenizer.hf_tokenizers.convert_ids_to_tokens( - self.config.general.annotation_schema_tag_ids) - - for project in data['projects']: - for doc_id, document in enumerate(project['documents']): - text = str(document['text']) - if len(text) > 0: - annotations = document['annotations'] - relations = document['relations'] + if len(doc_text) > 0: + annotations = document["annotations"] + relations = document["relations"] if self.config.general.lowercase: - text = text.lower() + doc_text = doc_text.lower() - tokenizer_data = self.tokenizer(text, truncation=False) + tokenizer_text_data = self.tokenizer(doc_text, truncation=False) - doc_length_tokens = len(tokenizer_data["tokens"]) + doc_token_length = len(tokenizer_text_data["tokens"]) relation_instances = [] - ann_ids_from_reliations = [] + ann_ids_from_relations = [] ann_ids_ents: Dict[Any, Any] = {} - _other_rel_subset = [] + _other_relations_subset = [] + # this section creates 'Other' class relations based on validated annotations for ent1_idx, ent1_ann in enumerate(annotations): - ann_id = ent1_ann['id'] + ann_id = ent1_ann["id"] ann_ids_ents[ann_id] = {} - ann_ids_ents[ann_id]['cui'] = ent1_ann['cui'] - ann_ids_ents[ann_id]['type_ids'] = list( - self.cdb.cui2type_ids.get(ent1_ann['cui'], '')) - ann_ids_ents[ann_id]['types'] = [self.cdb.addl_info['type_id2name'].get( - tui, '') for tui in ann_ids_ents[ann_id]['type_ids']] + ann_ids_ents[ann_id]["cui"] = ent1_ann["cui"] + ann_ids_ents[ann_id]["type_ids"] = list(self.cdb.cui2type_ids.get(ent1_ann["cui"], "")) + ann_ids_ents[ann_id]["types"] = [self.cdb.addl_info['type_id2name'].get(tui, '') for tui in ann_ids_ents[ann_id]['type_ids']] - if self.config.general.mct_export_create_addl_rels: + ent1_types = ann_ids_ents[ann_id]["types"] + if self.config.general.create_addl_rels: for _, ent2_ann in enumerate(annotations[ent1_idx + 1:]): - if abs(ent1_ann["start"] - ent2_ann["start"]) <= self.config.general.window_size: - if ent1_ann["validated"] and ent2_ann["validated"]: - _other_rel_subset.append({ - "start_entity": ent1_ann["id"], - "start_entity_cui": ent1_ann["cui"], - "start_entity_value": ent1_ann["value"], - "start_entity_start_idx": ent1_ann["start"], - "start_entity_end_idx": ent1_ann["end"], - "end_entity": ent2_ann["id"], - "end_entity_cui": ent2_ann["cui"], - "end_entity_value": ent2_ann["value"], - "end_entity_start_idx": ent2_ann["start"], - "end_entity_end_idx": ent2_ann["end"], - "relation": "Other", - "validated": True - }) - - non_rel_sample_size_limit = int(int( - self.config.general.mct_export_max_non_rel_sample_size) / len(data['projects'])) - - if non_rel_sample_size_limit > 0 and len(_other_rel_subset) > 0: - random.shuffle(_other_rel_subset) - _other_rel_subset = _other_rel_subset[0:non_rel_sample_size_limit] - - relations.extend(_other_rel_subset) + ent2_types = list(self.cdb.cui2type_ids.get(ent2_ann["cui"], "")) + + if ent1_ann["validated"] and ent2_ann["validated"]: + _relation_type = "Other" + + # create new Other subclass class if enabled + if self.config.general.create_addl_rels_by_type: + _relation_type = "Other" + ent1_types[0] + "-" + ent2_types[0] + + _other_relations_subset.append({ + "start_entity": ent1_ann["id"], + "start_entity_cui": ent1_ann["cui"], + "start_entity_value": ent1_ann["value"], + "start_entity_start_idx": ent1_ann["start"], + "start_entity_end_idx": ent1_ann["end"], + "end_entity": ent2_ann["id"], + "end_entity_cui": ent2_ann["cui"], + "end_entity_value": ent2_ann["value"], + "end_entity_start_idx": ent2_ann["start"], + "end_entity_end_idx": ent2_ann["end"], + "relation": _relation_type, + "validated": True + }) + + non_rel_sample_size_limit = int(int(self.config.general.addl_rels_max_sample_size) / len(data['projects'])) + + if non_rel_sample_size_limit > 0 and len(_other_relations_subset) > 0: + random.shuffle(_other_relations_subset) + _other_relations_subset = _other_relations_subset[0:non_rel_sample_size_limit] + + relations.extend(_other_relations_subset) for relation in relations: ann_start_start_pos = relation['start_entity_start_idx'] @@ -515,103 +597,60 @@ def create_relations_from_export(self, data: Dict): start_entity_id = relation['end_entity'] end_entity_id = relation['start_entity'] - for ent1type, ent2type in enumerate(relation_type_filter_pairs): + for ent1type, ent2type in enumerate(self.config.general.relation_type_filter_pairs): if ent1type not in start_entity_types and ent2type not in end_entity_types: continue - ann_ids_from_reliations.extend( - [start_entity_id, end_entity_id]) - + ann_ids_from_relations.extend([start_entity_id, end_entity_id]) relation_label = relation['relation'].strip() - if start_entity_id != end_entity_id and relation.get('validated', True): - if abs(ann_start_start_pos - ann_end_start_pos) <= self.config.general.window_size: - - ent1_token_start_pos = [i for i in range(0, doc_length_tokens) if ann_start_start_pos - in range(tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1)][0] - - ent2_token_start_pos = [i for i in range(0, doc_length_tokens) if ann_end_start_pos - in range(tokenizer_data["offset_mapping"][i][0], tokenizer_data["offset_mapping"][i][1] + 1)][0] - - ent1_left_ent_context_token_pos_end = ent1_token_start_pos - \ - self.config.general.cntx_left - - left_context_start_char_pos = 0 - right_context_start_end_pos = len(text) - 1 - - if ent1_left_ent_context_token_pos_end < 0: - ent1_left_ent_context_token_pos_end = 0 - else: - left_context_start_char_pos = tokenizer_data[ - "offset_mapping"][ent1_left_ent_context_token_pos_end][0] - - ent2_right_ent_context_token_pos_end = ent2_token_start_pos + \ - self.config.general.cntx_right - if ent2_right_ent_context_token_pos_end >= doc_length_tokens - 1: - ent2_right_ent_context_token_pos_end = doc_length_tokens - 2 - else: - right_context_start_end_pos = tokenizer_data[ - "offset_mapping"][ent2_right_ent_context_token_pos_end][1] - - tmp_text = text - # check if a tag is present, and if not so then insert the custom annotation tags in - if self.config.general.annotation_schema_tag_ids[0] not in tokenizer_data["input_ids"]: - _pre_e1 = text[0: (ann_start_start_pos)] - _e1_s2 = text[(ann_start_end_pos): ( - ann_end_start_pos)] - _e2_end = text[( - ann_end_end_pos): len(text)] - - tmp_text = _pre_e1 + " " + \ - annotation_token_text[0] + " " + \ - text[ann_start_start_pos:ann_start_end_pos] + " " + \ - annotation_token_text[1] + " " + \ - _e1_s2 + " " + \ - annotation_token_text[2] + " " + text[ann_end_start_pos:ann_end_end_pos] + \ - " " + \ - annotation_token_text[3] + \ - " " + _e2_end - - ann_tag_token_len = len( - annotation_token_text[0]) - - _left_context_start_char_pos = left_context_start_char_pos - ann_tag_token_len - 2 - left_context_start_char_pos = 0 if _left_context_start_char_pos <= 0 \ - else _left_context_start_char_pos - - _right_context_start_end_pos = right_context_start_end_pos + \ - (ann_tag_token_len * 4) + \ - 8 # 8 for spces - right_context_start_end_pos = len(tmp_text) if right_context_start_end_pos >= len(tmp_text) or _right_context_start_end_pos >= len(tmp_text) \ - else _right_context_start_end_pos - - window_tokenizer_data = self.tokenizer( - tmp_text[left_context_start_char_pos:right_context_start_end_pos]) - - if self.config.general.annotation_schema_tag_ids: - ent1_token_start_pos = \ - window_tokenizer_data["input_ids"].index( - self.config.general.annotation_schema_tag_ids[0]) - ent2_token_start_pos = \ - window_tokenizer_data["input_ids"].index( - self.config.general.annotation_schema_tag_ids[2]) - else: - # update token loc to match new selection - ent2_token_start_pos = ent2_token_start_pos - ent1_token_start_pos - ent1_token_start_pos = self.config.general.cntx_left if ent1_token_start_pos - \ - self.config.general.cntx_left > 0 else ent1_token_start_pos - ent2_token_start_pos += ent1_token_start_pos - - ent1_ent2_new_start = ( - ent1_token_start_pos, ent2_token_start_pos) - en1_start, en1_end = window_tokenizer_data[ - "offset_mapping"][ent1_token_start_pos] - en2_start, en2_end = window_tokenizer_data[ - "offset_mapping"][ent2_token_start_pos] - - relation_instances.append([window_tokenizer_data["input_ids"], ent1_ent2_new_start, start_entity_value, end_entity_value, relation_label, self.config.model.padding_idx, - start_entity_types, end_entity_types, start_entity_id, end_entity_id, start_entity_cui, end_entity_cui, doc_id, "", - en1_start, en1_end, en2_start, en2_end]) + try: + + ent1_token_start_pos = [i for i in range(0, doc_token_length) if ann_start_start_pos + in range(tokenizer_text_data["offset_mapping"][i][0], tokenizer_text_data["offset_mapping"][i][1] + 1)][0] + + ent2_token_start_pos = [i for i in range(0, doc_token_length) if ann_end_start_pos + in range(tokenizer_text_data["offset_mapping"][i][0], tokenizer_text_data["offset_mapping"][i][1] + 1)][0] + + ent1_token_end_pos = [i for i in range(0, doc_token_length) if ann_start_end_pos + in range(tokenizer_text_data["offset_mapping"][i][0], tokenizer_text_data["offset_mapping"][i][1] + 1)][0] + + ent2_token_end_pos = [i for i in range(0, doc_token_length) if ann_end_end_pos + in range(tokenizer_text_data["offset_mapping"][i][0], tokenizer_text_data["offset_mapping"][i][1] + 1)][0] + assert ent1_token_start_pos + assert ent2_token_start_pos + assert ent1_token_end_pos + assert ent2_token_end_pos + except Exception: + self.log.info("document id: " + str(doc_id) + " failed to process relation") + continue + + if start_entity_id != end_entity_id and relation.get("validated", True) and start_entity_value not in self.tokenizer.hf_tokenizers.all_special_tokens \ + and end_entity_value not in self.tokenizer.hf_tokenizers.all_special_tokens: + final_relation = self._create_relation_validation(text=doc_text, + doc_id=doc_id, + tokenized_text_data=tokenizer_text_data, + ent1_start_char_pos=ann_start_start_pos, + ent2_start_char_pos=ann_end_start_pos, + ent1_end_char_pos=ann_start_end_pos, + ent2_end_char_pos=ann_end_end_pos, + ent1_token_start_pos=ent1_token_start_pos, + ent2_token_start_pos=ent2_token_start_pos, + ent1_token_end_pos=ent1_token_end_pos, + ent2_token_end_pos=ent2_token_end_pos, + is_mct_export=True + ) + + if len(final_relation) > 0: + final_relation[4] = relation_label + final_relation[6] = start_entity_types + final_relation[7] = end_entity_types + final_relation[8] = start_entity_id + final_relation[9] = end_entity_id + final_relation[10] = start_entity_cui + final_relation[11] = end_entity_cui + + relation_instances.append(final_relation) output_relations.extend(relation_instances) @@ -622,15 +661,18 @@ def create_relations_from_export(self, data: Dict): # replace label_id with actual detected label number for idx in range(len(output_relations)): - output_relations[idx][5] = labels2idx[output_relations[idx][4]] + output_relations[idx][5] = int(labels2idx[output_relations[idx][4]]) self.log.info("MCT export dataset | nclasses: " + str(nclasses) + " | idx2label: " + str(idx2label)) self.log.info("Samples per class: ") + + self.log.error(str(idx2label)) + for label_num in list(idx2label.keys()): sample_count = 0 for output_relation in output_relations: - if int(label_num) == int(output_relation[5]): + if idx2label[label_num] == output_relation[4]: sample_count += 1 self.log.info( " label: " + idx2label[label_num] + " | samples: " + str(sample_count)) @@ -638,7 +680,7 @@ def create_relations_from_export(self, data: Dict): return {"output_relations": output_relations, "nclasses": nclasses, "labels2idx": labels2idx, "idx2label": idx2label} @classmethod - def get_labels(cls, relation_labels: List[str], config: ConfigRelCAT) -> Tuple[int, Dict[str, Any], Dict[int, Any]]: + def get_labels(cls, relation_labels: List[str], config: ConfigRelCAT) -> Tuple[int, Dict[str, int], Dict[int, str]]: """ This is used to update labels in config with unencountered classes/labels ( if any are encountered during training). Args: @@ -646,12 +688,12 @@ def get_labels(cls, relation_labels: List[str], config: ConfigRelCAT) -> Tuple[i config (ConfigRelCAT): config Returns: - Any: _description_ + Tuple[int, Dict[str, int], Dict[int, str]]: label count, labesl2idx mapping, idx2labels mapping """ curr_class_id = 0 - config_labels2idx: Dict = config.general.labels2idx - config_idx2labels: Dict = config.general.idx2labels + config_labels2idx: Dict[str, int] = config.general.labels2idx + config_idx2labels: Dict[int, str] = config.general.idx2labels relation_labels = [relation_label.strip() for relation_label in relation_labels] @@ -660,8 +702,8 @@ def get_labels(cls, relation_labels: List[str], config: ConfigRelCAT) -> Tuple[i if relation_label not in config_labels2idx.keys(): while curr_class_id in [int(label_idx) for label_idx in config_idx2labels.keys()]: curr_class_id += 1 - config_labels2idx[relation_label] = curr_class_id - config_idx2labels[curr_class_id] = relation_label + config_labels2idx[relation_label] = int(curr_class_id) + config_idx2labels[int(curr_class_id)] = relation_label return len(config_labels2idx.keys()), config_labels2idx, config_idx2labels, diff --git a/medcat/utils/relation_extraction/tokenizer.py b/medcat/utils/relation_extraction/tokenizer.py index 2256993e4..0c03cc490 100644 --- a/medcat/utils/relation_extraction/tokenizer.py +++ b/medcat/utils/relation_extraction/tokenizer.py @@ -1,27 +1,37 @@ import os from typing import Optional from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast +from transformers import PreTrainedTokenizerFast +import logging +from medcat.config_rel_cat import ConfigRelCAT -class TokenizerWrapperBERT(BertTokenizerFast): - ''' Wrapper around a huggingface BERT tokenizer so that it works with the - RelCAT models. - Args: - hf_tokenizers (`transformers.models.bert.tokenization_bert_fast.BertTokenizerFast`): - A huggingface Fast BERT. - ''' - name = 'bert-tokenizer' +logger = logging.getLogger(__name__) + + +class BaseTokenizerWrapper_RelationExtraction(PreTrainedTokenizerFast): + + name = "base_tokenizer_wrapper_rel" def __init__(self, hf_tokenizers=None, max_seq_length: Optional[int] = None, add_special_tokens: Optional[bool] = False): self.hf_tokenizers = hf_tokenizers self.max_seq_length = max_seq_length - self.add_special_tokens = add_special_tokens + self._add_special_tokens = add_special_tokens + + def get_size(self): + return len(self.hf_tokenizers.vocab) + + def token_to_id(self, token): + return self.hf_tokenizers.convert_tokens_to_ids(token) + + def get_pad_id(self): + return self.hf_tokenizers.pad_token_id def __call__(self, text, truncation: Optional[bool] = True): if isinstance(text, str): result = self.hf_tokenizers.encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True, return_attention_mask=True, - add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length, padding="longest", truncation=truncation) + add_special_tokens=self._add_special_tokens, max_length=self.max_seq_length, padding="longest", truncation=truncation) return {'offset_mapping': result['offset_mapping'], 'input_ids': result['input_ids'], @@ -32,7 +42,7 @@ def __call__(self, text, truncation: Optional[bool] = True): } elif isinstance(text, list): results = self.hf_tokenizers._batch_encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True, - add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length,truncation=truncation) + add_special_tokens=self._add_special_tokens, max_length=self.max_seq_length,truncation=truncation) output = [] for ind in range(len(results['input_ids'])): output.append({ @@ -48,24 +58,41 @@ def __call__(self, text, truncation: Optional[bool] = True): raise Exception( "Unsupported input type, supported: text/list, but got: {}".format(type(text))) - def save(self, dir_path): + def save(self, dir_path: str): path = os.path.join(dir_path, self.name) self.hf_tokenizers.save_pretrained(path) @classmethod - def load(cls, dir_path, **kwargs): - tokenizer = cls() - path = os.path.join(dir_path, cls.name) - tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained( - path, **kwargs) + def load(cls, tokenizer_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "BaseTokenizerWrapper_RelationExtraction": - return tokenizer - - def get_size(self): - return len(self.hf_tokenizers.vocab) + tokenizer = BaseTokenizerWrapper_RelationExtraction() - def token_to_id(self, token): - return self.hf_tokenizers.convert_tokens_to_ids(token) - - def get_pad_id(self): - return self.hf_tokenizers.pad_token_id + if os.path.exists(tokenizer_path): + if "modern-bert" in relcat_config.general.tokenizer_name: + from medcat.utils.relation_extraction.modernbert.tokenizer import TokenizerWrapperModernBERT_RelationExtraction + tokenizer = TokenizerWrapperModernBERT_RelationExtraction.load(tokenizer_path, relcat_config=relcat_config, **kwargs) + elif "bert" in relcat_config.general.tokenizer_name: + from medcat.utils.relation_extraction.bert.tokenizer import TokenizerWrapperBERT_RelationExtraction + tokenizer = TokenizerWrapperBERT_RelationExtraction.load(tokenizer_path, relcat_config=relcat_config, **kwargs) + elif "llama" in relcat_config.general.tokenizer_name: + from medcat.utils.relation_extraction.llama.tokenizer import TokenizerWrapperLlama_RelationExtraction + tokenizer = TokenizerWrapperLlama_RelationExtraction.load(tokenizer_path, relcat_config=relcat_config, **kwargs) + logger.info("Tokenizer loaded " + str(tokenizer.__class__.__name__) + " from:" + tokenizer_path) + elif relcat_config.general.model_name: + logger.info("Attempted to load Tokenizer from path:" + tokenizer_path + + ", but it doesn't exist, loading default toknizer from model_name relcat_config.general.model_name:" + + relcat_config.general.model_name) + from medcat.utils.relation_extraction.bert.tokenizer import TokenizerWrapperBERT_RelationExtraction + from medcat.utils.relation_extraction.ml_utils import create_tokenizer_pretrain + logger.info("Addeding special tokens to tokenizer:" + str(relcat_config.general.tokenizer_relation_annotation_special_tokens_tags) + " " + + str(relcat_config.general.tokenizer_other_special_tokens)) + tokenizer = TokenizerWrapperBERT_RelationExtraction(BertTokenizerFast.from_pretrained(relcat_config.general.model_name), add_special_tokens=True) + tokenizer = create_tokenizer_pretrain(tokenizer, relcat_config=relcat_config) + else: + logger.info("Attempted to load Tokenizer from path:" + tokenizer_path + + ", but it doesn't exist, loading default toknizer from model_name config.general.model_name:bert-base-uncased") + from medcat.utils.relation_extraction.bert.tokenizer import TokenizerWrapperBERT_RelationExtraction + tokenizer = TokenizerWrapperBERT_RelationExtraction(BertTokenizerFast.from_pretrained(relcat_config.general.model_name), + max_seq_length=relcat_config.general.max_seq_length, + add_special_tokens=relcat_config.general.tokenizer_special_tokens) + return tokenizer diff --git a/medcat/utils/relation_extraction/utils.py b/medcat/utils/relation_extraction/utils.py deleted file mode 100644 index 544905ca4..000000000 --- a/medcat/utils/relation_extraction/utils.py +++ /dev/null @@ -1,277 +0,0 @@ -import os -import pickle -from typing import Any, Dict, List, Tuple -import numpy as np -import torch -import logging -import random - -from pandas.core.series import Series -from medcat.config_rel_cat import ConfigRelCAT - -from medcat.preprocessing.tokenizers import TokenizerWrapperBERT -from medcat.utils.relation_extraction.models import BertModel_RelationExtraction - - -def split_list_train_test_by_class(data: List, test_size: float = 0.2, shuffle: bool = True) -> Tuple[List, List]: - """ - - Args: - data (List): "output_relations": relation_instances, <-- see create_base_relations_from_doc/csv - for data columns - test_size (float): Defaults to 0.2. - shuffle (bool): shuffle data randomly. Defaults to True. - - Returns: - Tuple[List, List]: train and test datasets - """ - - if shuffle: - random.shuffle(data) - - train_data = [] - test_data = [] - - row_id_labels = {row_idx: data[row_idx][5] for row_idx in range(len(data))} - count_per_label = {lbl: list(row_id_labels.values()).count( - lbl) for lbl in set(row_id_labels.values())} - - for lbl_id, count in count_per_label.items(): - _test_records_size = int(count * test_size) - tmp_count = 0 - if _test_records_size not in [0, 1]: - for row_idx, _lbl_id in row_id_labels.items(): - if _lbl_id == lbl_id: - if tmp_count < _test_records_size: - test_data.append(data[row_idx]) - tmp_count += 1 - else: - train_data.append(data[row_idx]) - else: - for row_idx, _lbl_id in row_id_labels.items(): - if _lbl_id == lbl_id: - train_data.append(data[row_idx]) - test_data.append(data[row_idx]) - - return train_data, test_data - - -def load_bin_file(file_name, path="./") -> Any: - with open(os.path.join(path, file_name), 'rb') as f: - data = pickle.load(f) - return data - - -def save_bin_file(file_name, data, path="./"): - with open(os.path.join(path, file_name), "wb") as f: - pickle.dump(data, f) - - -def save_state(model: BertModel_RelationExtraction, optimizer: torch.optim.Adam, scheduler: torch.optim.lr_scheduler.MultiStepLR, epoch:int = 1, best_f1:float = 0.0, path:str = "./", model_name: str = "BERT", task:str = "train", is_checkpoint=False, final_export=False) -> None: - """ Used by RelCAT.save() and RelCAT.train() - Saves the RelCAT model state. - For checkpointing multiple files are created, best_f1, loss etc. score. - If you want to export the model after training set final_export=True and leave is_checkpoint=False. - - Args: - model (BertModel_RelationExtraction): model - optimizer (torch.optim.Adam, optional): Defaults to None. - scheduler (torch.optim.lr_scheduler.MultiStepLR, optional): Defaults to None. - epoch (int): Defaults to None. - best_f1 (float): Defaults to None. - path (str):Defaults to "./". - model_name (str): . Defaults to "BERT". This is used to checkpointing only. - task (str): Defaults to "train". This is used to checkpointing only. - is_checkpoint (bool): Defaults to False. - final_export (bool): Defaults to False, if True then is_checkpoint must be False also. Exports model.state_dict(), out into"model.dat". - """ - - model_name = model_name.replace("/", "_") - file_name = "%s_checkpoint_%s.dat" % (task, model_name) - - if not is_checkpoint: - file_name = "%s_best_%s.dat" % (task, model_name) - if final_export: - file_name = "model.dat" - torch.save(model.state_dict(), os.path.join(path, file_name)) - - if is_checkpoint: - torch.save({ - 'epoch': epoch, - 'state_dict': model.state_dict(), - 'best_f1': best_f1, - 'optimizer': optimizer.state_dict(), - 'scheduler': scheduler.state_dict() - }, os.path.join(path, file_name)) - - -def load_state(model: BertModel_RelationExtraction, optimizer, scheduler, path="./", model_name="BERT", file_prefix="train", load_best=False, device: torch.device =torch.device("cpu"), config: ConfigRelCAT = ConfigRelCAT()) -> Tuple[int, int]: - """ Used by RelCAT.load() and RelCAT.train() - - Args: - model (BertModel_RelationExtraction): model, it has to be initialized before calling this method via BertModel_RelationExtraction(...) - optimizer (_type_): optimizer - scheduler (_type_): scheduler - path (str, optional): Defaults to "./". - model_name (str, optional): Defaults to "BERT". - file_prefix (str, optional): Defaults to "train". - load_best (bool, optional): Defaults to False. - device (torch.device, optional): Defaults to torch.device("cpu"). - config (ConfigRelCAT): Defaults to ConfigRelCAT(). - - Returns: - Tuple (int, int): last epoch and f1 score. - """ - - model_name = model_name.replace("/", "_") - logging.info("Attempting to load RelCAT model on device: " + str(device)) - checkpoint_path = os.path.join( - path, file_prefix + "_checkpoint_%s.dat" % model_name) - best_path = os.path.join( - path, file_prefix + "_best_%s.dat" % model_name) - start_epoch, best_f1, checkpoint = 0, 0, None - - if load_best is True and os.path.isfile(best_path): - checkpoint = torch.load(best_path, map_location=device) - logging.info("Loaded best model.") - elif os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location=device) - logging.info("Loaded checkpoint model.") - - if checkpoint is not None: - start_epoch = checkpoint['epoch'] - best_f1 = checkpoint['best_f1'] - model.load_state_dict(checkpoint['state_dict']) - model.to(device) - - if optimizer is None: - optimizer = torch.optim.Adam( - [{"params": model.module.parameters(), "lr": config.train.lr}]) - - if scheduler is None: - scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, - milestones=config.train.multistep_milestones, - gamma=config.train.multistep_lr_gamma) - optimizer.load_state_dict(checkpoint['optimizer']) - scheduler.load_state_dict(checkpoint['scheduler']) - logging.info("Loaded model and optimizer.") - - return start_epoch, best_f1 - - -def save_results(data, model_name: str = "BERT", path: str = "./", file_prefix: str = "train"): - save_bin_file(file_prefix + "_losses_accuracy_f1_per_epoch_%s.dat" % - model_name, data, path) - - -def load_results(path, model_name: str = "BERT", file_prefix: str = "train") -> Tuple[List, List, List]: - data_dict_path = os.path.join( - path, file_prefix + "_losses_accuracy_f1_per_epoch_%s.dat" % model_name) - - data_dict: Dict = {"losses_per_epoch": [], - "accuracy_per_epoch": [], "f1_per_epoch": []} - if os.path.isfile(data_dict_path): - data_dict = load_bin_file(data_dict_path) - - return data_dict["losses_per_epoch"], data_dict["accuracy_per_epoch"], data_dict["f1_per_epoch"] - - -def put_blanks(relation_data: List, blanking_threshold: float = 0.5) -> List: - """ - Args: - relation_data (List): tuple containing token (sentence_token_span , ent1 , ent2) - Puts blanks randomly in the relation. Used for pre-training. - blanking_threshold (float): % threshold to blank token ids. Defaults to 0.5. - - Returns: - List: data - """ - - blank_ent1 = np.random.uniform() - blank_ent2 = np.random.uniform() - - blanked_relation = relation_data - - sentence_token_span, ent1, ent2, label, label_id, ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id = ( - *relation_data, ) - - if blank_ent1 >= blanking_threshold: - blanked_relation = [sentence_token_span, "[BLANK]", ent2, label, label_id, - ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id] - - if blank_ent2 >= blanking_threshold: - blanked_relation = [sentence_token_span, ent1, "[BLANK]", label, label_id, - ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id] - - return blanked_relation - - -def create_tokenizer_pretrain(tokenizer: TokenizerWrapperBERT, tokenizer_path: str): - """ - This method simply adds special tokens that we encounter - - Args: - tokenizer (TokenizerWrapperBERT): BERT tokenizer. - tokenizer_path (str): path where tokenizer is to be saved. - """ - - - tokenizer.hf_tokenizers.add_tokens( - ["[BLANK]", "[ENT1]", "[ENT2]", "[/ENT1]", "[/ENT2]"], special_tokens=True) - tokenizer.hf_tokenizers.add_tokens( - ["[s1]", "[e1]", "[s2]", "[e2]"], special_tokens=True) - tokenizer.save(tokenizer_path) - - -# Used for creating data sets for pretraining -def tokenize(relations_dataset: Series, tokenizer: TokenizerWrapperBERT, mask_probability: float = 0.5) -> Tuple: - (tokens, span_1_pos, span_2_pos), ent1_text, ent2_text, label, label_id, ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id = relations_dataset - - cls_token = tokenizer.hf_tokenizers.cls_token - sep_token = tokenizer.hf_tokenizers.sep_token - - tokens = [token.lower() for token in tokens if tokens != '[BLANK]'] - - forbidden_indices = [i for i in range( - span_1_pos[0], span_1_pos[1])] + [i for i in range(span_2_pos[0], span_2_pos[1])] - - pool_indices = [i for i in range( - len(tokens)) if i not in forbidden_indices] - - masked_indices = np.random.choice(pool_indices, - size=round(mask_probability * - len(pool_indices)), - replace=False) - - masked_for_pred = [token.lower() for idx, token in enumerate( - tokens) if (idx in masked_indices)] - - tokens = [token if (idx not in masked_indices) - else tokenizer.hf_tokenizers.mask_token for idx, token in enumerate(tokens)] - - if (ent1_text == "[BLANK]") and (ent2_text != "[BLANK]"): - tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]", "[BLANK]", "[/ENT1]"] + \ - tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]"] + tokens[span_2_pos[0]:span_2_pos[1]] + ["[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token] - - elif (ent1_text == "[BLANK]") and (ent2_text == "[BLANK]"): - tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]", "[BLANK]", "[/ENT1]"] + \ - tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]", "[BLANK]", - "[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token] - - elif (ent1_text != "[BLANK]") and (ent2_text == "[BLANK]"): - tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]"] + tokens[span_1_pos[0]:span_1_pos[1]] + ["[/ENT1]"] + \ - tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]", "[BLANK]", - "[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token] - - elif (ent1_text != "[BLANK]") and (ent2_text != "[BLANK]"): - tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]"] + tokens[span_1_pos[0]:span_1_pos[1]] + ["[/ENT1]"] + \ - tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]"] + tokens[span_2_pos[0]:span_2_pos[1]] + ["[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token] - - ent1_ent2_start = ([i for i, e in enumerate(tokens) if e == "[ENT1]"][0], [ - i for i, e in enumerate(tokens) if e == "[ENT2]"][0]) - - token_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(tokens) - masked_for_pred = tokenizer.hf_tokenizers.convert_tokens_to_ids( - masked_for_pred) - - return token_ids, masked_for_pred, ent1_ent2_start diff --git a/medcat/vocab.py b/medcat/vocab.py index 88350c945..04528fd40 100644 --- a/medcat/vocab.py +++ b/medcat/vocab.py @@ -1,6 +1,6 @@ import numpy as np import pickle -from typing import Optional, List, Dict +from typing import Optional, List, Dict, cast import logging @@ -190,8 +190,19 @@ def make_unigram_table(self, table_size: int = -1) -> None: "the creation of a massive array. So therefore, there " "is no need to pass the `table_size` parameter anymore.") freqs = [] - for word in self.vec_index2word.values(): + # index list maps the slot in which a word index + # sits in vec_index2word to the actual index for said word + # e.g: + # if we have words indexed 0, 1, and 2 + # but only 0, and 2 have corresponding vectors + # then only 0 and 2 will occur in vec_index2word + # and while 0 will be in the 0th position (as expected) + # in the final probability list, 2 will be in 1st position + # so we need to mark that conversion down + index_list = [] + for word_index, word in self.vec_index2word.items(): freqs.append(self[word]) + index_list.append(word_index) # Power and normalize frequencies freqs = np.array(freqs) ** (3/4) @@ -199,6 +210,8 @@ def make_unigram_table(self, table_size: int = -1) -> None: # Calculate cumulative probabilities self.cum_probs = np.cumsum(freqs) + # the mapping from vector index order to word indices + self._index_list = index_list def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) -> List[int]: """Get N negative samples. @@ -216,7 +229,11 @@ def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) - if len(self.cum_probs) == 0: self.make_unigram_table() random_vals = np.random.rand(n) - inds = np.searchsorted(self.cum_probs, random_vals).tolist() + # NOTE: These indices are in terms of the cum_probs array + # which only has word data for words with vectors. + vec_slots = cast(List[int], np.searchsorted(self.cum_probs, random_vals).tolist()) + # so we need to translate these back to word indices + inds = list(map(self._index_list.__getitem__, vec_slots)) if ignore_punct_and_num: # Do not return anything that does not have letters in it diff --git a/requirements-dev.txt b/requirements-dev.txt index 6b954afc9..fe487b560 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ . -https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.6.0/en_core_web_md-3.6.0-py3-none-any.whl +https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.1/en_core_web_md-3.7.1-py3-none-any.whl flake8~=7.0.0 darglint~=1.8.1 mypy>=1.7.0,<1.12.0 diff --git a/requirements.txt b/requirements.txt index 45842566e..1e2b5efc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ . -https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.6.0/en_core_web_md-3.6.0-py3-none-any.whl +https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.2/en_core_web_md-3.7.2-py3-none-any.whl diff --git a/setup.py b/setup.py index 08440b9ec..f46ce23ea 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], diff --git a/tests/ner/test_transformers_ner.py b/tests/ner/test_transformers_ner.py index de9eae32c..14579711c 100644 --- a/tests/ner/test_transformers_ner.py +++ b/tests/ner/test_transformers_ner.py @@ -48,3 +48,20 @@ def on_epoch_end(self, *args, **kwargs) -> None: assert dataset["train"].num_rows == 48 assert dataset["test"].num_rows == 12 self.assertEqual(tracker.call.call_count, 2) + + def test_expand_model_with_concepts(self): + original_num_labels = self.undertest.model.num_labels + original_out_features = self.undertest.model.classifier.out_features + original_label_map_size = len(self.undertest.tokenizer.label_map) + cui2preferred_name = { + "concept_1" : "Preferred Name 1", + "concept_2" : "Preferred Name 2", + } + + self.undertest.expand_model_with_concepts(cui2preferred_name) + + assert self.undertest.model.num_labels == original_num_labels + len(cui2preferred_name) + assert self.undertest.model.classifier.out_features == original_out_features + len(cui2preferred_name) + assert len(self.undertest.tokenizer.label_map) == original_label_map_size + len(cui2preferred_name) + assert self.undertest.tokenizer.cui2name.get("concept_1") == "Preferred Name 1" + assert self.undertest.tokenizer.cui2name.get("concept_2") == "Preferred Name 2" diff --git a/tests/resources/regression/creation/vocab_data.txt b/tests/resources/regression/creation/vocab_data.txt index 036e651a8..0a3ca8053 100644 --- a/tests/resources/regression/creation/vocab_data.txt +++ b/tests/resources/regression/creation/vocab_data.txt @@ -1,7 +1,7 @@ severe 10000 1.0 0 0 1 0 0 0 minor 10000 -1.0 0 0 1 0 0 0 acute 6500 0 1.0 0 1 0 0 0 -chronic 6500 0 -1.0 0 0 1 0 0 0 +chronic 6500 0 -1.0 0 0 1 0 0 heavy 4000 0 0 1.0 1 0 0 0 light 4000 0 0 -1.0 1 0 0 0 considered 1000 0.1 -0.2 0 0 0.9 0 0 @@ -15,4 +15,4 @@ are 12000 0 0 0 0 1.1 0 0 has 11000 0 0 0 0 0.98 0 0 presence 1000 0 0 0 0 0 0 0.4 indication 500 0 0 0 0 0 0 0.3 -time 450 0 0 0 0 0 0 0.1 \ No newline at end of file +time 450 0 0 0 0 0 0 0.1 diff --git a/tests/resources/regression/run_regression.sh b/tests/resources/regression/run_regression.sh index 0e9911434..1b3bcfe0f 100644 --- a/tests/resources/regression/run_regression.sh +++ b/tests/resources/regression/run_regression.sh @@ -14,6 +14,10 @@ echo "$output" model_path=$(echo "$output" | tail -n 1) # NOTE: this file should be tagged with the python version we're using +# test the vocab to make sure it's all good +python tests/resources/regression/test_vocab.py +# TODO: test other things as well? + # run the regression_checker with the captured file path # if any of the regression cases fail, this will return a non-zero exit status python -m medcat.utils.regression.regression_checker \ diff --git a/tests/resources/regression/test_vocab.py b/tests/resources/regression/test_vocab.py new file mode 100644 index 000000000..47e400dcb --- /dev/null +++ b/tests/resources/regression/test_vocab.py @@ -0,0 +1,31 @@ +import os + +from medcat.vocab import Vocab + +import unittest + + +class RegressionModelVocabTests(unittest.TestCase): + VOCAB_DATA_PATH = os.path.join(os.path.dirname(__file__), + 'creation', 'vocab_data.txt') + + @classmethod + def setUpClass(cls): + cls.vocab = Vocab() + cls.vocab.add_words(cls.VOCAB_DATA_PATH) + + def test_has_same_vector_lengths(self): + all_lengths = set() + for w in self.vocab.vec_index2word.values(): + all_lengths.add(len(self.vocab.vec(w))) + self.assertEqual(len(all_lengths), 1, f"Expected equal lengths. Got: {all_lengths}") + + def test_all_words_have_vectors(self): + for w in self.vocab.vocab: + with self.subTest(f"Word: {repr(w)}"): + # NOTE: if not there, will raise an exception + self.assertIsNotNone(self.vocab.vec(w)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_cat.py b/tests/test_cat.py index 17cdd2819..432073822 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -383,6 +383,27 @@ def test_train_supervised_can_retain_MCT_filters(self, extra_cui_filter=None, re with self.subTest(f'CUI: {filtered_cui}'): self.assertTrue(filtered_cui in self.undertest.config.linking.filters.cuis) + def _test_train_sup_with_meta_cat(self, train_meta_cats: bool): + # def side_effect(doc, *args, **kwargs): + # raise ValueError() + # # return doc + meta_cat = _get_meta_cat(self.meta_cat_dir) + cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat]) + with patch.object(MetaCAT, "train_raw") as mock_train: + with patch.object(MetaCAT, "__call__", side_effect=lambda doc: doc): + cat.train_supervised_raw(get_fixed_meta_cat_data(), never_terminate=True, + train_meta_cats=train_meta_cats) + if train_meta_cats: + mock_train.assert_called() + else: + mock_train.assert_not_called() + + def test_train_supervised_does_not_train_meta_cat_by_default(self): + self._test_train_sup_with_meta_cat(False) + + def test_train_supervised_can_train_meta_cats(self): + self._test_train_sup_with_meta_cat(True) + def test_train_supervised_no_leak_extra_cui_filters(self): self.test_train_supervised_does_not_retain_MCT_filters_default(extra_cui_filter={'C123', 'C111'}) @@ -799,6 +820,9 @@ def test_loading_model_pack_without_any_config_raises_exception(self): CAT.load_model_pack(self.temp_dir.name) +META_CAT_JSON_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json") + + def _get_meta_cat(meta_cat_dir): config = ConfigMetaCAT() config.general["category_name"] = "Status" @@ -808,11 +832,31 @@ def _get_meta_cat(meta_cat_dir): embeddings=None, config=config) os.makedirs(meta_cat_dir, exist_ok=True) - json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json") + json_path = META_CAT_JSON_PATH meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir) return meta_cat +def get_fixed_meta_cat_data(path: str = META_CAT_JSON_PATH): + with open(path) as f: + data = json.load(f) + for proj_num, project in enumerate(data['projects']): + if 'name' not in project: + project['name'] = f"Proj_{proj_num}" + if 'cuis' not in project: + project['cuis'] = '' + if 'id' not in project: + project['id'] = f'P{proj_num}' + for doc in project['documents']: + if 'entities' in doc and 'annotations' not in doc: + ents = doc.pop("entities") + doc['annotations'] = list(ents.values()) + for ann in doc['annotations']: + if 'pretty_name' in ann and 'value' not in ann: + ann['value'] = ann.pop('pretty_name') + return data + + class TestLoadingOldWeights(unittest.TestCase): cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb_old_broken_weights_in_config.dat") diff --git a/tests/test_pipe.py b/tests/test_pipe.py index e626d139c..3abfc35b9 100644 --- a/tests/test_pipe.py +++ b/tests/test_pipe.py @@ -13,7 +13,8 @@ from medcat.ner.vocab_based_ner import NER from medcat.linking.context_based_linker import Linker from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT -from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT as RelTokenizerWrapperBERT +from medcat.utils.relation_extraction.bert.tokenizer import ( + BaseTokenizerWrapper_RelationExtraction as RelTokenizerWrapperBERT) from transformers import AutoTokenizer @@ -43,9 +44,10 @@ def setUpClass(cls) -> None: cls.linker = Linker(cls.cdb, cls.vocab, cls.config) _tokenizer = TokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased")) - _tokenizer_rel = RelTokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased")) + # _tokenizer_rel = RelTokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased")) cls.meta_cat = MetaCAT(tokenizer=_tokenizer) - cls.rel_cat = RelCAT(cls.cdb, tokenizer=_tokenizer_rel, init_model=True) + cls.rel_cat = RelCAT(cls.cdb, # tokenizer=_tokenizer_rel, + init_model=True) cls.text = "stop of CDB - I was running and then Movar Virus attacked and CDb" cls.undertest = Pipe(tokenizer=spacy_split_all, config=cls.config) diff --git a/tests/test_rel_cat.py b/tests/test_rel_cat.py index 8f6db4261..916e22b95 100644 --- a/tests/test_rel_cat.py +++ b/tests/test_rel_cat.py @@ -2,20 +2,20 @@ import shutil import unittest import json +import logging from medcat.cdb import CDB from medcat.config_rel_cat import ConfigRelCAT from medcat.rel_cat import RelCAT +from medcat.utils.relation_extraction.bert.tokenizer import BaseTokenizerWrapper_RelationExtraction from medcat.utils.relation_extraction.rel_dataset import RelData -from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT -from medcat.utils.relation_extraction.models import BertModel_RelationExtraction from transformers.models.auto.tokenization_auto import AutoTokenizer -from transformers.models.bert.configuration_bert import BertConfig import spacy from spacy.tokens import Span, Doc + class RelCATTests(unittest.TestCase): @classmethod @@ -27,8 +27,9 @@ def setUpClass(cls) -> None: config.train.nclasses = 3 config.model.hidden_size= 256 config.model.model_size = 2304 + config.general.log_level = logging.DEBUG - tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained( + tokenizer = BaseTokenizerWrapper_RelationExtraction(AutoTokenizer.from_pretrained( pretrained_model_name_or_path=config.general.model_name, config=config), add_special_tokens=True) @@ -54,26 +55,50 @@ def setUpClass(cls) -> None: cls.mct_file_test = json.loads(f.read())["projects"][0]["documents"][1] cls.config_rel_cat: ConfigRelCAT = config - cls.rel_cat: RelCAT = RelCAT(cdb, tokenizer=tokenizer, config=config, init_model=True) + cls.rel_cat: RelCAT = RelCAT(cdb, config=config, init_model=True) - cls.rel_cat.model.bert_model.resize_token_embeddings(len(tokenizer.hf_tokenizers)) + cls.rel_cat.component.model.hf_model.resize_token_embeddings(len(tokenizer.hf_tokenizers)) + cls.rel_cat.component.model_config.hf_model_config.vocab_size = tokenizer.get_size() cls.finished = False cls.tokenizer = tokenizer + def test_dataset_relation_parser(self) -> None: + + samples = [ + "The [s1]45-year-old male[e1] was diagnosed with [s2]hypertension[e2] during his routine check-up.", + "The patient’s [s1]chest pain[e1] was associated with [s2]shortness of breath[e2].", + "[s1]Blood pressure[e1] readings of [s2]160/90 mmHg[e2] indicated possible hypertension.", + "His elevated [s1]blood glucose[e1] level of [s2]220 mg/dL[e2] raised concerns about his diabetes management.", + "The doctor recommended a [s1]cardiac enzyme test[e1] to assess the risk of [s2]myocardial infarction[e2].", + "The patient’s [s1]ECG[e1] showed signs of [s2]ischemia[e2]", + "To manage his [s1]hypertension[e1], the patient was advised to [s2]reduce salt intake[e2].", + "[s1]Increased physical activity[e1][s2]type 2 diabetes[e2]." + ] + + rel_dataset = RelData(cdb=self.rel_cat.cdb, config=self.config_rel_cat, tokenizer=self.tokenizer) + + rels = [] + + for idx in range(len(samples)): + tkns = self.tokenizer(samples[idx])["tokens"] + ent1_ent2_tokens_start_pos = (tkns.index("[s1]"), tkns.index("[s2]")) + rels.append(rel_dataset.create_base_relations_from_doc(samples[idx], idx, + ent1_ent2_tokens_start_pos=ent1_ent2_tokens_start_pos)) + + self.assertEqual(len(rels), len(samples)) + def test_train_csv_no_tags(self) -> None: - self.rel_cat.config.train.epochs = 2 + self.rel_cat.component.relcat_config.train.epochs = 2 self.rel_cat.train(train_csv_path=self.medcat_rels_csv_path_train, test_csv_path=self.medcat_rels_csv_path_test, checkpoint_path=self.tmp_dir) self.rel_cat.save(self.save_model_path) def test_train_mctrainer(self) -> None: self.rel_cat = RelCAT.load(self.save_model_path) - self.rel_cat.config.general.mct_export_create_addl_rels = True - self.rel_cat.config.general.mct_export_max_non_rel_sample_size = 10 - self.rel_cat.config.train.test_size = 0.1 - self.rel_cat.config.train.nclasses = 3 - self.rel_cat.model.relcat_config.train.nclasses = 3 - self.rel_cat.model.bert_model.resize_token_embeddings(len(self.tokenizer.hf_tokenizers)) + self.rel_cat.component.relcat_config.general.create_addl_rels = True + self.rel_cat.component.relcat_config.general.addl_rels_max_sample_size = 10 + self.rel_cat.component.relcat_config.train.test_size = 0.1 + self.rel_cat.component.relcat_config.train.nclasses = 3 self.rel_cat.train(export_data_path=self.medcat_export_with_rels_path, checkpoint_path=self.tmp_dir) @@ -95,17 +120,19 @@ def test_train_predict(self) -> None: entity._.cui = ann["cui"] doc._.ents.append(entity) - self.rel_cat.model.bert_model.resize_token_embeddings(len(self.tokenizer.hf_tokenizers)) + self.rel_cat.component.model.hf_model.resize_token_embeddings(len(self.tokenizer.hf_tokenizers)) doc = self.rel_cat(doc) self.finished = True - assert len(doc._.relations) > 0 + self.assertGreater(len(doc._.relations), 0) + def tearDown(self) -> None: if self.finished: if os.path.exists(self.tmp_dir): shutil.rmtree(self.tmp_dir) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_transformers_ner.py b/tests/test_transformers_ner.py new file mode 100644 index 000000000..6071695d5 --- /dev/null +++ b/tests/test_transformers_ner.py @@ -0,0 +1,242 @@ +import unittest +import tempfile +import json +import os +import shutil +from medcat.cdb import CDB +from medcat.ner.transformers_ner import TransformersNER +from medcat.config_transformers_ner import ConfigTransformersNER + +class TestTransformersNER(unittest.TestCase): + def setUp(self): + # Create a temporary directory for the test + self.tmp_dir = tempfile.TemporaryDirectory() + # Create results dir for training outputs + self.results_dir = './results' + os.makedirs(self.results_dir, exist_ok=True) + + # Create a minimal CDB + self.cdb = CDB() + + # Create initial training data with 2 labels and multiple examples + self.initial_data = { + "projects": [{ + "documents": [ + { + "text": "Patient has diabetes and hypertension.", + "annotations": [ + { + "cui": "C0011849", # Diabetes + "start": 14, + "end": 22, + "value": "diabetes" + }, + { + "cui": "C0020538", # Hypertension + "start": 27, + "end": 39, + "value": "hypertension" + } + ] + }, + { + "text": "History of diabetes with hypertension.", + "annotations": [ + { + "cui": "C0011849", # Diabetes + "start": 12, + "end": 20, + "value": "diabetes" + }, + { + "cui": "C0020538", # Hypertension + "start": 26, + "end": 38, + "value": "hypertension" + } + ] + }, + { + "text": "Diagnosed with hypertension and diabetes.", + "annotations": [ + { + "cui": "C0020538", # Hypertension + "start": 15, + "end": 27, + "value": "hypertension" + }, + { + "cui": "C0011849", # Diabetes + "start": 32, + "end": 40, + "value": "diabetes" + } + ] + } + ] + }] + } + + # Create new training data with an extra label + self.new_data = { + "projects": [{ + "documents": [ + { + "text": "Patient has diabetes, hypertension, and asthma.", + "annotations": [ + { + "cui": "C0011849", # Diabetes + "start": 14, + "end": 22, + "value": "diabetes" + }, + { + "cui": "C0020538", # Hypertension + "start": 24, + "end": 36, + "value": "hypertension" + }, + { + "cui": "C0004096", # Asthma + "start": 42, + "end": 48, + "value": "asthma" + } + ] + }, + { + "text": "History of asthma with diabetes and hypertension.", + "annotations": [ + { + "cui": "C0004096", # Asthma + "start": 12, + "end": 18, + "value": "asthma" + }, + { + "cui": "C0011849", # Diabetes + "start": 24, + "end": 32, + "value": "diabetes" + }, + { + "cui": "C0020538", # Hypertension + "start": 37, + "end": 49, + "value": "hypertension" + } + ] + }, + { + "text": "Diagnosed with asthma, diabetes, and hypertension.", + "annotations": [ + { + "cui": "C0004096", # Asthma + "start": 15, + "end": 21, + "value": "asthma" + }, + { + "cui": "C0011849", # Diabetes + "start": 23, + "end": 31, + "value": "diabetes" + }, + { + "cui": "C0020538", # Hypertension + "start": 37, + "end": 49, + "value": "hypertension" + } + ] + } + ] + }] + } + + # Save initial training data + self.initial_data_path = os.path.join(self.tmp_dir.name, 'initial_data.json') + with open(self.initial_data_path, 'w') as f: + json.dump(self.initial_data, f) + + # Save new training data + self.new_data_path = os.path.join(self.tmp_dir.name, 'new_data.json') + with open(self.new_data_path, 'w') as f: + json.dump(self.new_data, f) + + def tearDown(self): + # Clean up the temporary directory + self.tmp_dir.cleanup() + # Clean up results directory if it exists + if os.path.exists(self.results_dir): + shutil.rmtree(self.results_dir) + # Clean up logs directory if it exists + if os.path.exists('./logs'): + shutil.rmtree('./logs') + + def test_ignore_extra_labels(self): + # Create and train initial model with tiny BERT + config = ConfigTransformersNER() + config.general['model_name'] = 'prajjwal1/bert-tiny' + # Set to single epoch and small test size for faster testing + config.general['num_train_epochs'] = 1 + config.general['test_size'] = 0.1 + + # Create training arguments with reduced epochs + from transformers import TrainingArguments + training_args = TrainingArguments( + output_dir=self.results_dir, # Use the class results_dir + num_train_epochs=1 + ) + + ner = TransformersNER(self.cdb, config=config, training_arguments=training_args) + ner.train(self.initial_data_path) + + # Save the model + model_path = os.path.join(self.tmp_dir.name, 'model') + ner.save(model_path) + + # Load the saved model + loaded_ner = TransformersNER.load(model_path) + + # Get initial number of labels + initial_num_labels = len(loaded_ner.tokenizer.label_map) + + # Train with ignore_extra_labels=True + loaded_ner.train(self.new_data_path, ignore_extra_labels=True) + + # Verify number of labels hasn't changed + self.assertEqual( + len(loaded_ner.tokenizer.label_map), + initial_num_labels, + "Number of labels changed despite ignore_extra_labels=True" + ) + + # Verify only original labels are present (including special tokens) + expected_labels = {"C0011849", "C0020538", "O", "X"} + self.assertEqual( + set(loaded_ner.tokenizer.label_map.keys()), + expected_labels, + "Label map contains unexpected labels" + ) + + # Train with ignore_extra_labels=False + loaded_ner.train(self.new_data_path, ignore_extra_labels=False) + + # Verify new label was added + self.assertEqual( + len(loaded_ner.tokenizer.label_map), + initial_num_labels + 1, + "New label was not added when ignore_extra_labels=False" + ) + + # Verify all labels are present (including special tokens) + expected_labels = {"C0011849", "C0020538", "C0004096", "O", "X"} + self.assertEqual( + set(loaded_ner.tokenizer.label_map.keys()), + expected_labels, + "Label map missing expected labels" + ) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 5e4f8e25e..dd1203ad8 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -43,11 +43,16 @@ class VocabUnigramTableTests(unittest.TestCase): "..", "examples", "vocab_data.txt") UNIGRAM_TABLE_SIZE = 10_000 # found that this seed had the closest frequency at the sample size we're at - RANDOM_SEED = 4976 + RANDOM_SEED = 32 NUM_SAMPLES = 20 # NOTE: 3, 9, 18, and 27 at a time are regular due to context vector sizes NUM_TIMES = 200 - # based on the counts on vocab_data.txt and the one set in setUpClass - EXPECTED_FREQUENCIES = [0.62218692, 0.32422858, 0.0535845] + # based on the counts on vocab_data.txt and the ones set in setUpClass + # plus the power of 3/4 + EXPECTED_FREQUENCIES = { + 0: 0.61078822, 1: 0.3182886, + 2: 0.05260281, + # NOTE: no 3 since that's got no vectors + 4: 0.01832037} TOLERANCE = 0.001 @classmethod @@ -55,25 +60,50 @@ def setUpClass(cls): cls.vocab = Vocab() cls.vocab.add_words(cls.EXAMPLE_DATA_PATH) cls.vocab.add_word("test", cnt=1310, vec=[1.42, 1.44, 1.55]) + cls.vocab.add_word("vectorless", cnt=1234, vec=None) + cls.vocab.add_word("withvector", cnt=321, vec=[1.3, 1.2, 0.8]) cls.vocab.make_unigram_table(table_size=cls.UNIGRAM_TABLE_SIZE) def setUp(self): np.random.seed(self.RANDOM_SEED) @classmethod - def _get_freqs(cls) -> list[float]: + def _get_freqs(cls) -> dict[int, float]: c = Counter() for _ in range(cls.NUM_TIMES): got = cls.vocab.get_negative_samples(cls.NUM_SAMPLES) c += Counter(got) - total = sum(c[i] for i in c) - got_freqs = [c[i]/total for i in range(len(cls.EXPECTED_FREQUENCIES))] + total = sum(c.values()) + got_freqs = {index: val/total for index, val in c.items()} return got_freqs - def assert_accurate_enough(self, got_freqs: list[float]): + @classmethod + def _get_abs_max_diff(cls, dict1: dict[int, float], + dict2: dict[int, float]): + assert dict1.keys() == dict2.keys() + vals1, vals2 = [], [] + for index in dict1: + vals1.append(dict1[index]) + vals2.append(dict2[index]) + return np.max(np.abs(np.array(vals1) - np.array(vals2))) + + def assert_accurate_enough(self, got_freqs: dict[int, float]): + self.assertEqual(got_freqs.keys(), self.EXPECTED_FREQUENCIES.keys()) self.assertTrue( - np.max(np.abs(np.array(got_freqs) - self.EXPECTED_FREQUENCIES)) < self.TOLERANCE - ) + self._get_abs_max_diff(self.EXPECTED_FREQUENCIES, got_freqs) < self.TOLERANCE) + + def test_does_not_include_vectorless_indices(self, num_samples: int = 100): + inds = self.vocab.get_negative_samples(num_samples) + for index in inds: + with self.subTest(f"Index: {index}"): + # in the right list + self.assertIn(index, self.vocab.vec_index2word) + word = self.vocab.vec_index2word[index] + info = self.vocab.vocab[word] + # the info has vector + self.assertIn("vec", info) + # the vector is an array or a list + self.assertIsInstance(self.vocab.vec(word), (np.ndarray, list),) def test_negative_sampling(self): got_freqs = self._get_freqs() diff --git a/tests/utils/ner/test_deid.py b/tests/utils/ner/test_deid.py index 0eed7b6da..9eda6c973 100644 --- a/tests/utils/ner/test_deid.py +++ b/tests/utils/ner/test_deid.py @@ -90,6 +90,13 @@ def test_training(self): self.assertIsNotNone(examples) self.assertIsNotNone(dataset) + def test_add_new_concepts(self): + self.deid_model.add_new_concepts({'CONCEPT': "Concept"}, with_random_init=True) + self.assertTrue("CONCEPT" in self.deid_model.cat.cdb.cui2names) + self.assertEqual(self.deid_model.cat.cdb.cui2names["CONCEPT"], {"concept"}) + self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].model.config.label2id) + self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].tokenizer.label_map) + self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].tokenizer.cui2name) input_text = ''' James Joyce diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py new file mode 100644 index 000000000..dca0c6318 --- /dev/null +++ b/tests/utils/test_data_utils.py @@ -0,0 +1,90 @@ +import os +import json +from copy import deepcopy + +from medcat.utils import data_utils +from medcat.stats.mctexport import count_all_annotations, count_all_docs + +from unittest import TestCase + + +class FakeCDB: + + def __init__(self): + self.cui2tui = {} + + def get_name(self, cui: str) -> str: + return cui + + +class TestTrainSplitTestsBase(TestCase): + file_name = os.path.join(os.path.dirname(__file__), + "..", "resources", "medcat_trainer_export.json") + allowed_doc_ids = {3204, 3205} + test_size = 0.2 + expect_empty_train_set = False + expect_empty_test_set = False + seed = None + + @classmethod + def setUpClass(cls): + with open(cls.file_name) as f: + cls.data = json.load(f) + cls.undertest = cls.data + cls.cdb = FakeCDB() + + def setUp(self): + if self.seed is not None: + data_utils.set_all_seeds(self.seed) + (self.train_set, self.test_set, + self.num_test_anns, + self.num_total_anns) = data_utils.make_mc_train_test( + self.undertest, self.cdb, test_size=self.test_size) + + +class TestTrainSplitUnfilteredTests(TestTrainSplitTestsBase): + + def test_all_docs_accounted_for(self): + self.assertEqual(count_all_docs(self.undertest), + count_all_docs(self.train_set) + count_all_docs(self.test_set)) + + def test_all_anns_accounted_for(self): + self.assertEqual(count_all_annotations(self.undertest), + count_all_annotations(self.train_set) + count_all_annotations(self.test_set)) + + def test_total_anns_match(self): + total = count_all_annotations(self.undertest) + self.assertEqual(self.num_total_anns, total) + self.assertEqual(self.num_test_anns + count_all_annotations(self.train_set), + total) + + def test_nonempty_train(self): + if not self.expect_empty_train_set: + self.assertTrue(self.train_set) + self.assertTrue(self.num_total_anns - self.num_test_anns) + self.assertEqual(self.num_total_anns - self.num_test_anns, + count_all_annotations(self.train_set)) + + def test_nonempty_test(self): + if not self.expect_empty_test_set: + self.assertTrue(self.test_set) + self.assertTrue(self.num_test_anns) + self.assertEqual(self.num_test_anns, + count_all_annotations(self.test_set)) + + +class TestTrainSplitFilteredTestsBase(TestTrainSplitUnfilteredTests): + expect_empty_test_set = True + # would work with previous version: + # seed = 332378110 + # was guaranteed to fail with previous version: + seed = 73607120 + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.filtered = deepcopy(cls.data) + for proj in cls.filtered['projects']: + proj['documents'] = [doc for doc in proj['documents'] + if doc['id'] in cls.allowed_doc_ids] + cls.undertest = cls.filtered diff --git a/tests/utils/test_memory_optimiser.py b/tests/utils/test_memory_optimiser.py index 5f59f5274..472029d45 100644 --- a/tests/utils/test_memory_optimiser.py +++ b/tests/utils/test_memory_optimiser.py @@ -254,8 +254,8 @@ def test_optimisation_round_trip_cuis(self): with self.subTest(f'{name}'): self.assertIsInstance(before, dict) self.assertIsInstance(after, dict) - self.assertEquals(len(before), len(after)) - self.assertEquals(before, after) + self.assertEqual(len(before), len(after)) + self.assertEqual(before, after) def test_optimisation_round_trip_snames(self): snames_before = self.cdb.snames @@ -264,8 +264,8 @@ def test_optimisation_round_trip_snames(self): snames_after = self.cdb.snames self.assertIsInstance(snames_before, set) self.assertIsInstance(snames_after, set) - self.assertEquals(len(snames_before), len(snames_after)) - self.assertEquals(snames_before, snames_after) + self.assertEqual(len(snames_before), len(snames_after)) + self.assertEqual(snames_before, snames_after) def test_optimisation_round_trip_dirty(self): memory_optimiser.perform_optimisation(self.cdb) From 0273234909a108fcd447b27e0858fe0b2e6a1de5 Mon Sep 17 00:00:00 2001 From: mart-r Date: Fri, 25 Apr 2025 13:37:26 +0100 Subject: [PATCH 9/9] Use Ubuntu 24.04 for production workflow --- .github/workflows/production.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml index 44fc53ebf..0db12f53c 100644 --- a/.github/workflows/production.yml +++ b/.github/workflows/production.yml @@ -8,7 +8,7 @@ on: jobs: build-n-publish-to-pypi: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 concurrency: build-n-publish-to-pypi if: github.repository == 'CogStack/MedCAT'