Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 61 additions & 15 deletions deepmd/pt/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@


class AutoBatchSize(AutoBatchSizeBase):
"""Auto batch size.

Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
is not set
factor : float, default: 2.
increased factor

"""

def __init__(
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
):
super().__init__(
initial_batch_size=initial_batch_size,
factor=factor,
)

def is_gpu_available(self) -> bool:
"""Check if GPU is available.

Expand Down Expand Up @@ -78,26 +100,50 @@ def execute_with_batch_size(
)

index = 0
results = []
results = None
returned_dict = None
while index < total_size:
n_batch, result = self.execute(execute_with_batch_size, index, natoms)
if not isinstance(result, tuple):
result = (result,)
returned_dict = (
isinstance(result, dict) if returned_dict is None else returned_dict
)
if not returned_dict:
result = (result,) if not isinstance(result, tuple) else result
index += n_batch
if n_batch:
for rr in result:
rr.reshape((n_batch, -1))
results.append(result)
r_list = []
for r in zip(*results):

def append_to_list(res_list, res):
if n_batch:
res_list.append(res)
return res_list

if not returned_dict:
results = [] if results is None else results
results = append_to_list(results, result)
else:
results = (
{kk: [] for kk in result.keys()} if results is None else results
)
results = {
kk: append_to_list(results[kk], result[kk]) for kk in result.keys()
}
assert results is not None
assert returned_dict is not None

def concate_result(r):
if isinstance(r[0], np.ndarray):
r_list.append(np.concatenate(r, axis=0))
ret = np.concatenate(r, axis=0)
elif isinstance(r[0], torch.Tensor):
r_list.append(torch.cat(r, dim=0))
ret = torch.cat(r, dim=0)
else:
raise RuntimeError(f"Unexpected result type {type(r[0])}")
r = tuple(r_list)
if len(r) == 1:
# avoid returning tuple if callable doesn't return tuple
r = r[0]
return ret

if not returned_dict:
r_list = [concate_result(r) for r in zip(*results)]
r = tuple(r_list)
if len(r) == 1:
# avoid returning tuple if callable doesn't return tuple
r = r[0]
else:
r = {kk: concate_result(vv) for kk, vv in results.items()}
return r
37 changes: 37 additions & 0 deletions source/tests/pt/test_auto_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np

from deepmd.pt.utils.auto_batch_size import (
AutoBatchSize,
)


class TestAutoBatchSize(unittest.TestCase):
def test_execute_all(self):
dd0 = np.zeros((10000, 2, 1, 3, 4))
dd1 = np.ones((10000, 2, 1, 3, 4))
auto_batch_size = AutoBatchSize(256, 2.0)

def func(dd1):
return np.zeros_like(dd1), np.ones_like(dd1)

dd2 = auto_batch_size.execute_all(func, 10000, 2, dd1)
np.testing.assert_equal(dd0, dd2[0])
np.testing.assert_equal(dd1, dd2[1])

def test_execute_all_dict(self):
dd0 = np.zeros((10000, 2, 1, 3, 4))
dd1 = np.ones((10000, 2, 1, 3, 4))
auto_batch_size = AutoBatchSize(256, 2.0)

def func(dd1):
return {
"foo": np.zeros_like(dd1),
"bar": np.ones_like(dd1),
}

dd2 = auto_batch_size.execute_all(func, 10000, 2, dd1)
np.testing.assert_equal(dd0, dd2["foo"])
np.testing.assert_equal(dd1, dd2["bar"])