|
71 | 71 | "from typing import Tuple\n",
|
72 | 72 | "\n",
|
73 | 73 | "import ray\n",
|
74 |
| - "from ray.train.lightgbm import LightGBMPredictor\n", |
75 |
| - "from ray.data.preprocessors.chain import Chain\n", |
76 |
| - "from ray.data.preprocessors.encoder import Categorizer\n", |
| 74 | + "from ray.data import Dataset, Preprocessor\n", |
| 75 | + "from ray.data.preprocessors import Categorizer, StandardScaler\n", |
77 | 76 | "from ray.train.lightgbm import LightGBMTrainer\n",
|
78 |
| - "from ray.train import Result, ScalingConfig\n", |
79 |
| - "from ray.data import Dataset\n", |
80 |
| - "from ray.data.preprocessors import StandardScaler" |
| 77 | + "from ray.train import Result, ScalingConfig" |
81 | 78 | ]
|
82 | 79 | },
|
83 | 80 | {
|
|
124 | 121 | "\n",
|
125 | 122 | " # Scale some random columns, and categorify the categorical_column,\n",
|
126 | 123 | " # allowing LightGBM to use its built-in categorical feature support\n",
|
127 |
| - " preprocessor = Chain(\n", |
128 |
| - " Categorizer([\"categorical_column\"]), \n", |
129 |
| - " StandardScaler(columns=[\"mean radius\", \"mean texture\"])\n", |
130 |
| - " )\n", |
| 124 | + " scaler = StandardScaler(columns=[\"mean radius\", \"mean texture\"])\n", |
| 125 | + " categorizer = Categorizer([\"categorical_column\"])\n", |
| 126 | + "\n", |
| 127 | + " train_dataset = categorizer.fit_transform(scaler.fit_transform(train_dataset))\n", |
| 128 | + " valid_dataset = categorizer.transform(scaler.transform(valid_dataset))\n", |
131 | 129 | "\n",
|
132 | 130 | " # LightGBM specific params\n",
|
133 | 131 | " params = {\n",
|
|
140 | 138 | " label_column=\"target\",\n",
|
141 | 139 | " params=params,\n",
|
142 | 140 | " datasets={\"train\": train_dataset, \"valid\": valid_dataset},\n",
|
143 |
| - " preprocessor=preprocessor,\n", |
144 | 141 | " num_boost_round=100,\n",
|
| 142 | + " metadata = {\"scaler_pkl\": scaler.serialize(), \"categorizer_pkl\": categorizer.serialize()}\n", |
145 | 143 | " )\n",
|
146 | 144 | " result = trainer.fit()\n",
|
147 | 145 | " print(result.metrics)\n",
|
|
173 | 171 | "class Predict:\n",
|
174 | 172 | "\n",
|
175 | 173 | " def __init__(self, checkpoint: Checkpoint):\n",
|
176 |
| - " self.predictor = LightGBMPredictor.from_checkpoint(checkpoint)\n", |
| 174 | + " self.model = LightGBMTrainer.get_model(checkpoint)\n", |
| 175 | + " self.scaler = Preprocessor.deserialize(checkpoint.get_metadata()[\"scaler_pkl\"])\n", |
| 176 | + " self.categorizer = Preprocessor.deserialize(checkpoint.get_metadata()[\"categorizer_pkl\"])\n", |
177 | 177 | "\n",
|
178 | 178 | " def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:\n",
|
179 |
| - " return self.predictor.predict(batch)\n", |
| 179 | + " preprocessed_batch = self.categorizer.transform_batch(self.scaler.transform_batch(batch))\n", |
| 180 | + " return {\"predictions\": self.model.predict(preprocessed_batch)}\n", |
180 | 181 | "\n",
|
181 | 182 | "\n",
|
182 | 183 | "def predict_lightgbm(result: Result):\n",
|
|
0 commit comments