Skip to content

Commit

Permalink
add some update
Browse files Browse the repository at this point in the history
  • Loading branch information
wyddmw committed May 5, 2021
1 parent ac15f81 commit 2731d08
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
Empty file added dataset_geoerror.py
Empty file.
39 changes: 23 additions & 16 deletions model/dataset_geoerror.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,26 @@ def generate_data(self):
np.savetxt(self.numpy_file, all_data, fmt='%f', delimiter=',')

# split training data
# for i in range(4): # 4:1的比例划分数据
# if i > 0:
# self.train_data = np.concatenate((self.train_data, all_data[i::5]), axis=0)
# else:
# self.train_data = all_data[::5]
for i in range(4): # 4:1的比例划分数据
if i > 0:
self.train_data = np.concatenate((self.train_data, all_data[i::5]), axis=0)
else:
self.train_data = all_data[::5]
self.test_data = all_data[4::5]

num_train, _ = self.train_data.shape
num_test, _ = self.test_data.shape
# split the dataset with a radio 4:1
train_index = [i for i in range(4)]
N, _ = all_data.shape
#train_index = np.array([i for i in range(4)])
#print(train_index)
#print(type(train_index))

if self.random_select:
print("random select")

self.train_data = all_data[train_index::5].reshape(N, -1, 4)
self.test_data = all_data[4::5].reshape(N, -1, 4)
#self.train_data = all_data[train_index::5].reshape(N, -1, 4)
self.train_data = self.train_data.reshape(num_train, -1, 4)
self.test_data = all_data[4::5].reshape(num_test, -1, 4)

print("train_data shape is ", self.train_data.shape)
print("test data shape is ", self.test_data.shape)
Expand All @@ -94,15 +100,16 @@ def generate_lable(self):
# 只使用
all_label = np.loadtxt(self.label_path, delimiter=',', dtype=np.float32)
all_label = all_label[:, 0].astype(np.int) - 1
print(all_label.shape)
# all_label = all_label[:, 0] - 1

# for i in range(4):
# if i > 0:
# self.train_label = np.concatenate((self.train_label, all_label[i::5]), axis=0)
# else:
# self.train_label = all_label[::5]
train_index = [i for i in range(4)]
self.train_label = all_label[train_index::4]
for i in range(4):
if i > 0:
self.train_label = np.concatenate((self.train_label, all_label[i::5]), axis=0)
else:
self.train_label = all_label[::5]
# train_index = [i for i in range(4)]
# self.train_label = all_label[train_index::4]
self.test_label = all_label[4::5]

return self.train_label, self.test_label
Expand Down
2 changes: 1 addition & 1 deletion train.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python train_classification_geoerror.py --data_path '/home/spyder/hazel/pointnet_geoerror/dataload-V2/origin/' --nepoch=5 --batchSize 4 --label_path '/home/spyder/hazel/pointnet_geoerror/dataload-V2/label13.csv' --numpy_path '/home/spyder/hazel/pointnet_geoerror/dataload-V2/all_data.txt'
python train_classification_geoerror.py --data_path '/home/spyder/hazel/dataload-V2/origin/' --nepoch=5 --batchSize 4 --label_path '/home/spyder/hazel/dataload-V2/label13.csv' --numpy_path '/home/spyder/hazel/dataload-V2/all_data.txt'
2 changes: 1 addition & 1 deletion train_classification_geoerror.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
classifier.cuda()

num_batch = len(dataset) / opt.batchSize
num_batch = len(train_dataset) / opt.batchSize

for epoch in range(opt.nepoch):
scheduler.step()
Expand Down

0 comments on commit 2731d08

Please sign in to comment.