-
Notifications
You must be signed in to change notification settings - Fork 47
feat: implement dynamic batch sizing for ML validation #816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0bf606e
e3ad6d8
c812910
1d4bd69
05aa20d
928f751
b75d492
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import gc | ||
| import logging | ||
| import psutil | ||
| from typing import Optional | ||
| from dataclasses import dataclass | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| @dataclass | ||
| class MemoryInfo: | ||
| total_mb: float | ||
| available_mb: float | ||
| used_mb: float | ||
| percent_used: float | ||
| process_mb: float | ||
|
|
||
| class MemoryMonitor: | ||
| def __init__(self, safety_margin: float = 0.2): | ||
| self.safety_margin = safety_margin | ||
| self.process = psutil.Process() | ||
|
|
||
| def get_memory_info(self) -> MemoryInfo: | ||
| memory = psutil.virtual_memory() | ||
| process_memory = self.process.memory_info() | ||
|
|
||
| return MemoryInfo( | ||
| total_mb=memory.total / 1024 / 1024, | ||
| available_mb=memory.available / 1024 / 1024, | ||
| used_mb=memory.used / 1024 / 1024, | ||
| percent_used=memory.percent, | ||
| process_mb=process_memory.rss / 1024 / 1024 | ||
| ) | ||
|
|
||
| def get_available_memory_mb(self) -> float: | ||
| memory_info = self.get_memory_info() | ||
| safe_available = memory_info.available_mb * (1 - self.safety_margin) | ||
| return max(0, safe_available - memory_info.process_mb) | ||
|
|
||
| def is_memory_pressure_high(self) -> bool: | ||
| memory_info = self.get_memory_info() | ||
| return memory_info.percent_used > (1 - self.safety_margin) * 100 | ||
|
|
||
| def force_garbage_collection(self) -> float: | ||
| before_mb = self.get_memory_info().process_mb | ||
| gc.collect() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suppose, the python gc is not helpful when onnxruntime allocated memory in native code. class MemoryMonitor is not necessary. There are only recommendation of memory consumption in help and the tests valuable. |
||
| after_mb = self.get_memory_info().process_mb | ||
| freed_mb = before_mb - after_mb | ||
| if freed_mb > 0: | ||
| logger.debug(f"Garbage collection freed {freed_mb:.2f} MB") | ||
| return freed_mb | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,56 @@ class MlValidator: | |
| ZERO_CHAR = '\x00' | ||
| # applied for unknown characters | ||
| FAKE_CHAR = '\x01' | ||
|
|
||
| # Memory usage estimates per batch size (in MB) | ||
| MEMORY_ESTIMATES = { | ||
| 1: 50, # ~50MB for batch size 1 | ||
| 2: 75, # ~75MB for batch size 2 | ||
| 4: 125, # ~125MB for batch size 4 | ||
| 8: 200, # ~200MB for batch size 8 | ||
| 16: 350, # ~350MB for batch size 16 | ||
| 32: 600, # ~600MB for batch size 32 | ||
| 64: 1000, # ~1GB for batch size 64 | ||
| 128: 1800, # ~1.8GB for batch size 128 | ||
| 256: 3200, # ~3.2GB for batch size 256 | ||
| 512: 6000 # ~6GB for batch size 512 | ||
| } | ||
|
|
||
| @classmethod | ||
| def get_memory_estimate(cls, batch_size: int) -> int: | ||
| """Get estimated memory usage for a batch size in MB""" | ||
| if batch_size in cls.MEMORY_ESTIMATES: | ||
| return cls.MEMORY_ESTIMATES[batch_size] | ||
|
|
||
| # Find nearest estimate and interpolate | ||
| sizes = sorted(cls.MEMORY_ESTIMATES.keys()) | ||
| for size in sizes: | ||
| if batch_size <= size: | ||
| return cls.MEMORY_ESTIMATES[size] | ||
|
|
||
| # Extrapolate for very large batch sizes | ||
| largest_size = sizes[-1] | ||
| largest_memory = cls.MEMORY_ESTIMATES[largest_size] | ||
| ratio = batch_size / largest_size | ||
| return int(largest_memory * ratio) | ||
|
|
||
| @classmethod | ||
| def get_memory_info_text(cls) -> str: | ||
| """Get formatted memory information for help text""" | ||
| info_lines = ["Memory usage estimates for different batch sizes:"] | ||
| for batch_size, memory_mb in sorted(cls.MEMORY_ESTIMATES.items()): | ||
| if memory_mb < 1000: | ||
| info_lines.append(f" Batch size {batch_size:3d}: ~{memory_mb}MB") | ||
| else: | ||
| info_lines.append(f" Batch size {batch_size:3d}: ~{memory_mb//1000}GB") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RAM uses x^2 measurement |
||
|
|
||
| info_lines.extend([ | ||
| "", | ||
| "Use smaller batch sizes if you encounter memory limitations.", | ||
| "Example: --ml_batch_size 4 for ~125MB memory usage" | ||
| ]) | ||
|
|
||
| return "\n".join(info_lines) | ||
|
|
||
| _dir_path = Path(__file__).parent | ||
|
|
||
|
|
@@ -245,30 +295,31 @@ def validate_groups(self, group_list: List[Tuple[CandidateKey, List[Candidate]]] | |
| Return: | ||
| Boolean numpy array with decision based on the threshold, | ||
| and numpy array with probability predicted by the model | ||
|
|
||
| """ | ||
| line_input_list = [] | ||
| variable_input_list = [] | ||
| value_input_list = [] | ||
| features_list = [] | ||
| probability: np.ndarray = np.zeros(len(group_list), dtype=np.float32) | ||
| head = tail = 0 | ||
|
|
||
| for _group_key, candidates in group_list: | ||
| line_input, variable_input, value_input, feature_array = self.get_group_features(candidates) | ||
| line_input_list.append(line_input) | ||
| variable_input_list.append(variable_input) | ||
| value_input_list.append(value_input) | ||
| features_list.append(feature_array) | ||
| tail += 1 | ||
|
|
||
| if 0 == tail % batch_size: | ||
| # use the approach to reduce memory consumption for huge candidates list | ||
| probability[head:tail] = self._batch_call_model(line_input_list, variable_input_list, value_input_list, | ||
| features_list) | ||
| head = tail | ||
| line_input_list.clear() | ||
| variable_input_list.clear() | ||
| value_input_list.clear() | ||
| features_list.clear() | ||
|
|
||
| if head != tail: | ||
| probability[head:tail] = self._batch_call_model(line_input_list, variable_input_list, value_input_list, | ||
| features_list) | ||
|
|
@@ -277,5 +328,4 @@ def validate_groups(self, group_list: List[Tuple[CandidateKey, List[Candidate]]] | |
| for i, decision in enumerate(is_cred): | ||
| logger.debug("ML decision: %s with prediction: %s for value: %s", decision, probability[i], | ||
| group_list[i][0]) | ||
| # apply cast to float to avoid json export issue | ||
| return is_cred, probability.astype(float) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| import unittest | ||
| from unittest.mock import Mock, patch | ||
| import psutil | ||
|
|
||
| from credsweeper.ml_model.memory_monitor import MemoryMonitor, MemoryInfo | ||
|
|
||
|
|
||
| class TestMemoryMonitor(unittest.TestCase): | ||
|
|
||
| def setUp(self): | ||
| self.monitor = MemoryMonitor(safety_margin=0.2) | ||
|
|
||
| @patch('psutil.virtual_memory') | ||
| @patch('psutil.Process') | ||
| def test_get_memory_info(self, mock_process, mock_virtual_memory): | ||
| mock_memory = Mock() | ||
| mock_memory.total = 8589934592 | ||
| mock_memory.available = 4294967296 | ||
| mock_memory.used = 4294967296 | ||
| mock_memory.percent = 50.0 | ||
| mock_virtual_memory.return_value = mock_memory | ||
|
|
||
| mock_proc = Mock() | ||
| mock_proc.memory_info.return_value.rss = 1073741824 | ||
| mock_process.return_value = mock_proc | ||
|
|
||
| info = self.monitor.get_memory_info() | ||
|
|
||
| self.assertEqual(info.total_mb, 8192.0) | ||
| self.assertEqual(info.available_mb, 4096.0) | ||
| self.assertEqual(info.used_mb, 4096.0) | ||
| self.assertEqual(info.percent_used, 50.0) | ||
| self.assertEqual(info.process_mb, 1024.0) | ||
|
|
||
| @patch('psutil.virtual_memory') | ||
| @patch('psutil.Process') | ||
| def test_get_available_memory_mb(self, mock_process, mock_virtual_memory): | ||
| mock_memory = Mock() | ||
| mock_memory.available = 4294967296 | ||
| mock_virtual_memory.return_value = mock_memory | ||
|
|
||
| mock_proc = Mock() | ||
| mock_proc.memory_info.return_value.rss = 1073741824 | ||
| mock_process.return_value = mock_proc | ||
|
|
||
| available = self.monitor.get_available_memory_mb() | ||
| expected = (4096.0 * 0.8) - 1024.0 | ||
| self.assertAlmostEqual(available, expected, places=2) | ||
|
|
||
| @patch('psutil.virtual_memory') | ||
| def test_is_memory_pressure_high(self, mock_virtual_memory): | ||
| mock_memory = Mock() | ||
| mock_memory.percent = 85.0 | ||
| mock_virtual_memory.return_value = mock_memory | ||
|
|
||
| self.assertTrue(self.monitor.is_memory_pressure_high()) | ||
|
|
||
| mock_memory.percent = 75.0 | ||
| self.assertFalse(self.monitor.is_memory_pressure_high()) | ||
|
|
||
| @patch('gc.collect') | ||
| @patch.object(MemoryMonitor, 'get_memory_info') | ||
| def test_force_garbage_collection(self, mock_get_memory_info, mock_gc_collect): | ||
| mock_gc_collect.return_value = None | ||
|
|
||
| mock_before = MemoryInfo(8192, 4096, 4096, 50, 1024) | ||
| mock_after = MemoryInfo(8192, 4096, 4096, 50, 512) | ||
| mock_get_memory_info.side_effect = [mock_before, mock_after] | ||
|
|
||
| freed = self.monitor.force_garbage_collection() | ||
|
|
||
| self.assertEqual(freed, 512.0) | ||
| mock_gc_collect.assert_called_once() | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let use a constant for minimal required memory for the default batch size. The constant must be used in tests for memory limitation. Try subprocess+resource (the tests may be skipped for Windows)