Skip to content

Commit

Permalink
Optimize data processor logic (autogluon#2125)
Browse files Browse the repository at this point in the history
* optimize data processor logic

* fix

* update

* fix lint

* update

* fix

* remove the data_types class attribute
  • Loading branch information
zhiqiangdon authored Sep 14, 2022
1 parent d55f491 commit 210212e
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from nptyping import NDArray
from torch import nn

from ..constants import CATEGORICAL, COLUMN
from .collator import Stack, Tuple
Expand All @@ -16,7 +17,7 @@ class CategoricalProcessor:

def __init__(
self,
prefix: str,
model: nn.Module,
requires_column_info: bool = False,
):
"""
Expand All @@ -27,7 +28,7 @@ def __init__(
requires_column_info
Whether to require feature column information in dataloader.
"""
self.prefix = prefix
self.prefix = model.prefix
self.requires_column_info = requires_column_info

@property
Expand Down
44 changes: 15 additions & 29 deletions multimodal/src/autogluon/multimodal/data/process_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import PIL
import torch
from timm import create_model
from timm.data.constants import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
Expand All @@ -15,7 +14,6 @@
)
from torch import nn
from torchvision import transforms
from transformers import AutoConfig

from .randaug import RandAugment

Expand Down Expand Up @@ -71,16 +69,14 @@ class ImageProcessor:

def __init__(
self,
prefix: str,
model: nn.Module,
train_transform_types: List[str],
val_transform_types: List[str],
checkpoint_name: Optional[str] = None,
norm_type: Optional[str] = None,
size: Optional[int] = None,
max_img_num_per_col: Optional[int] = 1,
missing_value_strategy: Optional[str] = "skip",
requires_column_info: bool = False,
model: Optional[nn.Module] = None,
):
"""
Parameters
Expand Down Expand Up @@ -122,24 +118,20 @@ def __init__(
model
The model using this data processor.
"""
self.checkpoint_name = checkpoint_name
self.prefix = prefix
self.train_transform_types = train_transform_types
self.val_transform_types = val_transform_types
logger.debug(f"image training transform type: {train_transform_types}")
logger.debug(f"image validation transform type: {val_transform_types}")

self.prefix = model.prefix
self.missing_value_strategy = missing_value_strategy
self.requires_column_info = requires_column_info
self.size = None
self.mean = None
self.std = None

if checkpoint_name is not None:
if self.prefix == MMDET_IMAGE or self.prefix == MMOCR_TEXT_DET:
self.size, self.mean, self.std = self.extract_default(checkpoint_name, cfg=model.model.cfg)
else:
self.size, self.mean, self.std = self.extract_default(checkpoint_name)
if model is not None:
self.size, self.mean, self.std = self.extract_default(model.config)
if self.size is None:
if size is not None:
self.size = size
Expand Down Expand Up @@ -246,7 +238,7 @@ def mean_std(norm_type: str):
else:
raise ValueError(f"unknown image normalization: {norm_type}")

def extract_default(self, checkpoint_name, cfg=None):
def extract_default(self, config=None):
"""
Extract some default hyper-parameters, e.g., image size, mean, and std,
from a pre-trained (timm or huggingface) checkpoint.
Expand All @@ -266,26 +258,20 @@ def extract_default(self, checkpoint_name, cfg=None):
Image normalizaiton std.
"""
if self.prefix.lower().startswith(MMDET_IMAGE):
image_size = cfg.test_pipeline[1]["img_scale"][0]
mean = cfg.test_pipeline[1]["transforms"][2]["mean"]
std = cfg.test_pipeline[1]["transforms"][2]["std"]
image_size = config.test_pipeline[1]["img_scale"][0]
mean = config.test_pipeline[1]["transforms"][2]["mean"]
std = config.test_pipeline[1]["transforms"][2]["std"]
elif self.prefix.lower().startswith(MMOCR_TEXT_DET):
image_size = cfg.data.test.pipeline[1]["img_scale"][0]
mean = cfg.data.test.pipeline[1]["transforms"][1]["mean"]
std = cfg.data.test.pipeline[1]["transforms"][1]["std"]
image_size = config.data.test.pipeline[1]["img_scale"][0]
mean = config.data.test.pipeline[1]["transforms"][1]["mean"]
std = config.data.test.pipeline[1]["transforms"][1]["std"]
elif self.prefix.lower().startswith(TIMM_IMAGE):
model = create_model(
checkpoint_name,
pretrained=True,
num_classes=0,
)
image_size = model.default_cfg["input_size"][-1]
mean = model.default_cfg["mean"]
std = model.default_cfg["std"]
image_size = config["input_size"][-1]
mean = config["mean"]
std = config["std"]
elif self.prefix.lower().startswith(CLIP):
config = AutoConfig.from_pretrained(checkpoint_name).to_diff_dict()
extracted = extract_value_from_config(
config=config,
config=config.to_diff_dict(),
keys=("image_size",),
)
if len(extracted) == 0:
Expand Down
5 changes: 3 additions & 2 deletions multimodal/src/autogluon/multimodal/data/process_label.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, Union

from nptyping import NDArray
from torch import nn

from ..constants import LABEL
from .collator import Stack
Expand All @@ -15,15 +16,15 @@ class LabelProcessor:

def __init__(
self,
prefix: str,
model: nn.Module,
):
"""
Parameters
----------
prefix
The prefix connecting a processor to its corresponding model.
"""
self.prefix = prefix
self.prefix = model.prefix

@property
def label_key(self):
Expand Down
5 changes: 3 additions & 2 deletions multimodal/src/autogluon/multimodal/data/process_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from nptyping import NDArray
from torch import nn

from ..constants import COLUMN, NUMERICAL
from .collator import Stack
Expand All @@ -16,7 +17,7 @@ class NumericalProcessor:

def __init__(
self,
prefix: str,
model: nn.Module,
merge: Optional[str] = "concat",
requires_column_info: bool = False,
):
Expand All @@ -33,7 +34,7 @@ def __init__(
requires_column_info
Whether to require feature column information in dataloader.
"""
self.prefix = prefix
self.prefix = model.prefix
self.merge = merge
self.requires_column_info = requires_column_info

Expand Down
17 changes: 7 additions & 10 deletions multimodal/src/autogluon/multimodal/data/process_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from typing import Any, Dict, List, Optional

import numpy as np
import torch
from nptyping import NDArray
from omegaconf import DictConfig
from torch import nn
from transformers import AutoConfig, AutoTokenizer, BertTokenizer, CLIPTokenizer, ElectraTokenizer

from ..constants import AUTOMM, CHOICES_IDS, COLUMN, TEXT, TEXT_SEGMENT_IDS, TEXT_TOKEN_IDS, TEXT_VALID_LENGTH
Expand Down Expand Up @@ -78,8 +78,7 @@ class TextProcessor:

def __init__(
self,
prefix: str,
checkpoint_name: str,
model: nn.Module,
tokenizer_name: Optional[str] = "hf_auto",
max_len: Optional[int] = None,
insert_sep: Optional[bool] = True,
Expand Down Expand Up @@ -119,13 +118,12 @@ def __init__(
train_augment_types
All possible augmentation operations
"""
self.prefix = prefix
self.prefix = model.prefix
self.tokenizer_name = tokenizer_name
self.checkpoint_name = checkpoint_name
self.requires_column_info = requires_column_info
self.tokenizer = self.get_pretrained_tokenizer(
tokenizer_name=tokenizer_name,
checkpoint_name=checkpoint_name,
checkpoint_name=model.checkpoint_name,
)
if hasattr(self.tokenizer, "deprecation_warnings"):
# Disable the warning "Token indices sequence length is longer than the specified maximum sequence..."
Expand All @@ -139,16 +137,15 @@ def __init__(
if max_len < self.tokenizer.model_max_length:
warnings.warn(
f"provided max length: {max_len} "
f"is smaller than {checkpoint_name}'s default: {self.tokenizer.model_max_length}"
f"is smaller than {model.checkpoint_name}'s default: {self.tokenizer.model_max_length}"
)
self.max_len = min(max_len, self.tokenizer.model_max_length)
logger.debug(f"text max length: {self.max_len}")

self.insert_sep = insert_sep
self.eos_only = self.cls_token_id == self.sep_token_id == self.eos_token_id

config = AutoConfig.from_pretrained(checkpoint_name).to_diff_dict()
extracted = extract_value_from_config(config=config, keys=("type_vocab_size",))
extracted = extract_value_from_config(config=model.config.to_diff_dict(), keys=("type_vocab_size",))
if len(extracted) == 0:
default_segment_num = 1
elif len(extracted) == 1:
Expand All @@ -162,7 +159,7 @@ def __init__(
if text_segment_num < default_segment_num:
warnings.warn(
f"provided text_segment_num: {text_segment_num} "
f"is smaller than {checkpoint_name}'s default: {default_segment_num}"
f"is smaller than {model.checkpoint_name}'s default: {default_segment_num}"
)
self.text_segment_num = min(text_segment_num, default_segment_num)
assert self.text_segment_num >= 1
Expand Down
1 change: 1 addition & 0 deletions multimodal/src/autogluon/multimodal/models/timm_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.checkpoint_name = checkpoint_name
self.pretrained = pretrained
self.model = create_model(checkpoint_name, pretrained=pretrained, num_classes=num_classes)
self.config = self.model.default_cfg
self.num_classes = self.model.num_classes
self.out_features = self.model.num_features
self.head = get_model_head(model=self.model)
Expand Down
33 changes: 19 additions & 14 deletions multimodal/src/autogluon/multimodal/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
average_checkpoints,
compute_num_gpus,
compute_score,
create_fusion_data_processors,
create_model,
data_to_df,
extract_from_output,
Expand All @@ -96,7 +97,6 @@
infer_dtypes_by_model_names,
infer_metrics,
infer_scarcity_mode_by_data_size,
init_data_processors,
init_df_preprocessor,
init_pretrained,
load_text_tokenizers,
Expand Down Expand Up @@ -860,16 +860,6 @@ def _fit(

config = select_model(config=config, df_preprocessor=df_preprocessor)

if self._data_processors is None:
data_processors = init_data_processors(
config=config,
)
else: # continuing training
data_processors = self._data_processors

data_processors_count = {k: len(v) for k, v in data_processors.items()}
logger.debug(f"data_processors_count: {data_processors_count}")

if self._model is None:
model = create_model(
config=config,
Expand All @@ -880,6 +870,17 @@ def _fit(
else: # continuing training
model = self._model

if self._data_processors is None:
data_processors = create_fusion_data_processors(
config=config,
model=model,
)
else: # continuing training
data_processors = self._data_processors

data_processors_count = {k: len(v) for k, v in data_processors.items()}
logger.debug(f"data_processors_count: {data_processors_count}")

pos_label = try_to_infer_pos_label(
data_config=config.data,
label_encoder=df_preprocessor.label_generator,
Expand Down Expand Up @@ -1908,9 +1909,7 @@ def _load_metadata(
# Only keep the modalities with non-empty processors.
data_processors = {k: v for k, v in data_processors.items() if len(v) > 0}
except: # backward compatibility. reconstruct the data processor in case something went wrong.
data_processors = init_data_processors(
config=config,
)
data_processors = None

predictor._label_column = assets["label_column"]
predictor._problem_type = assets["problem_type"]
Expand Down Expand Up @@ -1969,6 +1968,12 @@ def load(
pretrained=False, # set "pretrain=False" to prevent downloading online models
)

if predictor._data_processors is None:
predictor._data_processors = create_fusion_data_processors(
config=predictor._config,
model=model,
)

resume_ckpt_path = os.path.join(path, LAST_CHECKPOINT)
final_ckpt_path = os.path.join(path, MODEL_CHECKPOINT)
if resume: # resume training which crashed before
Expand Down
Loading

0 comments on commit 210212e

Please sign in to comment.