diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index e7fd371cc..ec88053ac 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -83,23 +83,32 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module: The module """ config = self.config - if config.model['model_name'] == 'lstm': - from medcat.utils.meta_cat.models import LSTM - model: nn.Module = LSTM(embeddings, config) - logger.info("LSTM model used for classification") - - elif config.model['model_name'] == 'bert': - from medcat.utils.meta_cat.models import BertForMetaAnnotation - model = BertForMetaAnnotation(config) - - if not config.model.model_freeze_layers: - peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16, - target_modules=["query", "value"], lora_dropout=0.2) + if config.model['model_name'] in ['lstm','bert','modernbert']: + if config.model['model_name'] == 'lstm': + from medcat.utils.meta_cat.models import LSTM + model: nn.Module = LSTM(embeddings, config) + logger.info("LSTM model used for classification") + + elif config.model['model_name'] == 'bert': + from medcat.utils.meta_cat.models import BertForMetaAnnotation + model = BertForMetaAnnotation(config) + logger.info("BERT model used for classification") + + if not config.model.model_freeze_layers: + peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16, + target_modules=["query", "value"], lora_dropout=0.2) + + elif config.model['model_name'] == 'modernbert': + from medcat.utils.meta_cat.models import ModernBertForMetaAnnotation + model = ModernBertForMetaAnnotation(config) + logger.info("ModernBERT model used for classification") + + if not config.model.model_freeze_layers: + peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16, + target_modules = ["Wqkv", "Wo"], lora_dropout=0.2) model = get_peft_model(model, peft_config) - # model.print_trainable_parameters() - - logger.info("BERT model used for classification") + # model.print_trainable_parameters() else: raise ValueError("Unknown model name %s" % config.model['model_name']) @@ -419,6 +428,10 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model['model_variant']) + elif config.general['tokenizer_name'] == 'modernbert-tokenizer': + from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperModernBERT + tokenizer = TokenizerWrapperModernBERT.load(save_dir_path, config.model['model_variant']) + # Create meta_cat meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config) diff --git a/medcat/tokenizers/meta_cat_tokenizers.py b/medcat/tokenizers/meta_cat_tokenizers.py index 4c4daf200..f7c8efe30 100644 --- a/medcat/tokenizers/meta_cat_tokenizers.py +++ b/medcat/tokenizers/meta_cat_tokenizers.py @@ -4,6 +4,7 @@ from typing import List, Dict, Optional, Union, overload from tokenizers import Tokenizer, ByteLevelBPETokenizer from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast +from transformers import PreTrainedTokenizerFast class TokenizerWrapperBase(ABC): @@ -209,3 +210,75 @@ def token_to_id(self, token: str) -> Union[int, List[int]]: def get_pad_id(self) -> Optional[int]: self.hf_tokenizers = self.ensure_tokenizer() return self.hf_tokenizers.pad_token_id + + +class TokenizerWrapperModernBERT(TokenizerWrapperBase): + """Wrapper around a huggingface ModernBERT tokenizer so that it works with the + MetaCAT models. + + Args: + hf_tokenizers (`transformers.PreTrainedTokenizerFast`): + A huggingface Fast tokenizer. + """ + name = 'modernbert-tokenizer' + + def __init__(self, hf_tokenizers: Optional[PreTrainedTokenizerFast] = None) -> None: + super().__init__(hf_tokenizers) + + @overload + def __call__(self, text: str) -> Dict: ... + + @overload + def __call__(self, text: List[str]) -> List[Dict]: ... + + def __call__(self, text: Union[str, List[str]]) -> Union[Dict, List[Dict]]: + self.hf_tokenizers = self.ensure_tokenizer() + if isinstance(text, str): + result = self.hf_tokenizers.encode_plus(text, return_offsets_mapping=True, + add_special_tokens=False) + + return {'offset_mapping': result['offset_mapping'], + 'input_ids': result['input_ids'], + 'tokens': self.hf_tokenizers.convert_ids_to_tokens(result['input_ids']), + } + elif isinstance(text, list): + results = self.hf_tokenizers._batch_encode_plus(text, return_offsets_mapping=True, + add_special_tokens=False) + output = [] + for ind in range(len(results['input_ids'])): + output.append({'offset_mapping': results['offset_mapping'][ind], + 'input_ids': results['input_ids'][ind], + 'tokens': self.hf_tokenizers.convert_ids_to_tokens(results['input_ids'][ind]), + }) + return output + else: + raise Exception("Unsupported input type, supported: text/list, but got: {}".format(type(text))) + + def save(self, dir_path: str) -> None: + self.hf_tokenizers = self.ensure_tokenizer() + path = os.path.join(dir_path, self.name) + self.hf_tokenizers.save_pretrained(path) + + @classmethod + def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "TokenizerWrapperModernBERT": + tokenizer = cls() + path = os.path.join(dir_path, cls.name) + try: + tokenizer.hf_tokenizers = PreTrainedTokenizerFast.from_pretrained(path, **kwargs) + except Exception as e: + logging.warning("Could not load tokenizer from path due to error: {}. Loading from library for model variant: {}".format(e,model_variant)) + tokenizer.hf_tokenizers = PreTrainedTokenizerFast.from_pretrained(model_variant) + + return tokenizer + + def get_size(self) -> int: + self.hf_tokenizers = self.ensure_tokenizer() + return len(self.hf_tokenizers.vocab) + + def token_to_id(self, token: str) -> Union[int, List[int]]: + self.hf_tokenizers = self.ensure_tokenizer() + return self.hf_tokenizers.convert_tokens_to_ids(token) + + def get_pad_id(self) -> Optional[int]: + self.hf_tokenizers = self.ensure_tokenizer() + return self.hf_tokenizers.pad_token_id diff --git a/medcat/utils/meta_cat/models.py b/medcat/utils/meta_cat/models.py index 543e0ca6b..4ea81bbca 100644 --- a/medcat/utils/meta_cat/models.py +++ b/medcat/utils/meta_cat/models.py @@ -2,7 +2,7 @@ from collections import OrderedDict from typing import Optional, Any, List, Iterable from torch import nn, Tensor -from transformers import BertModel, AutoConfig +from transformers import BertModel, AutoConfig, ModernBertModel, ModernBertConfig from medcat.meta_cat import ConfigMetaCAT import logging logger = logging.getLogger(__name__) @@ -214,3 +214,136 @@ def forward( # output layer x = self.fc4(x) return x + + +class ModernBertForMetaAnnotation(nn.Module): + _keys_to_ignore_on_load_unexpected: List[str] = [r"pooler"] # type: ignore + + def __init__(self, config): + super(ModernBertForMetaAnnotation, self).__init__() + _modernbertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers']) + if config.model['input_size'] != _modernbertconfig.hidden_size: + logger.warning("Input size for %s model should be %d, provided input size is %d. Input size changed to %d",config.model.model_variant,_modernbertconfig.hidden_size,config.model['input_size'],_modernbertconfig.hidden_size) + modernbert_model = ModernBertModel.from_pretrained(config.model.model_variant, config=_modernbertconfig) + self.config = config + self.config.use_return_dict = False + self.modernbert = modernbert_model + self.num_labels = config.model["nclasses"] + for param in self.modernbert.parameters(): + param.requires_grad = not config.model.model_freeze_layers + + hidden_size_2 = int(config.model.hidden_size / 2) + # dropout layer + self.dropout = nn.Dropout(config.model.dropout) + # relu activation function + self.relu = nn.ReLU() + # dense layer 1 + self.fc1 = nn.Linear(_modernbertconfig.hidden_size*2, config.model.hidden_size) + # dense layer 2 + self.fc2 = nn.Linear(config.model.hidden_size, hidden_size_2) + # dense layer 3 + self.fc3 = nn.Linear(hidden_size_2, hidden_size_2) + # dense layer 3 (Output layer) + model_arch_config = config.model.model_architecture_config + + if model_arch_config['fc3'] is True and model_arch_config['fc2'] is False: + logger.warning("FC3 can only be used if FC2 is also enabled. Enabling FC2...") + config.model.model_architecture_config['fc2'] = True + + if model_arch_config is not None: + if model_arch_config['fc2'] is True: + self.fc4 = nn.Linear(hidden_size_2, self.num_labels) + else: + self.fc4 = nn.Linear(config.model.hidden_size, self.num_labels) + else: + self.fc4 = nn.Linear(hidden_size_2, self.num_labels) + # softmax activation function + self.softmax = nn.LogSoftmax(dim=1) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + center_positions: Iterable[Any] = [], + ignore_cpos: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ): + """labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + + Args: + input_ids (Optional[torch.LongTensor]): The input IDs. Defaults to None. + attention_mask (Optional[torch.FloatTensor]): The attention mask. Defaults to None. + token_type_ids (Optional[torch.LongTensor]): Type IDs of the tokens. Defaults to None. + position_ids (Optional[torch.LongTensor]): Position IDs. Defaults to None. + head_mask (Optional[torch.FloatTensor]): Head mask. Defaults to None. + inputs_embeds (Optional[torch.FloatTensor]): Input embeddings. Defaults to None. + labels (Optional[torch.LongTensor]): Labels. Defaults to None. + center_positions (Optional[Any]): Cennter positions. Defaults to None. + output_attentions (Optional[bool]): Output attentions. Defaults to None. + ignore_cpos: If center positions are to be ignored. + output_hidden_states (Optional[bool]): Output hidden states. Defaults to None. + return_dict (Optional[bool]): Whether to return a dict. Defaults to None. + + Returns: + TokenClassifierOutput: The token classifier output. + """ + # return_dict = return_dict if return_dict is not None else self.config.use_return_dict # type: ignore + + outputs = self.modernbert( # type: ignore + input_ids, + attention_mask=attention_mask, output_hidden_states=True + ) + + x_all = [] + for i, indices in enumerate(center_positions): + this_hidden: torch.Tensor = outputs.last_hidden_state[i, indices, :] + to_append, _ = torch.max(this_hidden, dim=0) + x_all.append(to_append) + + x = torch.stack(x_all) + + sequence_output = outputs.last_hidden_state + pooled_output, _ = torch.max(sequence_output, dim=1) + + x = torch.cat((x, pooled_output), dim=1) + + # fc1 + x = self.dropout(x) + x = self.fc1(x) + x = self.relu(x) + + if self.config.model.model_architecture_config is not None: + if self.config.model.model_architecture_config['fc2'] is True: + # fc2 + x = self.fc2(x) + x = self.relu(x) + x = self.dropout(x) + + if self.config.model.model_architecture_config['fc3'] is True: + # fc3 + x = self.fc3(x) + x = self.relu(x) + x = self.dropout(x) + else: + # fc2 + x = self.fc2(x) + x = self.relu(x) + x = self.dropout(x) + + # fc3 + x = self.fc3(x) + x = self.relu(x) + x = self.dropout(x) + + # output layer + x = self.fc4(x) + return x \ No newline at end of file