Skip to content

Commit 1082fbf

Browse files
authored
[Doc] Move Torch Tensors to GPU (#286)
1 parent ec61a8e commit 1082fbf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

docs/QuickStart.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ model = bitblas.Linear(
151151
)
152152

153153
# Create an integer weight tensor
154-
intweight = torch.randint(-7, 7, (1024, 1024), dtype=torch.int8)
154+
intweight = torch.randint(-7, 7, (1024, 1024), dtype=torch.int8).cuda()
155155

156156
# Load and transform weights into the BitBLAS linear module
157157
model.load_and_transform_weight(intweight)
@@ -166,7 +166,7 @@ model.load_state_dict(torch.load("./model.pth"))
166166
model.eval()
167167

168168
# Create a dummy input tensor
169-
dummpy_input = torch.randn(1, 1024, dtype=torch.float16)
169+
dummpy_input = torch.randn(1, 1024, dtype=torch.float16).cuda()
170170

171171
# Perform inference
172172
output = model(dummpy_input)

0 commit comments

Comments
 (0)