Skip to content

Commit 9d1f13e

Browse files
update example code of FunctionalBatchTransform
1 parent cfdb3f3 commit 9d1f13e

File tree

1 file changed

+32
-25
lines changed

1 file changed

+32
-25
lines changed

ppsci/data/process/batch_transform/preprocess.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from typing import Callable
1919
from typing import Dict
2020
from typing import List
21+
from typing import Optional
22+
from typing import Tuple
2123

2224
import numpy as np
2325

@@ -26,31 +28,37 @@ class FunctionalBatchTransform:
2628
"""Functional data transform class, which allows to use custom data transform function from given transform_func for special cases.
2729
2830
Args:
29-
transform_func (Callable): Function of data transform.
31+
transform_func (Callable): Function of batch data transform.
3032
3133
Examples:
32-
>>> # This is the transform_func function. It takes three dictionaries as input: data_dict, label_dict, and weight_dict.
33-
>>> # The function will perform some transformations on the data in data_dict, convert all labels in label_dict to uppercase,
34-
>>> # and modify the weights in weight_dict by dividing each weight by 10.
35-
>>> # Finally, it returns the transformed data, labels, and weights as a tuple.
3634
>>> import ppsci
37-
>>> def transform_func(data_dict, label_dict, weight_dict):
38-
... for key in data_dict:
39-
... data_dict[key] = data_dict[key] * 2
40-
... for key in label_dict:
41-
... label_dict[key] = label_dict[key] + 1.0
42-
... for key in weight_dict:
43-
... weight_dict[key] = weight_dict[key] / 10
44-
... return data_dict, label_dict, weight_dict
45-
>>> transform = ppsci.data.transform.FunctionalTransform(transform_func)
35+
>>> from typing import Tuple, Dict, Optional
36+
>>> def batch_transform_func(
37+
... data_list: List[
38+
... Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]
39+
... ],
40+
... ) -> List[Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]]:
41+
... input_dicts, label_dicts, weight_dicts = zip(*data_list)
42+
...
43+
... for input_dict in input_dicts:
44+
... for key in input_dict:
45+
... input_dict[key] = input_dict[key] * 2
46+
...
47+
... for label_dict in label_dicts:
48+
... for key in label_dict:
49+
... label_dict[key] = label_dict[key] + 1.0
50+
...
51+
... return list(zip(input_dicts, label_dicts, weight_dicts))
52+
...
53+
>>> # Create a FunctionalBatchTransform object with the batch_transform_func function
54+
>>> transform = ppsci.data.batch_transform.FunctionalBatchTransform(batch_transform_func)
4655
>>> # Define some sample data, labels, and weights
47-
>>> data = {'feature1': np.array([1, 2, 3]), 'feature2': np.array([4, 5, 6])}
48-
>>> label = {'class': 0.0, 'instance': 0.1}
49-
>>> weight = {'weight1': 0.5, 'weight2': 0.5}
50-
>>> # Apply the transform function to the data, labels, and weights using the FunctionalTransform instance
51-
>>> transformed_data = transform(data, label, weight)
52-
>>> print(transformed_data)
53-
({'feature1': array([2, 4, 6]), 'feature2': array([ 8, 10, 12])}, {'class': 1.0, 'instance': 1.1}, {'weight1': 0.05, 'weight2': 0.05})
56+
>>> data = [({'x': 1}, {'y': 2}, None), ({'x': 11}, {'y': 22}, None)]
57+
>>> transformed_data = transform(data)
58+
>>> for tuple in transformed_data:
59+
... print(tuple)
60+
({'x': 2}, {'y': 3.0}, None)
61+
({'x': 22}, {'y': 23.0}, None)
5462
"""
5563

5664
def __init__(
@@ -61,7 +69,6 @@ def __init__(
6169

6270
def __call__(
6371
self,
64-
list_data: List[List[Dict[str, np.ndarray]]],
65-
# [{'u': arr, 'y': arr}, {'u': arr, 'y': arr}, {'u': arr, 'y': arr}], [{'s': arr}, {'s': arr}, {'s': arr}], [{}, {}, {}]
66-
) -> List[Dict[str, np.ndarray]]:
67-
return self.transform_func(list_data)
72+
data_list: List[Tuple[Optional[Dict[str, np.ndarray]], ...]],
73+
) -> List[Tuple[Optional[Dict[str, np.ndarray]], ...]]:
74+
return self.transform_func(data_list)

0 commit comments

Comments
 (0)