Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the inference server exit gracefully in case of errors instead of hanging #33

Merged
merged 4 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Support for fine-tuning a pretrained model on new data ([#30](https://github.com/microsoft/molecule-generation/pull/30))

### Fixed
- Made the inference server exit gracefully in case of errors instead of hanging ([#33](https://github.com/microsoft/molecule-generation/pull/33))

## [0.2.0] - 2022-07-01

### Added
Expand Down
70 changes: 42 additions & 28 deletions molecule_generation/utils/moler_inference_server.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import enum
import os
import pathlib
import queue
from collections import defaultdict
from itertools import chain
from multiprocessing import Queue, Process
from multiprocessing import Process, Queue
from queue import Empty
from typing import List, Tuple, Optional, Iterator, Union, Any, DefaultDict
from typing import Any, DefaultDict, Iterator, List, Optional, Tuple, Union
kmaziarz marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np
from rdkit import Chem
from more_itertools import chunked, ichunked
from rdkit import Chem

from molecule_generation.dataset.in_memory_trace_dataset import InMemoryTraceDataset, DataFold
from molecule_generation.dataset.in_memory_trace_dataset import DataFold, InMemoryTraceDataset
from molecule_generation.models.moler_generator import MoLeRGenerator
from molecule_generation.models.moler_vae import MoLeRVae
from molecule_generation.utils.moler_decoding_utils import DecoderSamplingMode, MoLeRDecoderState
from molecule_generation.utils.model_utils import load_vae_model_and_dataset
from molecule_generation.preprocessing.data_conversion_utils import remove_non_max_frags
from molecule_generation.utils.model_utils import load_vae_model_and_dataset
from molecule_generation.utils.moler_decoding_utils import DecoderSamplingMode, MoLeRDecoderState

Pathlike = Union[str, pathlib.Path]

Expand Down Expand Up @@ -234,6 +235,26 @@ def cleanup_workers(self, ignore_failures: bool = False):
self._request_queue.close()
self._output_queue.close()

def try_collect_results(self, num_results: int) -> List[Any]:
results: List[Any] = [None] * num_results

# Try to collect the results and put them back in order.
for _ in range(num_results):
while True:
try:
result_id, result = self._output_queue.get(timeout=10)
kmaziarz marked this conversation as resolved.
Show resolved Hide resolved
results[result_id] = result
break
except queue.Empty:
# We could not get the next result before the timeout, let us make sure that all
# child processes are still alive.
for worker in self._processes:
if not worker.is_alive():
self.cleanup_workers(ignore_failures=True)
raise RuntimeError("Worker process died")

return list(chain(*results))

def __del__(self):
self.cleanup_workers()

Expand All @@ -248,22 +269,18 @@ def __exit__(self, exc_type, exc_value, traceback) -> bool:
def encode(self, smiles_list: List[str], include_log_variances: bool = False):
self.init_workers()

# Issue all requests to the workers, and prepare results array to hold them:
# Choose chunk size such that all workers have something to do:
# Choose chunk size such that all workers have something to do.
chunk_size = min(self._max_num_samples_per_chunk, len(smiles_list) // self._num_workers + 1)
results: List[Any] = []

# Issue all requests to the workers.
num_results = 0
for smiles_chunk in chunked(smiles_list, chunk_size):
self._request_queue.put(
(MoLeRRequestType.ENCODE, len(results), (smiles_chunk, include_log_variances))
(MoLeRRequestType.ENCODE, num_results, (smiles_chunk, include_log_variances))
)
results.append(None)
num_results += 1

# Collect results and put them back into order, before returning them as one long list:
for _ in range(len(results)):
result_id, result = self._output_queue.get()
results[result_id] = result

return list(chain(*results))
return self.try_collect_results(num_results)

def decode(
self,
Expand All @@ -274,12 +291,12 @@ def decode(
sampling_mode: DecoderSamplingMode = DecoderSamplingMode.GREEDY,
) -> List[Tuple[str, Optional[np.ndarray]]]:
self.init_workers()
# Issue all requests to the workers, and prepare results array to hold them:
# Choose chunk size such that all workers have something to do:

# Choose chunk size such that all workers have something to do.
chunk_size = min(
self._max_num_samples_per_chunk, len(latent_representations) // self._num_workers + 1
)
results: List[Any] = []

if init_mols and len(init_mols) != len(latent_representations):
raise ValueError(
f"Number of graph representations ({len(latent_representations)})"
Expand All @@ -289,12 +306,14 @@ def decode(
if not init_mols:
init_mols = [None for _ in range(len(latent_representations))]

# Issue all requests to the workers.
num_results = 0
init_mol_chunks = ichunked(init_mols, chunk_size)
for latents_chunk in chunked(latent_representations, chunk_size):
self._request_queue.put(
(
MoLeRRequestType.DECODE,
len(results),
num_results,
(
latents_chunk,
include_latent_samples,
Expand All @@ -304,11 +323,6 @@ def decode(
),
)
)
results.append(None)
num_results += 1

# Collect results and put them back into order, before returning them as one long list:
for _ in range(len(results)):
result_id, result = self._output_queue.get()
results[result_id] = result

return list(chain(*results))
return self.try_collect_results(num_results)