3535WORKDIR = os .path .join (home , ".local/shark_tank/" )
3636print (WORKDIR )
3737
38+
3839# Checks whether the directory and files exists.
3940def check_dir_exists (model_name , frontend = "torch" , dynamic = "" ):
4041 model_dir = os .path .join (WORKDIR , model_name )
4142
4243 # Remove the _tf keyword from end.
4344 if frontend in ["tf" , "tensorflow" ]:
4445 model_name = model_name [:- 3 ]
46+ elif frontend in ["tflite" ]:
47+ model_name = model_name [:- 7 ]
48+ elif frontend in ["torch" , "pytorch" ]:
49+ model_name = model_name [:- 6 ]
4550
4651 if os .path .isdir (model_dir ):
4752 if (
4853 os .path .isfile (
49- os .path .join (model_dir , model_name + dynamic + ".mlir" )
54+ os .path .join (
55+ model_dir ,
56+ model_name + dynamic + "_" + str (frontend ) + ".mlir" ,
57+ )
5058 )
5159 and os .path .isfile (os .path .join (model_dir , "function_name.npy" ))
5260 and os .path .isfile (os .path .join (model_dir , "inputs.npz" ))
@@ -65,27 +73,28 @@ def download_torch_model(model_name, dynamic=False):
6573 model_name = model_name .replace ("/" , "_" )
6674 dyn_str = "_dynamic" if dynamic else ""
6775 os .makedirs (WORKDIR , exist_ok = True )
76+ model_dir_name = model_name + "_torch"
6877
6978 def gs_download_model ():
7079 gs_command = (
7180 'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
7281 + "/"
73- + model_name
82+ + model_dir_name
7483 + " "
7584 + WORKDIR
7685 )
7786 if os .system (gs_command ) != 0 :
7887 raise Exception ("model not present in the tank. Contact Nod Admin" )
7988
80- if not check_dir_exists (model_name , dyn_str ):
89+ if not check_dir_exists (model_dir_name , frontend = "torch" , dynamic = dyn_str ):
8190 gs_download_model ()
8291 else :
83- model_dir = os .path .join (WORKDIR , model_name )
92+ model_dir = os .path .join (WORKDIR , model_dir_name )
8493 local_hash = str (np .load (os .path .join (model_dir , "hash.npy" )))
8594 gs_hash = (
8695 'gsutil -o "GSUtil:parallel_process_count=1" cp gs://shark_tank'
8796 + "/"
88- + model_name
97+ + model_dir_name
8998 + "/hash.npy"
9099 + " "
91100 + os .path .join (model_dir , "upstream_hash.npy" )
@@ -98,8 +107,10 @@ def gs_download_model():
98107 if local_hash != upstream_hash :
99108 gs_download_model ()
100109
101- model_dir = os .path .join (WORKDIR , model_name )
102- with open (os .path .join (model_dir , model_name + dyn_str + ".mlir" )) as f :
110+ model_dir = os .path .join (WORKDIR , model_dir_name )
111+ with open (
112+ os .path .join (model_dir , model_name + dyn_str + "_torch.mlir" )
113+ ) as f :
103114 mlir_file = f .read ()
104115
105116 function_name = str (np .load (os .path .join (model_dir , "function_name.npy" )))
@@ -115,18 +126,21 @@ def gs_download_model():
115126def download_tflite_model (model_name , dynamic = False ):
116127 dyn_str = "_dynamic" if dynamic else ""
117128 os .makedirs (WORKDIR , exist_ok = True )
118- if not check_dir_exists (model_name , dyn_str ):
129+ model_dir_name = model_name + "_tflite"
130+ if not check_dir_exists (
131+ model_dir_name , frontend = "tflite" , dynamic = dyn_str
132+ ):
119133 gs_command = (
120134 'gsutil -o "GSUtil:parallel_process_count=1" cp -r gs://shark_tank'
121135 + "/"
122- + model_name
136+ + model_dir_name
123137 + " "
124138 + WORKDIR
125139 )
126140 if os .system (gs_command ) != 0 :
127141 raise Exception ("model not present in the tank. Contact Nod Admin" )
128142
129- model_dir = os .path .join (WORKDIR , model_name )
143+ model_dir = os .path .join (WORKDIR , model_dir_name )
130144 with open (
131145 os .path .join (model_dir , model_name + dyn_str + "_tflite.mlir" )
132146 ) as f :
0 commit comments