Skip to content

Commit

Permalink
got accepted files
Browse files Browse the repository at this point in the history
  • Loading branch information
isamu-isozaki committed Apr 18, 2021
1 parent 15c39f3 commit 6f6c395
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 17 deletions.
Binary file added dcgan_horrible_20_epochs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 6 additions & 3 deletions gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
pass
from sklearn.metrics import classification_report, confusion_matrix

print('desired shape')
print((z_max-z_min, y_max-y_min, x_max-x_min, 1 + len(atom_type) + len(atom_pos)))
# Create the discriminator
Expand Down Expand Up @@ -73,12 +75,12 @@ def train_step(self, real_images):
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# labels += 0.05 * tf.random.uniform(tf.shape(labels))

# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
predictions_d = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions_d)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
Expand All @@ -97,5 +99,6 @@ def train_step(self, real_images):
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

return {"d_loss": d_loss, "g_loss": g_loss}

45 changes: 33 additions & 12 deletions loadptn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
atom_pos_data = pd.Series(atom_pos)
atom_pos_encoder = np.array(pd.get_dummies(atom_pos_data))
dataset_file = 'ptn11H_10'

acceptable_protein_names = ['1aip18-28', '1ald78-88', '1bgy51-61', '1bkv10-21', '1c12109-119', '1cai146-156', '1d2y95-105', '1de488-98', '1deu198-208', '1djb255-265', '1dnu478-488', '1ds6142-152', '1dzt119-129', '1ej6870-880', '1eyv59-69', '1f9348-59', '1fbx62-72', '1fnu802-812', '1ga8125-135', '1gh0121-131', '1go4163-173', '1gwg10-20', '1h69116-126', '1hjv284-294', '1hk923-33', '1hl5116-126', '1hqj5-15', '1ilg199-209', '1irj50-60', '1irm99-109', '1iwo681-691', '1jqi534-544', '1k8i50-60', '1kei182-192', '1kfc75-85', '1lgb113-123', '1lk064-74', '1los132-142', '1lrh147-157', '1m5q42-52', '1mab51-61', '1mbu96-106', '1mlz125-135', '1mvw120-130', '1non126-136', '1o7t156-166', '1o7t73-83', '1og270-80', '1oi9170-180', '1oj7310-320', '1onx260-270', '1oow36-46', '1p1b86-96', '1q1g219-229', '1qe1253-263', '1qq1149-159', '1que77-87', '1qun19-29', '1rmt168-178', '1rqa16-26', '1rvx540-550', '1rzi126-136', '1syq203-213', '1tbu44-54', '1uxa296-306', '1v6o115-125', '1vge84-94', '1vio13-23', '1vqr87-102', '1vqu283-293', '1w5k24-34', '1wuh62-72', '1wxo22-32', '1xko123-133', '1y0w26-36', '1yde205-215', '1yi58-18', '1yk354-64', '1ynb33-43', '1z1j67-77', '1z6r45-56', '1zca187-197', '1zin196-206', '1zl21212-1222', '2a6969-79', '2a7w57-67', '2aaf351-361', '2aow168-178', '2ara115-125', '2bx4248-258', '2c10148-158', '2dxr408-418', '2e9b601-611', '2f1d125-135', '2fa472-82', '2fpt268-278', '2fsi280-290', '2g9t76-89', '2gpp393-403', '2gw1164-174', '2hnd250-260', '2hr3100-110', '2hyd110-120', '2i2q42-52', '2i2x327-337', '2if081-91', '2igk443-453', '2iou194-204', '2j1r43-53', '2j8840-50', '2j9615-25', '2jdh25-35', '2jj1257-267', '2nsp55-65', '2pff5-15', '2pw824-34', '2pzk268-278', '2q7l791-801', '2ql6106-116', '2qm1279-289', '2qr078-88', '2rhs205-215', '2rk750-60', '2uul62-72', '2uul75-85', '2v2l79-89', '2vef106-116', '2w7f258-268', '2wdr99-109', '2win588-598', '2wmy74-84', '2x5q98-108', '2xfe19-29', '2xin271-281', '2xla166-176', '2y5c68-78', '2yk320-30', '2yn4180-190', '2ynj55-65', '2yr5219-229', '2zaf281-291', '2zml73-83', '2zny59-69', '3aeq55-65', '3akf106-116', '3ao467-77', '3art542-552', '3azd11-21', '3bfg47-57', '3c2b158-168', '3c9121-31', '3cb9205-215', '3csm182-192', '3csu123-133', '3d2947-57', '3dqr404-414', '3ecy89-99', '3ee5303-313', '3euw285-295', '3evc141-151', '3exg344-354', '3f6s77-87', '3ftc50-60', '3fup865-875', '3fwu42-52', '3gw221-31', '3h1j20-30', '3hlr118-128', '3hws154-164', '3ias35-45', '3ic343-54', '3is817-27', '3ix039-49', '3jvy146-156', '3jwd169-179', '3k4l353-363', '3knb18-28', '3kpv184-194', '3kve321-331', '3kxf214-224', '3lin17-27', '3lke206-216', '3lkz79-89', '3lpl438-448', '3m6s54-64', '3m791-11', '3mbq90-100', '3muq56-66', '3n68368-378', '3ndo128-138', '3ngt111-121', '3oee139-149', '3om3234-244', '3owe27-37', '3pci425-435', '3q9n4-14', '3qu1121-131', '3qvz78-88', '3qxe76-86', '3r25292-302', '3r8b215-225', '3ril410-420', '3rq1328-338', '3ruc65-75', '3s1b31-41', '3s3823-33', '3so4286-296', '3t0y71-81', '3t2m103-113', '3t4a65-75', '3tky227-237', '3tr9205-215', '3tt2183-193', '3u2z324-334', '3u8j27-37', '3v5r106-116', '3vjh465-475', '3vng111-121', '3vop31-41', '3vrg9-19', '3w7v92-102', '3w8h386-396', '3wch120-130', '3wsa244-254', '3wsh80-90', '3wxf689-699', '3wyh96-106', '3x3c122-132', '3zlp106-116', '3zmf371-381', '3zou72-82', '3zpi336-346', '3zpj306-316', '3zvj111-121', '4b8c793-803', '4bls17-27', '4bt6101-111', '4c0c177-187', '4c0s228-238', '4c7r353-363', '4c90559-569', '4ckh92-102', '4cl7213-223', '4cs723-33', '4ctx382-392', '4d0m120-130', '4d1j215-225', '4dao105-115', '4das145-155', '4dr9148-158', '4dwz194-204', '4e52219-229', '4e5t109-119', '4eqc442-452', '4es5337-347', '4eux481-491', '4f0o2-12', '4f2p167-177', '4fd4100-110', '4fnp413-423', '4fq97-17', '4fxf504-514', '4gsl146-156', '4hac207-217', '4hb615-25', '4hg5192-202', '4hyr358-368', '4i1i130-140', '4igb346-356', '4inr63-73', '4itt21-31', '4jcr132-142', '4k1z83-93', '4kci558-568', '4ki776-86', '4knp178-188', '4knt88-98', '4ld7367-377', '4loc605-615', '4lrv49-62', '4lus85-95', '4lzw82-92', '4m0545-55', '4m0z397-407', '4m11408-418', '4mgk99-109', '4mpb378-388', '4mvm1148-1158', '4mz968-78', '4n4o177-187', '4n9i152-162', '4nrk143-153', '4nzv36-46', '4o6r293-303', '4ocl49-59', '4ojx105-115', '4oop49-59', '4ou8286-296', '4p9y84-94', '4po563-73', '4py3144-154', '4q7h191-201', '4qd8130-140', '4qux55-65', '4qz156-66', '4r41111-121', '4ree135-145', '4rhe95-105', '4rus63-73', '4to5379-389', '4uh4597-607', '4um3123-133', '4uud559-569', '4uwl837-847', '4v2n36-46', '4w7864-74', '4w8n187-197', '4w9h18-28', '4wjg171-181', '4wzb110-120', '4xcn67-77', '4xtw69-79', '4y8r28-38', '4ym1234-244', '4yqh692-702', '4z9o70-80', '4zci433-443', '4zgj180-190', '4zh4214-224', '4zlb137-147', '5a0q223-233', '5ad7459-469', '5afu303-313', '5aig108-118', '5apx36-46', '5apz98-108', '5aqv130-140', '5bps39-49', '5cax51-61', '5cpz41-51', '5d5o170-180', '5dy934-44', '5e7c26-36', '5eiz40-50', '5eno132-142', '5exf18-28', '5fj936-46', '5fli15-25', '5fm5587-597', '5fuc46-56', '5fue44-54', '5fw4171-181', '5h47101-111', '5hbb69-79', '5irg117-127', '5iw7770-780', '5j5g106-116', '5jh1176-186', '5jpi408-418', '5jsi62-72', '5kgt152-162', '5kyd17-27', '5l0f1029-1039', '5l0f1097-1107', '5l5q24-34', '5ldf321-331', '5lf446-56', '5mt310-20', '5nev178-188', '5nok636-646', '5nzr192-202', '5p8z91-101', '5poj131-141', '5pyc795-805', '5q37930-940', '5qm0170-180', '5szn395-405', '5tcg266-276', '5tja163-173', '5tou6-16', '5tv3148-158', '5u1552-62', '5udw69-79', '5ufl145-158', '5unr705-715', '5us6246-256', '5v9u92-102', '5vch133-143', '5vys96-106', '5w70312-322', '5w9a25-35', '5wbu1711-1721', '5wk1113-123', '5ws527-37', '5x8l89-99', '5xkg262-272', '5xnw165-175', '5xte165-175', '5xvi232-242', '5xxw304-314', '5y7292-102', '5yb441-51', '5yc8135-145', '5z0g84-94', '5z2l135-145', '5z2m67-77', '5zb4111-121', '5zb8172-182', '5zcp25-35', '5zwe285-295', '5zws5-15', '6a0n99-109', '6a2u75-85', '6aqh345-355', '6bed269-279', '6bgl25-35', '6brj840-850', '6bxb1-11', '6c2y532-542', '6cp095-105', '6dhi161-172', '6dnj194-204', '6dpv348-358', '6dpx50-60', '6e5b13-23', '6ea5153-163', '6eh1150-160', '6epc57-67', '6fdu255-265', '6flc365-375', '6fow415-425', '6g6b17-27', '6g6f9-19', '6gey93-103', '6gnq18-28', '6hb0134-144', '6hck14-24', '6huc61-71', '6hvw10-20', '6hw383-93', '6hwa78-88', '6hy08-18', '6i0x430-440', '6iwp67-77', '6jeb455-465', '6jeq65-75', '6jla177-187', '6jlm13-23', '6joh132-142', '6joh37-47', '6jwi13-23', '6jx5460-470', '6jy0433-443', '6ko717-27', '6l5618-28', '6l7c16-26', '6lrb509-519', '6m6665-75', '6m7g159-169', '6m8s50-60', '6mb275-85', '6mdv96-106', '6mfp7-17', '6mjg222-232', '6mo11058-1068', '6mtg261-271', '6mx2164-174', '6mx916-26', '6mx973-83', '6mya472-482', '6ner135-145', '6nl98-18', '6ofs74-84', '6oo2341-351', '6pbu189-199', '6pby275-285', '6pej178-188', '6pfa67-77', '6q0r311-321', '6qg3302-312', '6qsd15-25', '6r7l55-65', '6rco55-65', '6rdx103-113', '6re0191-201', '6re2113-123', '6rec113-123', '6rfc235-245', '6riq232-242', '6rya37-47', '6s79187-197', '6t23310-320', '6u65211-221', '6udu265-275', '6v16209-219', '6v2f134-144', '6vak166-176', '6vi4106-116', '6vnr628-638', '6vxc401-411', '6w9c64-74', '6w9d73-83']
# Given a set of files storing entry objects and their directory location, return their feature dimensions such as the positional atom types and the bounds for the matrix.
def load_feature_dimensions(files, fdir = 'ptndata_10H/'):
x_min, y_min, z_min, x_max, y_max, z_max = CUBIC_LENGTH_CONSTRAINT, CUBIC_LENGTH_CONSTRAINT, CUBIC_LENGTH_CONSTRAINT, 0, 0, 0
Expand All @@ -41,6 +41,24 @@ def load_feature_dimensions(files, fdir = 'ptndata_10H/'):

return atom_pos, x_min, y_min, z_min, x_max, y_max, z_max

def load_acceptable_dimensions(fdir = 'ptndata_10H/'):
files = os.listdir(fdir)
output = []
x_min, y_min, z_min, x_max, y_max, z_max = 17, 22, 15, 49, 48, 52
atom_pos = ['CG', 'OG', 'OH', 'NH2', 'CE1', 'CD2', 'OG1', 'N', 'ND1', 'CD', 'SD', 'ND2', 'OD1', 'OE2', 'OE1', 'C', 'NE', 'OD2', 'SG', 'CB', 'CD1', 'CZ', 'NH1', 'CE2', 'CG1', 'NE2', 'NZ', 'CG2', 'CA', 'O', 'CE', 'None']
for f in tqdm(files):
entry = pickle.load(open(fdir + f, 'rb'))
new_x_min, new_y_min, new_z_min, new_x_max, new_y_max, new_z_max = find_bounds(grid2logical(entry.mat))
if x_min <= new_x_min and y_min <= new_y_min and z_min <= new_z_min and new_x_max <= x_max and new_y_max <= y_max and new_z_max <= z_max:
new_atom_pos = get_all_atoms(entry.mat, [])
no_new_protein_pos = True
for pos in new_atom_pos:
if pos not in atom_pos:
no_new_protein_pos = False
if no_new_protein_pos:
output.append(f)
print(f)
return output

# This is almost like sample_gen, except it is a function instead of a generator function. This is used for generating the validation data before training the CNN. It generates the validation samples for all three of the metrics.
def sample_loader(files, feature_set_, atom_type, atom_type_encoder, atom_pos, atom_pos_encoder, energy_scores, x_min, y_min, z_min, x_max, y_max, z_max, fdir='ptndata_10H/'):
Expand Down Expand Up @@ -189,19 +207,22 @@ def train_data_loader(files, feature_set, fdir='ptndata_10H/'):
feature_set[q][i][j][k] = [a[x_min + i][y_min + j][z_min + k]] + b[x_min + i][y_min + j][z_min + k].tolist() + c[x_min + i][y_min + j][z_min + k].tolist()

if __name__ == "__main__":
fdir='ptn11H_10/'
files = os.listdir(fdir)
fdir='ptn11H_1000/'
output = load_acceptable_dimensions(fdir)
print(len(output))
print(output)
# files = os.listdir(fdir)
# files.sort()

print(load_feature_dimensions(files, fdir))
# Initialize the feature set
feature_set = None
if os.path.isfile(dataset_file+'.npy'):
feature_set = np.load(dataset_file+'.npy')
else:
feature_set = np.zeros(shape=(len(files), z_max-z_min, y_max-y_min, x_max-x_min, 1 + len(atom_type) + len(atom_pos)))
train_data_loader(files, feature_set, fdir=fdir)
np.save(dataset_file, feature_set)
# print(load_feature_dimensions(files, fdir))
# # Initialize the feature set
# feature_set = None
# if os.path.isfile(dataset_file+'.npy'):
# feature_set = np.load(dataset_file+'.npy')
# else:
# feature_set = np.zeros(shape=(len(files), z_max-z_min, y_max-y_min, x_max-x_min, 1 + len(atom_type) + len(atom_pos)))
# train_data_loader(files, feature_set, fdir=fdir)
# np.save(dataset_file, feature_set)
# feature_set_ = np.array([[[[ [0] * (1 + len(atom_type) + len(atom_pos)) for i in range(x_min, x_max)] for j in range(y_min, y_max)] for k in range(z_min, z_max)] for q in range(validation_samples)])


Expand Down
60 changes: 58 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from tensorflow import keras
import numpy as np
from data import get_data
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

def main(batch_size, file_dir):
# Prepare the dataset. We use both the training & test MNIST digits.
Expand All @@ -12,10 +14,64 @@ def main(batch_size, file_dir):
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True)
)
# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(x, epochs=20)
print(generator.summary())
print(discriminator.summary())
history = gan.fit(x, batch_size=batch_size, epochs=20)
g_loss, d_loss = history.history['g_loss'], history.history['d_loss']
plt.plot(g_loss)
plt.plot(d_loss)
plt.xticks(np.arange(0, 20, step=1)) # Set label locations.
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('Protein Structure Generation With DCGAN')
# print(xticks(np.arange(0, 20, step=1)))
# pred = np.stack(history.history['pred'], axis=0)
# labels = np.stack(history.history['label'], axis=0)
# accuracies = get_accuracies(pred, labels)
# plt.plot(accuracies)
plt.legend(['Generator loss', 'Discriminator loss'], loc='upper right')
plt.show()
def get_accuracies(pred, labels, threshold=.5):
pred_output = pred.copy()
labels_output = labels.copy()

pred_output[pred_output >= threshold] = 1
pred_output[pred_output < threshold] = 0

labels_output[labels_output >= threshold] = 1
labels_output[labels_output < threshold] = 0

accuracies = []
for i in range(pred_output.shape[0]):
accuracies.append(accuracy_score(labels_output[i], pred_output[i]))
return accuracies
#print(classification_report(labels_output,pred_output))


# Plot Accuracy and Loss
def plot_training_loss(history):
# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()


if __name__ == '__main__':
main(10, 'ptn11H_10')

0 comments on commit 6f6c395

Please sign in to comment.