Skip to content

Commit

Permalink
fix hardcoded filenames in convert_awq_to_bin.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ankan-ban committed Sep 5, 2023
1 parent 30d643b commit a71e4c2
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions convert_awq_to_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@
# Import os module
import os

# Load the file
data = torch.load("llama2-7b-awq.pt")
# Import sys module
import sys

# Create a sub-directory named 'bin_wt' if it does not exist
os.makedirs('bin_wt', exist_ok=True)
# Get the filename from the command line argument
filename = sys.argv[1]

# Get the directory name from the command line argument
dirname = sys.argv[2]

# Load the data from the filename
data = torch.load(filename)

# Create a sub-directory with the given name if it does not exist
os.makedirs(dirname, exist_ok=True)

# If the data is a dictionary, iterate over its keys and values
if isinstance(data, dict):
Expand All @@ -17,6 +26,6 @@
# If the value is a tensor, check its shape and dtype
if isinstance(value, torch.Tensor):
print(value.shape, value.dtype)
# Dump the tensor to a binary file with the same name as the key
with open(os.path.join('bin_wt', key + '.bin'), 'wb') as f:
# Dump the tensor to a binary file with the same name as the key in the given directory
with open(os.path.join(dirname, key + '.bin'), 'wb') as f:
f.write(value.numpy().tobytes())

0 comments on commit a71e4c2

Please sign in to comment.