77
88# pyre-strict
99
10+ import copy
1011import random
1112from dataclasses import dataclass
1213from typing import Any , cast , Dict , List , Optional , Tuple , Type , Union
@@ -239,10 +240,16 @@ def _validate_pooling_factor(
239240 else None
240241 )
241242
242- global_float = torch .rand (
243- (batch_size * world_size , num_float_features ), device = device
244- )
245- global_label = torch .rand (batch_size * world_size , device = device )
243+ if randomize_indices :
244+ global_float = torch .rand (
245+ (batch_size * world_size , num_float_features ), device = device
246+ )
247+ global_label = torch .rand (batch_size * world_size , device = device )
248+ else :
249+ global_float = torch .zeros (
250+ (batch_size * world_size , num_float_features ), device = device
251+ )
252+ global_label = torch .zeros (batch_size * world_size , device = device )
246253
247254 # Split global batch into local batches.
248255 local_inputs = []
@@ -939,6 +946,7 @@ def __init__(
939946 max_feature_lengths_list : Optional [List [Dict [str , int ]]] = None ,
940947 feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None ,
941948 over_arch_clazz : Type [nn .Module ] = TestOverArch ,
949+ preproc_module : Optional [nn .Module ] = None ,
942950 ) -> None :
943951 super ().__init__ (
944952 tables = cast (List [BaseEmbeddingConfig ], tables ),
@@ -960,13 +968,22 @@ def __init__(
960968 embedding_names = (
961969 list (embedding_groups .values ())[0 ] if embedding_groups else None
962970 )
971+ self ._embedding_names : List [str ] = (
972+ embedding_names
973+ if embedding_names
974+ else [feature for table in tables for feature in table .feature_names ]
975+ )
976+ self ._weighted_features : List [str ] = [
977+ feature for table in weighted_tables for feature in table .feature_names
978+ ]
963979 self .over : nn .Module = over_arch_clazz (
964980 tables , weighted_tables , embedding_names , dense_device
965981 )
966982 self .register_buffer (
967983 "dummy_ones" ,
968984 torch .ones (1 , device = dense_device ),
969985 )
986+ self .preproc_module = preproc_module
970987
971988 def sparse_forward (self , input : ModelInput ) -> KeyedTensor :
972989 return self .sparse (
@@ -993,6 +1010,8 @@ def forward(
9931010 self ,
9941011 input : ModelInput ,
9951012 ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
1013+ if self .preproc_module :
1014+ input = self .preproc_module (input )
9961015 return self .dense_forward (input , self .sparse_forward (input ))
9971016
9981017
@@ -1409,3 +1428,246 @@ def _post_ebc_test_wrap_function(kt: KeyedTensor) -> KeyedTensor:
14091428 continue
14101429
14111430 return kt
1431+
1432+
1433+ class TestPreprocNonWeighted (nn .Module ):
1434+ """
1435+ Basic module for testing
1436+
1437+ Args: None
1438+ Examples:
1439+ >>> TestPreprocNonWeighted()
1440+ Returns:
1441+ List[KeyedJaggedTensor]
1442+ """
1443+
1444+ def forward (self , kjt : KeyedJaggedTensor ) -> List [KeyedJaggedTensor ]:
1445+ """
1446+ Selects 3 features from a specific KJT
1447+ """
1448+ # split
1449+ jt_0 = kjt ["feature_0" ]
1450+ jt_1 = kjt ["feature_1" ]
1451+ jt_2 = kjt ["feature_2" ]
1452+
1453+ # merge only features 0,1,2, removing feature 3
1454+ return [
1455+ KeyedJaggedTensor .from_jt_dict (
1456+ {
1457+ "feature_0" : jt_0 ,
1458+ "feature_1" : jt_1 ,
1459+ "feature_2" : jt_2 ,
1460+ }
1461+ )
1462+ ]
1463+
1464+
1465+ class TestPreprocWeighted (nn .Module ):
1466+ """
1467+ Basic module for testing
1468+
1469+ Args: None
1470+ Examples:
1471+ >>> TestPreprocWeighted()
1472+ Returns:
1473+ List[KeyedJaggedTensor]
1474+ """
1475+
1476+ def forward (self , kjt : KeyedJaggedTensor ) -> List [KeyedJaggedTensor ]:
1477+ """
1478+ Selects 1 feature from specific weighted KJT
1479+ """
1480+
1481+ # split
1482+ jt_0 = kjt ["weighted_feature_0" ]
1483+
1484+ # keep only weighted_feature_0
1485+ return [
1486+ KeyedJaggedTensor .from_jt_dict (
1487+ {
1488+ "weighted_feature_0" : jt_0 ,
1489+ }
1490+ )
1491+ ]
1492+
1493+
1494+ class TestModelWithPreproc (nn .Module ):
1495+ """
1496+ Basic module with up to 3 preproc modules:
1497+ - preproc on idlist_features for non-weighted EBC
1498+ - preproc on idscore_features for weighted EBC
1499+ - optional preproc on model input shared by both EBCs
1500+
1501+ Args:
1502+ tables,
1503+ weighted_tables,
1504+ device,
1505+ preproc_module,
1506+ num_float_features,
1507+ run_preproc_inline,
1508+
1509+ Example:
1510+ >>> TestModelWithPreproc(tables, weighted_tables, device)
1511+
1512+ Returns:
1513+ Tuple[torch.Tensor, torch.Tensor]
1514+ """
1515+
1516+ def __init__ (
1517+ self ,
1518+ tables : List [EmbeddingBagConfig ],
1519+ weighted_tables : List [EmbeddingBagConfig ],
1520+ device : torch .device ,
1521+ preproc_module : Optional [nn .Module ] = None ,
1522+ num_float_features : int = 10 ,
1523+ run_preproc_inline : bool = False ,
1524+ ) -> None :
1525+ super ().__init__ ()
1526+ self .dense = TestDenseArch (num_float_features , device )
1527+
1528+ self .ebc : EmbeddingBagCollection = EmbeddingBagCollection (
1529+ tables = tables ,
1530+ device = device ,
1531+ )
1532+ self .weighted_ebc = EmbeddingBagCollection (
1533+ tables = weighted_tables ,
1534+ is_weighted = True ,
1535+ device = device ,
1536+ )
1537+ self .preproc_nonweighted = TestPreprocNonWeighted ()
1538+ self .preproc_weighted = TestPreprocWeighted ()
1539+ self ._preproc_module = preproc_module
1540+ self ._run_preproc_inline = run_preproc_inline
1541+
1542+ def forward (
1543+ self ,
1544+ input : ModelInput ,
1545+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1546+ """
1547+ Runs preprco for EBC and weighted EBC, optionally runs preproc for input
1548+
1549+ Args:
1550+ input
1551+ Returns:
1552+ Tuple[torch.Tensor, torch.Tensor]
1553+ """
1554+ modified_input = input
1555+
1556+ if self ._preproc_module is not None :
1557+ modified_input = self ._preproc_module (modified_input )
1558+ elif self ._run_preproc_inline :
1559+ modified_input .idlist_features = KeyedJaggedTensor .from_lengths_sync (
1560+ modified_input .idlist_features .keys (),
1561+ modified_input .idlist_features .values (),
1562+ modified_input .idlist_features .lengths (),
1563+ )
1564+
1565+ modified_idlist_features = self .preproc_nonweighted (
1566+ modified_input .idlist_features
1567+ )
1568+ modified_idscore_features = self .preproc_weighted (
1569+ modified_input .idscore_features
1570+ )
1571+ ebc_out = self .ebc (modified_idlist_features [0 ])
1572+ weighted_ebc_out = self .weighted_ebc (modified_idscore_features [0 ])
1573+
1574+ pred = torch .cat ([ebc_out .values (), weighted_ebc_out .values ()], dim = 1 )
1575+ return pred .sum (), pred
1576+
1577+
1578+ class TestNegSamplingModule (torch .nn .Module ):
1579+ """
1580+ Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
1581+
1582+ Args:
1583+ extra_input
1584+ has_params
1585+
1586+ Example:
1587+ >>> preproc = TestNegSamplingModule(extra_input)
1588+ >>> out = preproc(in)
1589+
1590+ Returns:
1591+ ModelInput
1592+ """
1593+
1594+ def __init__ (
1595+ self ,
1596+ extra_input : ModelInput ,
1597+ has_params : bool = False ,
1598+ ) -> None :
1599+ super ().__init__ ()
1600+ self ._extra_input = extra_input
1601+ if has_params :
1602+ self ._linear : nn .Module = nn .Linear (30 , 30 )
1603+
1604+ def forward (self , input : ModelInput ) -> ModelInput :
1605+ """
1606+ Appends extra features to model input
1607+
1608+ Args:
1609+ input
1610+ Returns:
1611+ ModelInput
1612+ """
1613+
1614+ # merge extra input
1615+ modified_input = copy .deepcopy (input )
1616+
1617+ # dim=0 (batch dimensions) increases by self._extra_input.float_features.shape[0]
1618+ modified_input .float_features = torch .concat (
1619+ (modified_input .float_features , self ._extra_input .float_features ), dim = 0
1620+ )
1621+
1622+ # stride will be same but features will be joined
1623+ modified_input .idlist_features = KeyedJaggedTensor .concat (
1624+ [modified_input .idlist_features , self ._extra_input .idlist_features ]
1625+ )
1626+ if self ._extra_input .idscore_features is not None :
1627+ # stride will be smae but features will be joined
1628+ modified_input .idscore_features = KeyedJaggedTensor .concat (
1629+ # pyre-ignore
1630+ [modified_input .idscore_features , self ._extra_input .idscore_features ]
1631+ )
1632+
1633+ # dim=0 (batch dimensions) increases by self._extra_input.input_label.shape[0]
1634+ modified_input .label = torch .concat (
1635+ (modified_input .label , self ._extra_input .label ), dim = 0
1636+ )
1637+
1638+ return modified_input
1639+
1640+
1641+ class TestPositionWeightedPreprocModule (torch .nn .Module ):
1642+ """
1643+ Basic module for testing
1644+
1645+ Args: None
1646+ Example:
1647+ >>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
1648+ >>> out = preproc(in)
1649+ Returns:
1650+ ModelInput
1651+ """
1652+
1653+ def __init__ (
1654+ self , max_feature_lengths : Dict [str , int ], device : torch .device
1655+ ) -> None :
1656+ super ().__init__ ()
1657+ self .fp_proc = PositionWeightedProcessor (
1658+ max_feature_lengths = max_feature_lengths ,
1659+ device = device ,
1660+ )
1661+
1662+ def forward (self , input : ModelInput ) -> ModelInput :
1663+ """
1664+ Runs PositionWeightedProcessor
1665+
1666+ Args:
1667+ input
1668+ Returns:
1669+ ModelInput
1670+ """
1671+ modified_input = copy .deepcopy (input )
1672+ modified_input .idlist_features = self .fp_proc (modified_input .idlist_features )
1673+ return modified_input
0 commit comments