Skip to content

Commit

Permalink
Linter fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed May 7, 2024
1 parent 517fe3d commit ec6df1c
Show file tree
Hide file tree
Showing 10 changed files with 529 additions and 361 deletions.
73 changes: 43 additions & 30 deletions gradlib/gemm_runner.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,75 @@
import torch
import rocsolidxgemm
import hipbsolidxgemm
import numpy as np
import torch.nn.functional as F
import sys

import hipbsolidxgemm
import pandas as pd
import timeit
import rocsolidxgemm
import torch
import torch.nn.functional as F

rocsolidxgemm.rocb_create_extension()
hipbsolidxgemm.hipb_create_extension()


class TunedGemm:
def __init__(self,tuned_csv_file):
self.bestsols = pd.read_csv(tuned_csv_file,index_col=[0])

def __init__(self, tuned_csv_file):
self.bestsols = pd.read_csv(tuned_csv_file, index_col=[0])
self.create_ds()

def create_ds(self):
df = self.bestsols
solds = {}
for i in range(len(df)):
ds = df.iloc[i]
key = (ds['M'],ds['N'],ds['K'])
if ds['libtype']=='hipblaslt': soltype = 1
elif ds['libtype']=='rocblas': soltype = 2
solds[key] = (soltype,int(ds['solidx']))
key = (ds['M'], ds['N'], ds['K'])
if ds['libtype'] == 'hipblaslt':
soltype = 1
elif ds['libtype'] == 'rocblas':
soltype = 2
solds[key] = (soltype, int(ds['solidx']))
#print(solds)
self.solids = solds
def query_sol(self,m,n,k):
return self.solids.get((m,n,k),(0,0))
def mm(self,inp,weights):
soltype,solidx = self.query_sol(m=weights.shape[0],n=inp.shape[0],k=inp.shape[1])
if soltype==1:
out = hipbsolidxgemm.hipb_mm(inp,weights.t(),solidx)
elif soltype==2:
out = rocsolidxgemm.rocb_mm(inp,weights.t(),solidx)

def query_sol(self, m, n, k):
return self.solids.get((m, n, k), (0, 0))

def mm(self, inp, weights):
soltype, solidx = self.query_sol(m=weights.shape[0],
n=inp.shape[0],
k=inp.shape[1])
if soltype == 1:
out = hipbsolidxgemm.hipb_mm(inp, weights.t(), solidx)
elif soltype == 2:
out = rocsolidxgemm.rocb_mm(inp, weights.t(), solidx)
else:
out = F.linear(inp,weights)
out = F.linear(inp, weights)
return out

def run_all_tuned_sols(self):
for i in range(len(self.bestsols)):
ds = self.bestsols.iloc[i]
print('>>> Running tuned solution')
print(ds)
inp = torch.randn((ds['N'], ds['K']), dtype=get_dtype(ds['dtype']), device='cuda')
weights = torch.randn((ds['M'], ds['K']), dtype=get_dtype(ds['dtype']), device='cuda')
self.mm(inp,weights)
inp = torch.randn((ds['N'], ds['K']),
dtype=get_dtype(ds['dtype']),
device='cuda')
weights = torch.randn((ds['M'], ds['K']),
dtype=get_dtype(ds['dtype']),
device='cuda')
self.mm(inp, weights)


def get_dtype(dtype_csv):
if dtype_csv=='torch.float16':
if dtype_csv == 'torch.float16':
dtype = torch.float16
elif dtype_csv=='torch.bfloat16':
elif dtype_csv == 'torch.bfloat16':
dtype = torch.bfloat16
elif dtype_csv=='torch.float32':
elif dtype_csv == 'torch.float32':
dtype = torch.float32
return dtype


if __name__ == '__main__':
tgemm = TunedGemm(sys.argv[1]) #csv file with tuned sols goes in argv[1]
tgemm = TunedGemm(sys.argv[1]) #csv file with tuned sols goes in argv[1]
print(tgemm.bestsols)
tgemm.run_all_tuned_sols()


101 changes: 65 additions & 36 deletions gradlib/gemm_tuner.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
import torch
import os
import argparse
from gradlib.GemmTuner import GemmTuner
import rocsolidxgemm
import hipbsolidxgemm
import numpy as np
import torch.nn.functional as F
import sys
import pandas as pd
import json
import random
import os
from pathlib import Path

import hipbsolidxgemm
import pandas as pd
import rocsolidxgemm
import torch

from gradlib.GemmTuner import GemmTuner

rocsolidxgemm.rocb_create_extension()
hipbsolidxgemm.hipb_create_extension()

'''
{'architectures': ['LlamaForCausalLM'], 'bos_token_id': 1, 'eos_token_id': 2, 'hidden_act': 'silu', 'hidden_size': 5120, 'initializer_range': 0.02,
'intermediate_size': 13824, 'max_position_embeddings': 2048, 'model_type': 'llama', 'num_attention_heads': 40, 'num_hidden_layers': 40, 'num_key_value_heads': 40,
'pretraining_tp': 1, 'rms_norm_eps': 1e-05, 'rope_scaling': None, 'tie_word_embeddings': False, 'torch_dtype': 'float16', 'transformers_version': '4.33.0.dev0', 'use_cache': True, 'vocab_size': 32000}
'''

def generate_mk_sets(model_dir, tp=1):
f = open(f'{model_dir}/config.json')
data = json.load(f)
hidden_size = data['hidden_size']
intermediate_size = data['intermediate_size']
total_num_heads = data['num_attention_heads']
total_num_kv_heads = data['num_key_value_heads']
head_dim = hidden_size // total_num_heads
return [((total_num_heads + (2*total_num_kv_heads)) * head_dim // tp, hidden_size), (hidden_size, hidden_size // tp), (intermediate_size *2 // tp, hidden_size), (hidden_size, intermediate_size // tp) ], hidden_size
with open(f'{model_dir}/config.json') as f:
data = json.load(f)
hidden_size = data['hidden_size']
intermediate_size = data['intermediate_size']
total_num_heads = data['num_attention_heads']
total_num_kv_heads = data['num_key_value_heads']
head_dim = hidden_size // total_num_heads
return [((total_num_heads + (2 * total_num_kv_heads)) * head_dim // tp,
hidden_size), (hidden_size, hidden_size // tp),
(intermediate_size * 2 // tp, hidden_size),
(hidden_size, intermediate_size // tp)], hidden_size


def get_dtype(dtype_str):
dtype = torch.float16
Expand All @@ -38,28 +37,55 @@ def get_dtype(dtype_str):
elif dtype_str == 'f16':
dtype = torch.float16
else:
print('>>> Warning! Invalid dtype', dtype_str, 'using default dtype f16')
print('>>> Warning! Invalid dtype', dtype_str,
'using default dtype f16')
return dtype


def list_of_ints(arg):
return list(map(int, arg.split(',')))


def load_input_gemms(input_file):
if Path(input_file).is_file():
return
return


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default=os.getenv('GTUNE_MODEL', ""), help="Enter the location of your model directory")
parser.add_argument("--tuned_file", type=str, default=os.getenv('GTUNE_TUNED', "tuned.csv"), help="output file for tuned gemm solutions")
parser.add_argument("--input_file", type=str, default=os.getenv('GTUNE_INPUT', None), help="list of gemms to tune for, mutually exclusive with model_dir")
parser.add_argument("--tp", type=int, default=os.getenv('GTUNE_TP', 1), help="Tensor parallelism to be used.")
parser.add_argument("--dtype", type=str, default='f16', help="dtype f32 f16 bf16")
parser.add_argument("--rocblas-decode", action="store_true", default=False, help="forces rocblas solution on decode N=1")
parser.add_argument("--batch_size", type=int, default=os.getenv('GTUNE_BATCH_SIZE', 1), help="Batch size to tune for")
parser.add_argument("--nsets", type=list_of_ints, default=[1, 512, 1024, 2048, 3072, 4096, 8192, 16384], help="N sizes to tune for: 1,128,2048")
parser.add_argument("--model_dir",
type=str,
default=os.getenv('GTUNE_MODEL', ""),
help="Enter the location of your model directory")
parser.add_argument("--tuned_file",
type=str,
default=os.getenv('GTUNE_TUNED', "tuned.csv"),
help="output file for tuned gemm solutions")
parser.add_argument(
"--input_file",
type=str,
default=os.getenv('GTUNE_INPUT', None),
help="list of gemms to tune for, mutually exclusive with model_dir")
parser.add_argument("--tp",
type=int,
default=os.getenv('GTUNE_TP', 1),
help="Tensor parallelism to be used.")
parser.add_argument("--dtype",
type=str,
default='f16',
help="dtype f32 f16 bf16")
parser.add_argument("--rocblas-decode",
action="store_true",
default=False,
help="forces rocblas solution on decode N=1")
parser.add_argument("--batch_size",
type=int,
default=os.getenv('GTUNE_BATCH_SIZE', 1),
help="Batch size to tune for")
parser.add_argument("--nsets",
type=list_of_ints,
default=[1, 512, 1024, 2048, 3072, 4096, 8192, 16384],
help="N sizes to tune for: 1,128,2048")
args = parser.parse_args()

dtype = get_dtype(args.dtype)
Expand All @@ -74,16 +100,19 @@ def load_input_gemms(input_file):
shapes = pd.read_csv(args.input_file)
for i in range(len(shapes)):
ds = shapes.iloc[i]
gtuner.add_gemm(ds['M'],ds['N'],ds['K'])
gtuner.add_gemm(ds['M'], ds['N'], ds['K'])
else:
if not args.model_dir:
print(">>> Warning! NO MODEL SPECIFIED. Tuning for LL2 13B TP1")
#LL2 13B sizes
mksets = [(15360, 5120), (5120, 5120), (27648, 5120), (5120, 13824)]
mksets = [(15360, 5120), (5120, 5120), (27648, 5120),
(5120, 13824)]
gtuner.add_gemm(m=32000, n=1, k=5120) # logits gemm
else:
mksets, hidden_size = generate_mk_sets(args.model_dir, args.tp)
gtuner.add_gemm(m=32000//args.tp, n=1 * args.batch_size, k=hidden_size) #TODO: Handle cases where vocab_size is not divisible by tp
gtuner.add_gemm(
m=32000 // args.tp, n=1 * args.batch_size, k=hidden_size
) #TODO: Handle cases where vocab_size is not divisible by tp

for n in sorted(nsets):
for m, k in mksets:
Expand Down
Loading

0 comments on commit ec6df1c

Please sign in to comment.