Skip to content

Commit

Permalink
Fix local artifact recognition and usage by SHARK downloader. (nod-ai…
Browse files Browse the repository at this point in the history
…#286)

* Fix local artifact recognition and usage by SHARK downloader.

* Update generate_sharktank.py

* Update generate_sharktank.py
  • Loading branch information
monorimet authored Aug 24, 2022
1 parent f79a6bf commit 1485777
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
15 changes: 9 additions & 6 deletions generate_sharktank.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
"""SHARK Tank"""
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
# will generate local shark tank folder like this:
# /SHARK
# /gen_shark_tank
# /albert_lite_base
# /...model_name...
# HOME
# /.local
# /shark_tank
# /albert_lite_base
# /...model_name...
#

import os
Expand All @@ -16,6 +17,7 @@
import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path

visible_default = tf.config.list_physical_devices("GPU")
try:
Expand All @@ -28,7 +30,8 @@
pass

# All generated models and metadata will be saved under this directory.
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
home = str(Path.home())
WORKDIR = os.path.join(home, ".local/shark_tank/")


def create_hash(file_name):
Expand Down Expand Up @@ -237,5 +240,5 @@ def is_valid_file(arg):
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
print("uploading files to gs://shark_tank/" + git_hash)
os.system(
"gsutil cp -r ./gen_shark_tank/* gs://shark_tank/" + git_hash
"gsutil cp -r ~/.local/shark_tank/* gs://shark_tank/" + git_hash
)
12 changes: 9 additions & 3 deletions shark/shark_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def gs_download_model():
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
gs_download_model()
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)

model_dir = os.path.join(WORKDIR, model_dir_name)
with open(
Expand Down Expand Up @@ -167,7 +169,9 @@ def gs_download_model():
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
gs_download_model()
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)

model_dir = os.path.join(WORKDIR, model_dir_name)
with open(
Expand Down Expand Up @@ -221,7 +225,9 @@ def gs_download_model():
np.load(os.path.join(model_dir, "upstream_hash.npy"))
)
if local_hash != upstream_hash:
gs_download_model()
print(
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
)

model_dir = os.path.join(WORKDIR, model_dir_name)
with open(os.path.join(model_dir, model_name + "_tf.mlir")) as f:
Expand Down

0 comments on commit 1485777

Please sign in to comment.