diff --git a/generate_sharktank.py b/generate_sharktank.py index 9a331ea5d3..1e015c82f0 100644 --- a/generate_sharktank.py +++ b/generate_sharktank.py @@ -2,10 +2,11 @@ """SHARK Tank""" # python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url] # will generate local shark tank folder like this: -# /SHARK -# /gen_shark_tank -# /albert_lite_base -# /...model_name... +# HOME +# /.local +# /shark_tank +# /albert_lite_base +# /...model_name... # import os @@ -16,6 +17,7 @@ import subprocess as sp import hashlib import numpy as np +from pathlib import Path visible_default = tf.config.list_physical_devices("GPU") try: @@ -28,7 +30,8 @@ pass # All generated models and metadata will be saved under this directory. -WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank") +home = str(Path.home()) +WORKDIR = os.path.join(home, ".local/shark_tank/") def create_hash(file_name): @@ -237,5 +240,5 @@ def is_valid_file(arg): git_hash = sp.getoutput("git log -1 --format='%h'") + "/" print("uploading files to gs://shark_tank/" + git_hash) os.system( - "gsutil cp -r ./gen_shark_tank/* gs://shark_tank/" + git_hash + "gsutil cp -r ~/.local/shark_tank/* gs://shark_tank/" + git_hash ) diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index 5afb1fb7a3..fc1ad6f296 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -110,7 +110,9 @@ def gs_download_model(): np.load(os.path.join(model_dir, "upstream_hash.npy")) ) if local_hash != upstream_hash: - gs_download_model() + print( + "Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended." + ) model_dir = os.path.join(WORKDIR, model_dir_name) with open( @@ -167,7 +169,9 @@ def gs_download_model(): np.load(os.path.join(model_dir, "upstream_hash.npy")) ) if local_hash != upstream_hash: - gs_download_model() + print( + "Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended." + ) model_dir = os.path.join(WORKDIR, model_dir_name) with open( @@ -221,7 +225,9 @@ def gs_download_model(): np.load(os.path.join(model_dir, "upstream_hash.npy")) ) if local_hash != upstream_hash: - gs_download_model() + print( + "Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended." + ) model_dir = os.path.join(WORKDIR, model_dir_name) with open(os.path.join(model_dir, model_name + "_tf.mlir")) as f: