Skip to content

Commit db1f09e

Browse files
unit tests
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 3dbdd9b commit db1f09e

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -749,17 +749,19 @@ def _accumulate_mean(
749749
return (prev_sum + sum_added) / new_count, new_count
750750

751751

752-
def get_lowest_common_parent(names: List[str], module: Module) -> str:
752+
def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]:
753753
"""
754754
Given a list of names, returns the lowest-scope common parent,
755755
excluding parents of type ModuleList, which don't seem to play
756756
nicely with hooks.
757-
Slight alteration from os.path.commonprefix
757+
Returns name of parent and pointer to parent module
758+
759+
Implementation is a small alteration of os.path.commonprefix
758760
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
759761
"""
760762
s1 = min(names)
761763
s2 = max(names)
762-
parent_name = s1
764+
parent_name = ""
763765
for i, c in enumerate(s1):
764766
if c != s2[i]:
765767
parent_name = s1[:i].rstrip(".")

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pydantic import ValidationError
55

66
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier
7+
from llmcompressor.modifiers.awq.base import get_lowest_common_parent
78
from llmcompressor.modifiers.factory import ModifierFactory
89
from tests.llmcompressor.modifiers.conf import setup_modifier_factory
910

@@ -172,3 +173,55 @@ def test_validate():
172173
),
173174
}
174175
)
176+
177+
178+
@pytest.mark.unit
179+
def test_get_lowest_common_parent():
180+
mlp = torch.nn.ModuleDict(
181+
{
182+
"experts": torch.nn.ModuleList(
183+
[
184+
torch.nn.ModuleDict(
185+
{
186+
"gate_proj": torch.nn.Linear(4, 2),
187+
"down_proj": torch.nn.Linear(4, 2),
188+
}
189+
)
190+
for _ in range(10)
191+
]
192+
)
193+
}
194+
)
195+
self_attn = torch.nn.ModuleDict(
196+
{
197+
"q_proj": torch.nn.Linear(4, 2),
198+
"k_proj": torch.nn.Linear(4, 2),
199+
"v_proj": torch.nn.Linear(4, 2),
200+
"o_proj": torch.nn.Linear(4, 4),
201+
}
202+
)
203+
model = torch.nn.ModuleDict(
204+
{
205+
"decoder": torch.nn.ModuleDict(
206+
{
207+
"self_attn": self_attn,
208+
"mlp": mlp,
209+
}
210+
)
211+
}
212+
)
213+
214+
parent_name, parent = get_lowest_common_parent(
215+
["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model
216+
)
217+
assert parent_name == "decoder.mlp" and parent == mlp
218+
219+
parent_name, parent = get_lowest_common_parent(
220+
["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model
221+
)
222+
assert parent_name == "decoder.self_attn" and parent == self_attn
223+
224+
parent_name, parent = get_lowest_common_parent(
225+
["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model
226+
)
227+
assert parent_name == "decoder" and parent == model["decoder"]

0 commit comments

Comments
 (0)