Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
KumoLiu committed Dec 27, 2023
1 parent 90d2acb commit 859575a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 39 deletions.
49 changes: 13 additions & 36 deletions monai/networks/nets/transchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
from __future__ import annotations

import math
import os
import shutil
import tarfile
import tempfile
from collections.abc import Sequence

import torch
from torch import nn

from monai.config.type_definitions import PathLike
from monai.utils import optional_import

transformers = optional_import("transformers")
load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0]
cached_path = optional_import("transformers.file_utils", name="cached_path")[0]
cached_file = optional_import("transformers.utils", name="cached_file")[0]
BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0]
BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0]

Expand Down Expand Up @@ -63,44 +60,16 @@ def from_pretrained(
state_dict=None,
cache_dir=None,
from_tf=False,
path_or_repo_id="bert-base-uncased",
filename="pytorch_model.bin",
*inputs,
**kwargs,
):
archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz"
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
tempdir = None
if os.path.isdir(resolved_archive_file) or from_tf:
serialization_dir = resolved_archive_file
else:
tempdir = tempfile.mkdtemp()
with tarfile.open(resolved_archive_file, "r:gz") as archive:

def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

prefix = os.path.commonprefix([abs_directory, abs_target])

return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")

tar.extractall(path, members, numeric_owner=numeric_owner)

safe_extract(archive, tempdir)
serialization_dir = tempdir
weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, "pytorch_model.bin")
state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)
if tempdir:
shutil.rmtree(tempdir)
if from_tf:
weights_path = os.path.join(serialization_dir, "model.ckpt")
return load_tf_weights_in_bert(model, weights_path)
old_keys = []
new_keys = []
Expand Down Expand Up @@ -304,6 +273,8 @@ def __init__(
chunk_size_feed_forward: int = 0,
is_decoder: bool = False,
add_cross_attention: bool = False,
path_or_repo_id: str | PathLike = "bert-base-uncased",
filename: str = "pytorch_model.bin",
) -> None:
"""
Args:
Expand All @@ -315,6 +286,10 @@ def __init__(
num_vision_layers: number of vision transformer layers.
num_mixed_layers: number of mixed transformer layers.
drop_out: fraction of the input units to drop.
path_or_repo_id: This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
filename: The name of the file to locate in `path_or_repo`.
The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.
Expand Down Expand Up @@ -369,6 +344,8 @@ def __init__(
num_vision_layers=num_vision_layers,
num_mixed_layers=num_mixed_layers,
bert_config=bert_config,
path_or_repo_id=path_or_repo_id,
filename=filename,
)

self.patch_size = patch_size
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin"
pandas
requests
einops
transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157
transformers>4.36.0
mlflow>=1.28.0
clearml>=1.10.0rc0
matplotlib!=3.5.0
Expand Down
3 changes: 1 addition & 2 deletions tests/test_transchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from monai.networks import eval_mode
from monai.networks.nets.transchex import Transchex
from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick
from tests.utils import skip_if_quick

TEST_CASE_TRANSCHEX = []
for drop_out in [0.4]:
Expand Down Expand Up @@ -46,7 +46,6 @@


@skip_if_quick
@SkipIfAtLeastPyTorchVersion((1, 10))
class TestTranschex(unittest.TestCase):
@parameterized.expand(TEST_CASE_TRANSCHEX)
def test_shape(self, input_param, expected_shape):
Expand Down

0 comments on commit 859575a

Please sign in to comment.