Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for speedier Python-Julia interaction #32

Merged
merged 20 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: Force close and join at exit
  • Loading branch information
kshyatt-aws committed Aug 26, 2024
commit 05870ba88ba870bb8e26f113d8288149696024d5
139 changes: 35 additions & 104 deletions src/braket/simulator_v2/base_simulator_v2.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import sys
import time
from collections.abc import Sequence
from multiprocessing.pool import Pool
from typing import List, Optional, Union
import atexit
import json
import numpy as np
from braket.default_simulator.simulator import BaseLocalSimulator
from braket.ir.jaqcd import DensityMatrix, Probability, StateVector
from braket.ir.openqasm import Program as OpenQASMProgram
from braket.task_result import GateModelTaskResult

from braket.simulator_v2.julia_workers import translate_and_run, translate_and_run_multiple, _handle_julia_error

from multiprocessing.pool import Pool
import atexit

__JULIA_POOL__ = None

def setup_julia():
import os
import sys
import json
import warnings

import os
# don't reimport if we don't have to
if "juliacall" in sys.modules:
os.environ["PYTHON_JULIACALL_HANDLE_SIGNALS"] = "yes"
Expand All @@ -30,110 +32,37 @@ def setup_julia():
):
os.environ[k] = os.environ.get(k, default)
# install Julia and any packages as needed
os.environ["PYTHON_JULIAPKG_OFFLINE"] = "yes"
import juliacall

jl = juliacall.Main
jl.seval("using JSON3, BraketSimulator")
jl_yield = getattr(jl, "yield")
jl_yield()
# don't waste time looking for packages
# which should already be present after this
os.environ["PYTHON_JULIAPKG_OFFLINE"] = "no"

stock_oq3 = """
OPENQASM 3.0;
qubit[2] q;
h q[0];
cnot q;
#pragma braket result probability
"""
jl.BraketSimulator.simulate("braket_sv_v2", stock_oq3, '{}', 0)
jl.BraketSimulator.simulate("braket_dm_v2", stock_oq3, '{}', 0)
return jl

def exit_julia():
import sys
jl = sys.modules["juliacall"].Main
jl_yield = getattr(jl, "yield")
jl_yield()
def setup_pool():
global __JULIA_POOL__
__JULIA_POOL__ = Pool(processes=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming this will be used by the batched executions, should this take in a param for setup_pool called processes which has a default of 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because Julia handles the batches itself behind the scenes. We only need one process to chat back and forth with Julia.

__JULIA_POOL__.apply(setup_julia)
atexit.register(__JULIA_POOL__.join)
atexit.register(__JULIA_POOL__.close)
return


def _handle_julia_error(error):
# we don't import `JuliaError` explicitly here to avoid
# having to import juliacall on the main thread. we need
# to call *this* function on that thread in case getting
# the result from the submitted Future raises an exception
if type(error).__name__ == "JuliaError":
python_exception = getattr(error.exception, "alternate_type", None)
if python_exception is None:
py_error = error
else:
class_val = getattr(sys.modules["builtins"], str(python_exception))
py_error = class_val(str(error.exception.message))
raise py_error
else:
raise error


def translate_and_run(
device_id: str, openqasm_ir: OpenQASMProgram, shots: int = 0
) -> str:
jl = setup_julia()
jl_shots = shots
jl_inputs = json.dumps(openqasm_ir.inputs) if openqasm_ir.inputs else json.dumps({})
py_result = ""
try:
result = jl.BraketSimulator.simulate(
device_id,
openqasm_ir.source,
jl_inputs,
jl_shots,
)
py_result = str(result)
except Exception as e:
_handle_julia_error(e)

return py_result


def translate_and_run_multiple(
device_id: str,
programs: Sequence[OpenQASMProgram],
shots: Optional[int] = 0,
inputs: Optional[Union[dict, Sequence[dict]]] = {},
) -> List[str]:
jl = setup_julia()
irs = [program.source for program in programs]
is_single_input = isinstance(inputs, dict) or len(inputs) == 1
py_inputs = {}
if (is_single_input and isinstance(inputs, dict)) or not is_single_input:
py_inputs = [inputs.copy() for _ in range(len(programs))]
elif is_single_input and not isinstance(inputs, dict):
py_inputs = [inputs[0].copy() for _ in range(len(programs))]
else:
py_inputs = inputs
full_inputs = []
for p_ix, program in enumerate(programs):
if program.inputs:
full_inputs.append(program.inputs | py_inputs[p_ix])
else:
full_inputs.append(py_inputs[p_ix])

jl_inputs = json.dumps(full_inputs)

try:
results = jl.BraketSimulator.simulate(
device_id,
irs,
jl_inputs,
shots,
)
py_results = [str(result) for result in results]
except Exception as e:
_handle_julia_error(e)
return py_results


class BaseLocalSimulatorV2(BaseLocalSimulator):
def __init__(self, device: str):
global __JULIA_POOL__
if __JULIA_POOL__ is None:
setup_pool()
self._device = device
pool = Pool(processes=1)
pool.apply(setup_julia)
self._executor = pool
atexit.register(self._executor.close)


def initialize_simulation(self, **kwargs):
return

Expand All @@ -158,15 +87,16 @@ def run_openqasm(
as a result type when shots=0. Or, if StateVector and Amplitude result types
are requested when shots>0.
"""
global __JULIA_POOL__
try:
jl_result = self._executor.apply(
jl_result = __JULIA_POOL__.apply(
translate_and_run,
[self._device, openqasm_ir, shots],
)
except Exception as e:
_handle_julia_error(e)

result = GateModelTaskResult.parse_raw_schema(jl_result)
result = GateModelTaskResult(**json.loads(jl_result))
jl_result = None
result.additionalMetadata.action = openqasm_ir

Expand Down Expand Up @@ -196,16 +126,17 @@ def run_multiple(
list[GateModelTaskResult]: A list of result objects, with the ith object being
the result of the ith program.
"""
global __JULIA_POOL__
try:
jl_results = self._executor.apply(
jl_results = __JULIA_POOL__.apply(
translate_and_run_multiple,
[self._device, programs, shots, inputs],
)
except Exception as e:
_handle_julia_error(e)

results = [
GateModelTaskResult.parse_raw_schema(jl_result) for jl_result in jl_results
GateModelTaskResult(**json.loads(jl_result)) for jl_result in jl_results
]
jl_results = None
for p_ix, program in enumerate(programs):
Expand Down
Empty file.
78 changes: 78 additions & 0 deletions src/braket/simulator_v2/julia_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from braket.ir.openqasm import Program as OpenQASMProgram
from collections.abc import Sequence
import json
import sys
from typing import List, Optional, Union

def _handle_julia_error(error):
import sys
if isinstance(error, sys.modules["juliacall"].JuliaError):
python_exception = getattr(error.exception, "alternate_type", None)
if python_exception is None:
py_error = error
else:
class_val = getattr(sys.modules["builtins"], str(python_exception))
py_error = class_val(str(error.exception.message))
raise py_error
else:
raise error
return


def translate_and_run(
device_id: str, openqasm_ir: OpenQASMProgram, shots: int = 0
) -> str:
jl = sys.modules["juliacall"].Main
jl_shots = shots
jl_inputs = json.dumps(openqasm_ir.inputs) if openqasm_ir.inputs else '{}'
py_result = ""
try:
result = jl.BraketSimulator.simulate(
device_id,
openqasm_ir.source,
jl_inputs,
jl_shots,
)
py_result = str(result)
except Exception as e:
_handle_julia_error(e)

return py_result


def translate_and_run_multiple(
device_id: str,
programs: Sequence[OpenQASMProgram],
shots: Optional[int] = 0,
inputs: Optional[Union[dict, Sequence[dict]]] = {},
) -> List[str]:
jl = sys.modules["juliacall"].Main
irs = [program.source for program in programs]
is_single_input = isinstance(inputs, dict) or len(inputs) == 1
py_inputs = {}
if (is_single_input and isinstance(inputs, dict)) or not is_single_input:
py_inputs = [inputs.copy() for _ in range(len(programs))]
elif is_single_input and not isinstance(inputs, dict):
py_inputs = [inputs[0].copy() for _ in range(len(programs))]
else:
py_inputs = inputs
full_inputs = []
for p_ix, program in enumerate(programs):
if program.inputs:
full_inputs.append(program.inputs | py_inputs[p_ix])
else:
full_inputs.append(py_inputs[p_ix])

jl_inputs = json.dumps(full_inputs)

try:
results = jl.BraketSimulator.simulate(
device_id,
irs,
jl_inputs,
shots,
)
py_results = [str(result) for result in results]
except Exception as e:
_handle_julia_error(e)
return py_results