Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove allocations for better performance from Python, cleanup #48

Merged
merged 13 commits into from
Aug 27, 2024
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BraketSimulator"
uuid = "76d27892-9a0b-406c-98e4-7c178e9b3dff"
authors = ["Katharine Hyatt <hyatkath@amazon.com> and contributors"]
version = "0.0.3"
version = "0.0.4"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
42 changes: 32 additions & 10 deletions ext/BraketSimulatorPythonExt/BraketSimulatorPythonExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,45 @@
using PrecompileTools

@recompile_invalidations begin
using BraketSimulator, PythonCall, JSON3
using BraketSimulator, BraketSimulator.Quasar, PythonCall, JSON3
end

using BraketSimulator: simulate
using BraketSimulator: simulate, AbstractProgramResult

function BraketSimulator.simulate(simulator, task_spec::String, inputs::Dict{String, Any}, shots::Int; kwargs...)
jl_specs = BraketSimulator.OpenQasmProgram(BraketSimulator.braketSchemaHeader("braket.ir.openqasm.program", "1"), task_spec, inputs)
jl_results = simulate(simulator, jl_specs, shots; kwargs...)
json = JSON3.write(jl_results)
return json
function BraketSimulator.simulate(simulator_id::String, task_spec::String, py_inputs::String, shots::Int; kwargs...)
inputs = JSON3.read(py_inputs, Dict{String, Any})
jl_spec = BraketSimulator.OpenQasmProgram(BraketSimulator.braketSchemaHeader("braket.ir.openqasm.program", "1"), task_spec, inputs)
simulator = if simulator_id == "braket_sv_v2"
StateVectorSimulator(0, shots)
elseif simulator_id == "braket_dm_v2"
DensityMatrixSimulator(0, shots)
end
jl_results = simulate(simulator, jl_spec, shots; kwargs...)
# this is expensive due to allocations
py_results = JSON3.write(jl_results)
simulator = nothing
inputs = nothing
jl_spec = nothing
jl_results = nothing
return py_results
end
function BraketSimulator.simulate(simulator, task_specs::Vector{String}, inputs::Vector{Dict{String, Any}}, shots::Int; kwargs...)
jl_specs = map(zip(task_specs, inputs)) do (task_spec, input)
BraketSimulator.OpenQasmProgram(BraketSimulator.braketSchemaHeader("braket.ir.openqasm.program", "1"), task_spec, input)
function BraketSimulator.simulate(simulator_id::String, task_specs::PyList, py_inputs::String, shots::Int; kwargs...)
inputs = JSON3.read(py_inputs, Vector{Dict{String, Any}})
jl_specs = map(zip(task_specs, inputs)) do (task_spec, input)
jl_spec = task_spec isa Py ? pyconvert(String, task_spec) : task_spec
BraketSimulator.OpenQasmProgram(BraketSimulator.braketSchemaHeader("braket.ir.openqasm.program", "1"), jl_spec, input)
end
simulator = if simulator_id == "braket_sv_v2"
StateVectorSimulator(0, shots)
elseif simulator_id == "braket_dm_v2"

Check warning on line 36 in ext/BraketSimulatorPythonExt/BraketSimulatorPythonExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BraketSimulatorPythonExt/BraketSimulatorPythonExt.jl#L36

Added line #L36 was not covered by tests
DensityMatrixSimulator(0, shots)
end
jl_results = simulate(simulator, jl_specs, shots; kwargs...)
jsons = [JSON3.write(r) for r in jl_results]
simulator = nothing
jl_results = nothing
inputs = nothing
jl_specs = nothing
return jsons
end

Expand Down
109 changes: 61 additions & 48 deletions src/BraketSimulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ include("noise_kernels.jl")
include("Quasar.jl")
using .Quasar

const CHUNK_SIZE = 2^10
const LOG2_CHUNK_SIZE = 10
const CHUNK_SIZE = 2^LOG2_CHUNK_SIZE

function _index_to_endian_bits(ix::Int, qubit_count::Int)
bits = Vector{Int}(undef, qubit_count)
Expand Down Expand Up @@ -139,53 +140,52 @@ function _bundle_results(
end

function _generate_results(
results::Vector{<:AbstractProgramResult},
result_types::Vector,
result_types,
simulator::D,
) where {D<:AbstractSimulator}
result_values = map(result_type -> calculate(result_type, simulator), result_types)
final_results = Vector{ResultTypeValue}(undef, length(result_values))
for r_ix in 1:length(final_results)
final_results[r_ix] = ResultTypeValue(results[r_ix],
complex_matrix_to_ir(result_values[r_ix]))
ir_results = map(StructTypes.lower, result_types)
results = map(zip(ir_results, result_values)) do (ir, val)
ir_val = complex_matrix_to_ir(val)
return ResultTypeValue(ir, ir_val)
end
return final_results
return results
end

_translate_result_type(r::IR.Amplitude, qc::Int) = Amplitude(r.states)
_translate_result_type(r::IR.StateVector, qc::Int) = StateVector()
# The IR result types support `nothing` as a valid option for the `targets` field,
# however `Result`s represent this with an empty `QubitSet` for type
# stability reasons. Here we take a `nothing` value for `targets` and translate it
# to apply to all qubits.
_translate_result_type(r::IR.DensityMatrix, qc::Int) = isnothing(r.targets) ? DensityMatrix(collect(0:qc-1)) : DensityMatrix(r.targets)
_translate_result_type(r::IR.Probability, qc::Int) = isnothing(r.targets) ? Probability(collect(0:qc-1)) : Probability(r.targets)
_translate_result_type(r::IR.Amplitude) = Amplitude(r.states)
_translate_result_type(r::IR.StateVector) = StateVector()
_translate_result_type(r::IR.DensityMatrix) = DensityMatrix(r.targets)
_translate_result_type(r::IR.Probability) = Probability(r.targets)
for (RT, IRT) in ((:Expectation, :(IR.Expectation)), (:Variance, :(IR.Variance)), (:Sample, :(IR.Sample)))
@eval begin
function _translate_result_type(r::$IRT, qc::Int)
targets = isnothing(r.targets) ? collect(0:qc-1) : r.targets
obs = StructTypes.constructfrom(Observables.Observable, r.observable)
$RT(obs, QubitSet(targets))
function _translate_result_type(r::$IRT)
obs = StructTypes.constructfrom(Observables.Observable, r.observable)
$RT(obs, QubitSet(r.targets))
end
end
end
_translate_result_types(results::Vector{AbstractProgramResult}, qubit_count::Int) = map(result->_translate_result_type(result, qubit_count), results)
_translate_result_types(results::Vector{AbstractProgramResult}) = map(_translate_result_type, results)

function _compute_exact_results(d::AbstractSimulator, program::Program, qubit_count::Int)
result_types = _translate_result_types(program.results, qubit_count)
result_types = _translate_result_types(program.results)
_validate_result_types_qubits_exist(result_types, qubit_count)
return _generate_results(program.results, result_types, d)
return _generate_results(result_types, d)
end

function _compute_exact_results(d::AbstractSimulator, program::Circuit, qubit_count::Int)
_validate_result_types_qubits_exist(program.result_types, qubit_count)
return _generate_results(program.result_types, d)
end

"""
_get_measured_qubits(program::Program, qubit_count::Int) -> Vector{Int}
_get_measured_qubits(program, qubit_count::Int) -> Vector{Int}

Get the qubits measured by the program. If [`Measure`](@ref)
instructions are present in the program's instruction list,
their targets are used to form the list of measured qubits.
If not, all qubits from 0 to `qubit_count-1` are measured.
"""
function _get_measured_qubits(program::Program, qubit_count::Int)
function _get_measured_qubits(program, qubit_count::Int)
measure_inds = findall(ix->ix.operator isa Measure, program.instructions)
isempty(measure_inds) && return collect(0:qubit_count-1)
measure_ixs = program.instructions[measure_inds]
Expand All @@ -208,29 +208,29 @@ function _prepare_program(circuit_ir::OpenQasmProgram, inputs::Dict{String, <:An
_verify_openqasm_shots_observables(circuit, n_qubits)
basis_rotation_instructions!(circuit)
end
return Program(circuit), n_qubits
return circuit, n_qubits
end
"""
_prepare_program(circuit_ir::Program, inputs::Dict{String, <:Any}, shots::Int) -> (Program, Int)

Apply any `inputs` provided for the simulation. Return the `Program`
(with bound parameters) and the qubit count of the circuit.
"""
# nosemgrep
function _prepare_program(circuit_ir::Program, inputs::Dict{String, <:Any}, shots::Int)
function _prepare_program(circuit_ir::Program, inputs::Dict{String, <:Any}, shots::Int) # nosemgrep
operations::Vector{Instruction} = circuit_ir.instructions
symbol_inputs = Dict(Symbol(k) => v for (k, v) in inputs)
operations = [bind_value!(operation, symbol_inputs) for operation in operations]
qc = qubit_count(circuit_ir)
bound_program = Program(circuit_ir.braketSchemaHeader, operations, circuit_ir.results, circuit_ir.basis_rotation_instructions)
return bound_program, qubit_count(circuit_ir)
return bound_program, qc
end
"""
_combine_operations(program::Program, shots::Int) -> Program
_combine_operations(program, shots::Int) -> Program

Combine explicit instructions and basis rotation instructions (if necessary).
Validate that all operations are performed on qubits within `qubit_count`.
"""
function _combine_operations(program::Program, shots::Int)
function _combine_operations(program, shots::Int)
operations = program.instructions
if shots > 0 && !isempty(program.basis_rotation_instructions)
operations = vcat(operations, program.basis_rotation_instructions)
Expand All @@ -248,17 +248,19 @@ Compute the results once `simulator` has finished applying all the instructions.
the results array is populated with the parsed result types (to help the Braket SDK compute them from the sampled measurements)
and a placeholder zero value.
"""
function _compute_results(::Type{OpenQasmProgram}, simulator, program, n_qubits, shots) # nosemgrep
analytic_results = shots == 0 && !isnothing(program.results) && !isempty(program.results)
function _compute_results(simulator, program::Circuit, n_qubits, shots)
results = program.result_types
has_no_results = isnothing(results) || isempty(results)
analytic_results = shots == 0 && !has_no_results
if analytic_results
return _compute_exact_results(simulator, program, n_qubits)
elseif isnothing(program.results) || isempty(program.results)
elseif has_no_results
return ResultTypeValue[]
else
return ResultTypeValue[ResultTypeValue(result_type, 0.0) for result_type in program.results]
return ResultTypeValue[ResultTypeValue(StructTypes.lower(result_type), 0.0) for result_type in results]
end
end
function _compute_results(::Type{Program}, simulator, program, n_qubits, shots) # nosemgrep
function _compute_results(simulator, program::Program, n_qubits, shots)
analytic_results = shots == 0 && !isnothing(program.results) && !isempty(program.results)
if analytic_results
return _compute_exact_results(simulator, program, n_qubits)
Expand All @@ -272,6 +274,12 @@ function _validate_circuit_ir(simulator, circuit_ir::Program, qubit_count::Int,
_validate_shots_and_ir_results(shots, circuit_ir.results, qubit_count)
return
end
function _validate_circuit_ir(simulator, circuit_ir::Circuit, qubit_count::Int, shots::Int)
_validate_ir_results_compatibility(simulator, circuit_ir.result_types, Val(:JAQCD))
_validate_ir_instructions_compatibility(simulator, circuit_ir, Val(:JAQCD))
_validate_shots_and_ir_results(shots, circuit_ir.result_types, qubit_count)
return
end

"""
simulate(simulator::AbstractSimulator, circuit_ir::Union{OpenQasmProgram, Program}, shots::Int; kwargs...) -> GateModelTaskResult
Expand All @@ -296,7 +304,7 @@ function simulate(
reinit!(simulator, n_qubits, shots)
simulator = evolve!(simulator, operations)
measured_qubits = _get_measured_qubits(program, n_qubits)
results = _compute_results(T, simulator, program, n_qubits, shots)
results = _compute_results(simulator, program, n_qubits, shots)
return _bundle_results(results, circuit_ir, simulator, measured_qubits)
end

Expand Down Expand Up @@ -623,6 +631,7 @@ include("dm_simulator.jl")
"""
all_gates_qasm = """
OPENQASM 3.0;
input float theta;
bit[3] b;
qubit[3] q;
rx(0.1) q[0];
Expand Down Expand Up @@ -653,11 +662,11 @@ include("dm_simulator.jl")
swap q[0], q[1];
iswap q[0], q[1];

xx(6.249142469550989) q[0], q[1];
yy(6.249142469550989) q[0], q[1];
xy(6.249142469550989) q[0], q[1];
zz(6.249142469550989) q[0], q[1];
pswap(6.249142469550989) q[0], q[1];
xx(theta) q[0], q[1];
yy(theta) q[0], q[1];
xy(theta) q[0], q[1];
zz(theta) q[0], q[1];
pswap(theta) q[0], q[1];
ms(0.1, 0.2, 0.3) q[0], q[1];

cphaseshift(6.249142469550989) q[0], q[1];
Expand Down Expand Up @@ -701,7 +710,7 @@ include("dm_simulator.jl")
#pragma braket result sample x(q[0]) @ y(q[1])
"""
@compile_workload begin
using BraketSimulator, BraketSimulator.Quasar
using BraketSimulator, BraketSimulator.Quasar, BraketSimulator.StructTypes
simulator = StateVectorSimulator(5, 0)
oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), custom_qasm, nothing)
simulate(simulator, oq3_program, 100)
Expand Down Expand Up @@ -730,19 +739,23 @@ include("dm_simulator.jl")

sv_simulator = StateVectorSimulator(3, 0)
dm_simulator = DensityMatrixSimulator(3, 0)
oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), all_gates_qasm, nothing)
oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), all_gates_qasm, Dict("theta"=>0.665))
simulate(sv_simulator, oq3_program, 100)
simulate(dm_simulator, oq3_program, 100)

sv_simulator = StateVectorSimulator(2, 0)
dm_simulator = DensityMatrixSimulator(2, 0)
sv_oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), sv_exact_results_qasm, nothing)
dm_oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), dm_exact_results_qasm, nothing)
simulate(sv_simulator, sv_oq3_program, 0)
simulate(dm_simulator, dm_oq3_program, 0)
results = simulate(sv_simulator, sv_oq3_program, 0)
map(StructTypes.lower, results.resultTypes)
results = simulate(dm_simulator, dm_oq3_program, 0)
map(StructTypes.lower, results.resultTypes)
oq3_program = OpenQasmProgram(braketSchemaHeader("braket.ir.openqasm.program", "1"), shots_results_qasm, nothing)
simulate(sv_simulator, oq3_program, 10)
simulate(dm_simulator, oq3_program, 10)
results = simulate(sv_simulator, oq3_program, 10)
map(StructTypes.lower, results.resultTypes)
results = simulate(dm_simulator, oq3_program, 10)
map(StructTypes.lower, results.resultTypes)
end
end
end # module BraketSimulator
6 changes: 5 additions & 1 deletion src/circuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ julia> qubit_count(c)
qubit_count(c::Circuit) = length(qubits(c))
qubit_count(p::Program) = length(qubits(p))

Base.convert(::Type{Program}, c::Circuit) = (basis_rotation_instructions!(c); return Program(braketSchemaHeader("braket.ir.jaqcd.program" ,"1"), c.instructions, ir.(c.result_types, Val(:JAQCD)), c.basis_rotation_instructions))
function Base.convert(::Type{Program}, c::Circuit) # nosemgrep
lowered_rts = map(StructTypes.lower, c.result_types)
header = braketSchemaHeader("braket.ir.jaqcd.program" ,"1")
return Program(header, c.instructions, lowered_rts, c.basis_rotation_instructions)
end
Program(c::Circuit) = convert(Program, c)

extract_observable(rt::ObservableResult) = rt.observable
Expand Down
29 changes: 11 additions & 18 deletions src/custom_gates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ mutable struct DoubleExcitation <: AngledGate{1}
new(angle, Float64(pow_exponent))
end
qubit_count(::Type{DoubleExcitation}) = 4
function matrix_rep_raw(g::DoubleExcitation)
cosϕ = cos(g.angle[1] / 2.0)
sinϕ = sin(g.angle[1] / 2.0)

mat = diagm(ones(ComplexF64, 16))
mat[4, 4] = cosϕ
mat[13, 13] = cosϕ
mat[4, 13] = -sinϕ
mat[13, 4] = sinϕ
function matrix_rep_raw(::DoubleExcitation, ϕ) # nosemgrep
sθ, cθ = sincos(ϕ/2.0)
mat = diagm(ones(ComplexF64, 16))
mat[4, 4] = cθ
mat[13, 13] = cθ
mat[4, 13] = -sθ
mat[13, 4] = sθ
return SMatrix{16,16,ComplexF64}(mat)
end

Expand All @@ -24,11 +22,7 @@ mutable struct SingleExcitation <: AngledGate{1}
new(angle, Float64(pow_exponent))
end
qubit_count(::Type{SingleExcitation}) = 2
function matrix_rep_raw(g::SingleExcitation)
cosϕ = cos(g.angle[1] / 2.0)
sinϕ = sin(g.angle[1] / 2.0)
return SMatrix{4,4,ComplexF64}([1.0 0 0 0; 0 cosϕ sinϕ 0; 0 -sinϕ cosϕ 0; 0 0 0 1.0])
end
matrix_rep_raw(::SingleExcitation, ϕ) = ((sθ, cθ) = sincos(ϕ/2.0); return SMatrix{4,4,ComplexF64}(complex(1.0), 0, 0, 0, 0, cθ, -sθ, 0, 0, sθ, cθ, 0, 0, 0, 0, complex(1.0)))
"""
MultiRz(angle)

Expand Down Expand Up @@ -95,13 +89,12 @@ function apply_gate!(
) where {T<:Complex}
n_amps, endian_ts = get_amps_and_qubits(state_vec, t1, t2, t3, t4)
ordered_ts = sort(collect(endian_ts))
cosϕ = cos(g.angle[1] / 2.0)
sinϕ = sin(g.angle[1] / 2.0)
sinϕ, cosϕ = sincos(g.angle[1] * g.pow_exponent / 2.0)
e_t1, e_t2, e_t3, e_t4 = endian_ts
Threads.@threads for ix = 0:div(n_amps, 2^4)-1
padded_ix = pad_bits(ix, ordered_ts)
i0011 = flip_bits(padded_ix, (e_t3, e_t4)) + 1
i1100 = flip_bits(padded_ix, (e_t1, e_t2)) + 1
i0011 = flip_bits(padded_ix, (e_t3, e_t4)) + 1
i1100 = flip_bits(padded_ix, (e_t1, e_t2)) + 1
@inbounds begin
amp0011 = state_vec[i0011]
amp1100 = state_vec[i1100]
Expand Down
Loading
Loading