Skip to content

Commit ccbde64

Browse files
committed
inference
1 parent bc068e7 commit ccbde64

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

templates/titanic/tutorial.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,21 @@
8989
# %%
9090
df_test = pd.read_csv(csv_test)
9191

92-
predictions = model.predict(csv_test)
93-
print(predictions[0])
92+
dm = TabularClassificationData.from_data_frame(
93+
predict_data_frame=df_test,
94+
parameters=datamodule.parameters,
95+
batch_size=datamodule.batch_size,
96+
)
97+
preds = trainer.predict(model, datamodule=dm, output="classes")
98+
print(preds[0][:10])
9499

95100
# %%
101+
import itertools # noqa: E402]
102+
96103
import numpy as np # noqa: E402]
97104

98-
assert len(df_test) == len(predictions)
105+
predictions = list(itertools.chain(*preds))
106+
# assert len(df_test) == len(predictions)
99107

100108
df_test["Survived"] = np.argmax(predictions, axis=-1)
101109
df_test.set_index("PassengerId", inplace=True)

0 commit comments

Comments
 (0)