Skip to content

Commit

Permalink
feat: add one hot code data
Browse files Browse the repository at this point in the history
  • Loading branch information
james397520 committed Oct 24, 2023
2 parents b805c3e + 4ada68d commit 84c04e0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 0 additions & 2 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def one_hot_encode(ids, num_classes):
class HousePriceTrainDataset(Dataset):
def __init__(self, dataframe, target_column, normalize_columns=None):
# Load area data
area_file_path = 'data/area.csv'
self.area_df = pd.read_csv(area_file_path)
self.dataframe = dataframe.copy() # Creating a copy to avoid modifying the original dataframe

feature_list=[]
Expand Down
8 changes: 5 additions & 3 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ def inference():
with torch.no_grad():
for batch in data_loader:
if gpu:
features = batch['features'].cuda()
data = batch['features']
data[0] = data[0].cuda()
data[1] = data[1].cuda()
else:
features = batch['features']
data = batch['features']

predictions = model(features).cpu().numpy().flatten()
predictions = model(data).cpu().numpy().flatten()

# Create a DataFrame to hold the IDs and predicted prices
ids = [f"PU-{i}" for i in range(1, len(predictions) + 1)] # Adjust ID format as needed
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def main():
# print(batch['features'].shape)
# print(batch['target'].shape)
if gpu:
data = batch['features'].cuda()
data = batch['features']
data[0] = data[0].cuda()
data[1] = data[1].cuda()
targets = batch['target'].cuda()
else:
data = batch['features']
Expand Down

0 comments on commit 84c04e0

Please sign in to comment.