1818from typing import Callable
1919from typing import Dict
2020from typing import List
21+ from typing import Optional
22+ from typing import Tuple
2123
2224import numpy as np
2325
@@ -26,31 +28,37 @@ class FunctionalBatchTransform:
2628 """Functional data transform class, which allows to use custom data transform function from given transform_func for special cases.
2729
2830 Args:
29- transform_func (Callable): Function of data transform.
31+ transform_func (Callable): Function of batch data transform.
3032
3133 Examples:
32- >>> # This is the transform_func function. It takes three dictionaries as input: data_dict, label_dict, and weight_dict.
33- >>> # The function will perform some transformations on the data in data_dict, convert all labels in label_dict to uppercase,
34- >>> # and modify the weights in weight_dict by dividing each weight by 10.
35- >>> # Finally, it returns the transformed data, labels, and weights as a tuple.
3634 >>> import ppsci
37- >>> def transform_func(data_dict, label_dict, weight_dict):
38- ... for key in data_dict:
39- ... data_dict[key] = data_dict[key] * 2
40- ... for key in label_dict:
41- ... label_dict[key] = label_dict[key] + 1.0
42- ... for key in weight_dict:
43- ... weight_dict[key] = weight_dict[key] / 10
44- ... return data_dict, label_dict, weight_dict
45- >>> transform = ppsci.data.transform.FunctionalTransform(transform_func)
35+ >>> from typing import Tuple, Dict, Optional
36+ >>> def batch_transform_func(
37+ ... data_list: List[
38+ ... Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]
39+ ... ],
40+ ... ) -> List[Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]]:
41+ ... input_dicts, label_dicts, weight_dicts = zip(*data_list)
42+ ...
43+ ... for input_dict in input_dicts:
44+ ... for key in input_dict:
45+ ... input_dict[key] = input_dict[key] * 2
46+ ...
47+ ... for label_dict in label_dicts:
48+ ... for key in label_dict:
49+ ... label_dict[key] = label_dict[key] + 1.0
50+ ...
51+ ... return list(zip(input_dicts, label_dicts, weight_dicts))
52+ ...
53+ >>> # Create a FunctionalBatchTransform object with the batch_transform_func function
54+ >>> transform = ppsci.data.batch_transform.FunctionalBatchTransform(batch_transform_func)
4655 >>> # Define some sample data, labels, and weights
47- >>> data = {'feature1': np.array([1, 2, 3]), 'feature2': np.array([4, 5, 6])}
48- >>> label = {'class': 0.0, 'instance': 0.1}
49- >>> weight = {'weight1': 0.5, 'weight2': 0.5}
50- >>> # Apply the transform function to the data, labels, and weights using the FunctionalTransform instance
51- >>> transformed_data = transform(data, label, weight)
52- >>> print(transformed_data)
53- ({'feature1': array([2, 4, 6]), 'feature2': array([ 8, 10, 12])}, {'class': 1.0, 'instance': 1.1}, {'weight1': 0.05, 'weight2': 0.05})
56+ >>> data = [({'x': 1}, {'y': 2}, None), ({'x': 11}, {'y': 22}, None)]
57+ >>> transformed_data = transform(data)
58+ >>> for tuple in transformed_data:
59+ ... print(tuple)
60+ ({'x': 2}, {'y': 3.0}, None)
61+ ({'x': 22}, {'y': 23.0}, None)
5462 """
5563
5664 def __init__ (
@@ -61,7 +69,6 @@ def __init__(
6169
6270 def __call__ (
6371 self ,
64- list_data : List [List [Dict [str , np .ndarray ]]],
65- # [{'u': arr, 'y': arr}, {'u': arr, 'y': arr}, {'u': arr, 'y': arr}], [{'s': arr}, {'s': arr}, {'s': arr}], [{}, {}, {}]
66- ) -> List [Dict [str , np .ndarray ]]:
67- return self .transform_func (list_data )
72+ data_list : List [Tuple [Optional [Dict [str , np .ndarray ]], ...]],
73+ ) -> List [Tuple [Optional [Dict [str , np .ndarray ]], ...]]:
74+ return self .transform_func (data_list )
0 commit comments