Skip to content

Commit a181bcb

Browse files
authored
Fix xgb string data contains slash (#2869)
* fix xgb string data contains slash * fix ci
1 parent 6cd7b24 commit a181bcb

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

python/runtime/local/xgboost_submitter/predict.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from runtime.feature.derivation import get_ordered_field_descs
2424
from runtime.feature.field_desc import DataType
2525
from runtime.model.model import Model
26-
from runtime.xgboost.dataset import xgb_dataset
26+
from runtime.xgboost.dataset import DMATRIX_FILE_SEP, xgb_dataset
2727

2828

2929
def pred(datasource, select, result_table, pred_label_name, model):
@@ -148,8 +148,8 @@ def _store_predict_result(preds, result_table, result_column_names,
148148
break
149149

150150
row = [
151-
item for i, item in enumerate(line.strip().split("/"))
152-
if i != train_label_idx
151+
item for i, item in enumerate(line.strip().split(
152+
DMATRIX_FILE_SEP)) if i != train_label_idx
153153
]
154154
row.append(str(preds[line_no]))
155155
w.write(row)

python/runtime/xgboost/dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from scipy.sparse import vstack
2626
from sklearn.datasets import load_svmlight_file, load_svmlight_files
2727

28+
DMATRIX_FILE_SEP = "\t"
29+
2830

2931
def xgb_dataset(datasource,
3032
fn,
@@ -130,7 +132,8 @@ def dump_dmatrix(filename,
130132
feature_metas)
131133

132134
if raw_data_fid is not None:
133-
raw_data_fid.write("/".join([str(r) for r in row]) + "\n")
135+
raw_data_fid.write(
136+
DMATRIX_FILE_SEP.join([str(r) for r in row]) + "\n")
134137

135138
if transform_fn:
136139
features = transform_fn(features)
@@ -163,7 +166,7 @@ def dump_dmatrix(filename,
163166
if has_label:
164167
row_data = [str(label)] + row_data
165168

166-
f.write("\t".join(row_data) + "\n")
169+
f.write(DMATRIX_FILE_SEP.join(row_data) + "\n")
167170
row_id += 1
168171
# batch_size == None means use all data in generator
169172
if batch_size is None:

python/runtime/xgboost/evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from runtime import db
1818
from runtime.dbapi.paiio import PaiIOConnection
1919
from runtime.model.metadata import load_metadata
20-
from runtime.xgboost.dataset import xgb_dataset
20+
from runtime.xgboost.dataset import DMATRIX_FILE_SEP, xgb_dataset
2121

2222
SKLEARN_METRICS = [
2323
'accuracy_score',
@@ -118,7 +118,7 @@ def evaluate_and_store_result(bst, dpred, feature_file_id, validation_metrics,
118118

119119
y_test_list = []
120120
for line in feature_file_read:
121-
row = [i for i in line.strip().split("\t")]
121+
row = [i for i in line.strip().split(DMATRIX_FILE_SEP)]
122122
# DMatrix store label in the first column
123123
if label_meta["dtype"] == "float32":
124124
label = float(row[0])

python/runtime/xgboost/predict.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from runtime import db
1717
from runtime.dbapi.paiio import PaiIOConnection
1818
from runtime.model.metadata import load_metadata
19-
from runtime.xgboost.dataset import xgb_dataset
19+
from runtime.xgboost.dataset import DMATRIX_FILE_SEP, xgb_dataset
2020

2121
DEFAULT_PREDICT_BATCH_SIZE = 10000
2222

@@ -123,7 +123,8 @@ def predict_and_store_result(bst, dpred, feature_file_id, model_params,
123123
# FIXME(typhoonzero): how to output columns that are not used
124124
# as features, like ids?
125125
row = [
126-
item for i, item in enumerate(line.strip().split("/"))
126+
item
127+
for i, item in enumerate(line.strip().split(DMATRIX_FILE_SEP))
127128
if i != train_label_index
128129
]
129130
row.append(str(preds[line_no]))

0 commit comments

Comments
 (0)