|
5 | 5 | import numpy as np |
6 | 6 |
|
7 | 7 | from data.loader import expand_sh |
| 8 | +from data.process_label import parse_pdbfile |
8 | 9 | from feature import create_basic_features, get_base2d_feature |
9 | 10 | from data.process_alphafold import process_alphafold_target_ensemble, process_alphafold_model |
10 | 11 | from network.resEGNN import resEGNN, resEGNN_with_mask, resEGNN_with_ne |
|
22 | 23 | help='Path to alphafold prediction results.') |
23 | 24 | parser.add_argument('--alphafold_feature_cache', type=str, required=False, default='') |
24 | 25 | parser.add_argument('--af2_pdb', type=str, required=False, default='', |
25 | | - help='Optional. PDBs from AlphaFold2 predcition for index correction with input pdb') |
| 26 | + help='Optional. PDBs from AlphaFold2 predcition for index correction with input pdb. Must contain all residues in input pdb.') |
26 | 27 |
|
27 | 28 | args = parser.parse_args() |
28 | 29 | if args.alphafold_feature_cache == '': |
|
65 | 66 | if args.alphafold_feature_cache is not None: |
66 | 67 | pickle.dump({'plddt': plddt, 'cmap': cmap, 'dict_2d': dict_2d}, |
67 | 68 | open(args.alphafold_prediction_cache, 'wb')) |
| 69 | + if args.af2_pdb != '': |
| 70 | + pose_input = parse_pdbfile(args.input) |
| 71 | + input_idx = np.array([i['rindex'] for i in pose_input]) |
| 72 | + pose_af2 = parse_pdbfile(args.af2_pdb) |
| 73 | + af2_idx = np.array([i['rindex'] for i in pose_af2]) |
| 74 | + mask = af2_idx in input_idx |
| 75 | + af2_qa = af2_qa[:, mask] |
| 76 | + plddt = plddt[:, mask] |
| 77 | + cmap = cmap[:, mask][mask, :] |
| 78 | + for f2d_type in dict_2d.keys(): |
| 79 | + dict_2d[f2d_type] = dict_2d[f2d_type][:, :, mask][:, mask, :] |
68 | 80 | else: |
69 | 81 | dict_2d['f2d_dan'] = get_base2d_feature(args.input, args.output) |
70 | 82 | with torch.no_grad(): |
|
0 commit comments