Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Use weights_only for load #3073

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 75 additions & 46 deletions examples/FasterTransformer_HuggingFace_Bert/Bert_FT_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,11 @@
import logging
import os
import random
import timeit

import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from tqdm import tqdm, trange

from transformers import (
BertConfig,
BertTokenizer,
)
from utils.modeling_bert import BertForSequenceClassification, BertForQuestionAnswering
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors

from transformers import BertTokenizer
from utils.modeling_bert import BertForQuestionAnswering, BertForSequenceClassification

logger = logging.getLogger(__name__)

Expand All @@ -30,6 +18,7 @@ def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)


def main():
parser = argparse.ArgumentParser()

Expand All @@ -42,7 +31,10 @@ def main():
)

parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
"--config_name",
default="",
type=str,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
Expand All @@ -63,33 +55,52 @@ def main():
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--mode", default= "sequence_classification", help=" Set the model for sequence classification or question answering")
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
"--mode",
default="sequence_classification",
help=" Set the model for sequence classification or question answering",
)
parser.add_argument(
"--do_lower_case",
action="store_true",
help="Set this flag if you are using an uncased model.",
)

parser.add_argument(
"--batch_size", default=8, type=int, help="Batch size for tracing.",
"--batch_size",
default=8,
type=int,
help="Batch size for tracing.",
)

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--seed", type=int, default=42, help="random seed for initialization"
)
# parser.add_arument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

parser.add_argument("--model_type", type=str, help="ori, ths, thsext")
parser.add_argument("--data_type", type=str, help="fp32, fp16")
parser.add_argument('--ths_path', type=str, default='./lib/libpyt_fastertransformer.so',
help='path of the pyt_fastertransformer dynamic lib file')
parser.add_argument('--remove_padding', action='store_false',
help='Remove the padding of sentences of encoder.')
parser.add_argument('--allow_gemm_test', action='store_false',
help='per-channel quantization.')
parser.add_argument(
"--ths_path",
type=str,
default="./lib/libpyt_fastertransformer.so",
help="path of the pyt_fastertransformer dynamic lib file",
)
parser.add_argument(
"--remove_padding",
action="store_false",
help="Remove the padding of sentences of encoder.",
)
parser.add_argument(
"--allow_gemm_test", action="store_false", help="per-channel quantization."
)

args = parser.parse_args()

if torch.cuda.is_available():
device = torch.device("cuda")
args.device = device

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand All @@ -110,32 +121,45 @@ def main():

checkpoints = [args.model_name_or_path]
for checkpoint in checkpoints:
use_ths = args.model_type.startswith('ths')
use_ths = args.model_type.startswith("ths")
if args.mode == "sequence_classification":
model = BertForSequenceClassification.from_pretrained(checkpoint, torchscript=use_ths)
model = BertForSequenceClassification.from_pretrained(
checkpoint, torchscript=use_ths
)
elif args.mode == "question_answering":
model = BertForQuestionAnswering.from_pretrained(checkpoint, torchscript=use_ths)
model = BertForQuestionAnswering.from_pretrained(
checkpoint, torchscript=use_ths
)
model.to(args.device)

if args.data_type == 'fp16':
if args.data_type == "fp16":
logger.info("Use fp16")
model.half()
if args.model_type == 'thsext':
if args.model_type == "thsext":
logger.info("Use custom BERT encoder for TorchScript")
from utils.encoder import EncoderWeights, CustomEncoder
from utils.encoder import CustomEncoder, EncoderWeights

weights = EncoderWeights(
model.config.num_hidden_layers, model.config.hidden_size,
torch.load(os.path.join(checkpoint, 'pytorch_model.bin'), map_location='cpu'))
model.config.num_hidden_layers,
model.config.hidden_size,
torch.load(
os.path.join(checkpoint, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
),
)
weights.to_cuda()
if args.data_type == 'fp16':
if args.data_type == "fp16":
weights.to_half()
enc = CustomEncoder(model.config.num_hidden_layers,
model.config.num_attention_heads,
model.config.hidden_size//model.config.num_attention_heads,
weights,
remove_padding=args.remove_padding,
allow_gemm_test=(args.allow_gemm_test),
path=os.path.abspath(args.ths_path))
enc = CustomEncoder(
model.config.num_hidden_layers,
model.config.num_attention_heads,
model.config.hidden_size // model.config.num_attention_heads,
weights,
remove_padding=args.remove_padding,
allow_gemm_test=(args.allow_gemm_test),
path=os.path.abspath(args.ths_path),
)
enc_ = torch.jit.script(enc)
model.replace_encoder(enc_)
if use_ths:
Expand All @@ -145,14 +169,19 @@ def main():
fake_input_id = fake_input_id.to(args.device)
fake_mask = torch.ones(args.batch_size, args.max_seq_length).to(args.device)
fake_type_id = fake_input_id.clone().detach()
if args.data_type == 'fp16':
if args.data_type == "fp16":
fake_mask = fake_mask.half()
model.eval()
with torch.no_grad():
print("********** input id and mask sizes ******",fake_input_id.size(),fake_mask.size() )
print(
"********** input id and mask sizes ******",
fake_input_id.size(),
fake_mask.size(),
)
model_ = torch.jit.trace(model, (fake_input_id, fake_mask))
model = model_
torch.jit.save(model,"traced_model.pt")
torch.jit.save(model, "traced_model.pt")


if __name__ == "__main__":
main()
main()
93 changes: 43 additions & 50 deletions examples/MMF-activity-recognition/handler.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,29 @@
import json
import ast
import io
import logging
import os
import pickle
import sys
import ast

import pandas as pd
import torch
from ts.torch_handler.base_handler import BaseHandler
import io
import torchaudio
import torchvision
from omegaconf import OmegaConf
import pandas as pd
import csv

from torchvision import transforms
from torchvision.datasets.vision import VisionDataset
from torchvision.io import (
read_video_timestamps,
read_video
)

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)

from mmf.common.sample import Sample, SampleList
from mmf.utils.env import set_seed, setup_imports
from mmf.utils.logger import setup_logger, setup_very_basic_config
from mmf.datasets.base_dataset import BaseDataset
from mmf.utils.build import build_encoder, build_model, build_processors
from mmf.datasets.mmf_dataset_builder import MMFDatasetBuilder
from torch.utils.data import IterableDataset
from mmf.utils.configuration import load_yaml
from mmf.models.mmf_transformer import MMFTransformer
from mmf.utils.build import build_processors
from mmf.utils.env import setup_imports
from mmf.utils.logger import setup_very_basic_config


class MMFHandler(BaseHandler):
"""
Transformers handler class for MMFTransformerWithVideoAudio model.
"""

def __init__(self):
super(MMFHandler, self).__init__()
self.initialized = False
Expand All @@ -43,7 +32,7 @@ def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
serialized_file = self.manifest['model']['serializedFile']
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(
Expand All @@ -54,10 +43,10 @@ def initialize(self, ctx):

# reading the csv file which include all the labels in the dataset to make the class/index mapping
# and matching the output of the model with num labels from dataset
df = pd.read_csv('./charades_action_lables.csv')
df = pd.read_csv("./charades_action_lables.csv")
label_set = set()
df['action_labels'] = df['action_labels'].str.replace('"','')
labels_initial = df['action_labels'].tolist()
df["action_labels"] = df["action_labels"].str.replace('"', "")
labels_initial = df["action_labels"].tolist()
labels = []
for sublist in labels_initial:
new_sublist = ast.literal_eval(sublist)
Expand All @@ -69,68 +58,72 @@ def initialize(self, ctx):
self.classes = classes
self.labels = labels
self.idx_to_class = classes
config = OmegaConf.load('config.yaml')
config = OmegaConf.load("config.yaml")
print("*********** config keyssss **********", config.keys())
setup_very_basic_config()
setup_imports()
self.model = MMFTransformer(config.model_config.mmf_transformer)
self.model.build()
self.model.init_losses()
self.processor = build_processors(
config.dataset_config["charades"].processors
self.processor = build_processors(config.dataset_config["charades"].processors)
state_dict = torch.load(
serialized_file, map_location=self.device, weights_only=True
)
state_dict = torch.load(serialized_file, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
self.initialized = True
print("********* files in temp direcotry that .mar file got extracted *********", os.listdir(model_dir))
print(
"********* files in temp direcotry that .mar file got extracted *********",
os.listdir(model_dir),
)

def preprocess(self, requests):
""" Preprocessing, based on processor defined for MMF model.
"""

def create_sample(video_transfomred,audio_transfomred,text_tensor, video_label):
"""Preprocessing, based on processor defined for MMF model."""

def create_sample(
video_transfomred, audio_transfomred, text_tensor, video_label
):
label = [self.class_to_idx[l] for l in video_label]

one_hot_label = torch.zeros(len(self.class_to_idx))
one_hot_label[label] = 1

current_sample= Sample()
current_sample = Sample()
current_sample.video = video_transfomred
current_sample.audio = audio_transfomred
current_sample.update(text_tensor)
current_sample.targets = one_hot_label
current_sample.dataset_type = 'test'
current_sample.dataset_name = 'charades'
current_sample.dataset_type = "test"
current_sample.dataset_name = "charades"
return SampleList([current_sample]).to(self.device)

for idx, data in enumerate(requests):
raw_script = data.get('script')
script = raw_script.decode('utf-8')
raw_label = data.get('labels')
video_label = raw_label.decode('utf-8')
raw_script = data.get("script")
script = raw_script.decode("utf-8")
raw_label = data.get("labels")
video_label = raw_label.decode("utf-8")
video_label = [video_label]
video = io.BytesIO(data['data'])
video_tensor, audio_tensor,info = torchvision.io.read_video(video)

video = io.BytesIO(data["data"])
video_tensor, audio_tensor, info = torchvision.io.read_video(video)
text_tensor = self.processor["text_processor"]({"text": script})
video_transformed = self.processor["video_test_processor"](video_tensor)
audio_transformed = self.processor["audio_processor"](audio_tensor)
samples = create_sample(video_transformed,audio_transformed,text_tensor,video_label)
samples = create_sample(
video_transformed, audio_transformed, text_tensor, video_label
)

return samples

def inference(self, samples):
""" Predict the class (or classes) of the received text using the serialized transformers checkpoint.
"""
"""Predict the class (or classes) of the received text using the serialized transformers checkpoint."""
if torch.cuda.is_available():
with torch.cuda.device(samples.get_device()):
output = self.model(samples)
else:
output = self.model(samples)

sigmoid_scores = torch.sigmoid(output["scores"])
binary_scores = torch.round(sigmoid_scores)
score = binary_scores[0]
Expand Down
2 changes: 1 addition & 1 deletion examples/cpp/aot_inductor/llama2/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def load_checkpoint(checkpoint):
# load the provided model checkpoint
checkpoint_dict = torch.load(checkpoint, map_location="cpu")
checkpoint_dict = torch.load(checkpoint, map_location="cpu", weights_only=True)
gptconf = ModelArgs(**checkpoint_dict["model_args"])
model = Transformer(gptconf)
state_dict = checkpoint_dict["model"]
Expand Down
Loading