Skip to content

Commit

Permalink
print the estimated tree model size
Browse files Browse the repository at this point in the history
  • Loading branch information
ericliu8168 committed Jul 31, 2024
1 parent 2982ed3 commit fecc11d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
33 changes: 33 additions & 0 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sklearn.cluster
import sklearn.preprocessing
from tqdm import tqdm
import psutil

from . import linear

Expand Down Expand Up @@ -135,13 +136,25 @@ def train_tree(
root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)

num_nodes = 0
label_feature_used = ((y.T != 0) * (x != 0)).tocsr()

def count(node):
nonlocal num_nodes
num_nodes += 1
node.num_nnz_feat = np.count_nonzero(label_feature_used[node.label_map,:].sum(axis=0) != 0)

root.dfs(count)

# Calculate the total memory (excluding swap) on the local machine
total_memory = psutil.virtual_memory().total
print(f'{total_memory / (1024**3):.3f} GB')

model_size = get_estimated_model_size(root, num_nodes)
print(f'*** model_size: {model_size / (1024**3):.3f} GB')

if (total_memory <= model_size):
raise MemoryError(f'Not enough memory to train the model. model_size: {model_size / (1024**3):.3f} GB')

pbar = tqdm(total=num_nodes, disable=not verbose)

def visit(node):
Expand Down Expand Up @@ -195,6 +208,26 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
return Node(label_map=label_map, children=children)


def get_estimated_model_size(root, num_nodes):
num_nnz_feat, num_branches = np.zeros(num_nodes), np.zeros(num_nodes)
num_nodes = 0
def collect_stat(node: Node):
nonlocal num_nodes
num_nnz_feat[num_nodes] = node.num_nnz_feat

if node.isLeaf():
num_branches[num_nodes] = len(node.label_map)
else:
num_branches[num_nodes] = len(node.children)

num_nodes += 1

root.dfs(collect_stat)

# 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes
return np.dot(num_nnz_feat, num_branches) * 16


def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node):
"""If node is internal, computes the metalabels representing each child and trains
on the metalabels. Otherwise, train on y.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ PyYAML
scikit-learn
scipy
tqdm
psutil

0 comments on commit fecc11d

Please sign in to comment.