Skip to content

feat: reproduce from primaries #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .github/workflows/build-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ on:

pull_request:
push:
branches:
[main]
branches: [main]

## Paste this snippet into the workflow file to enable tmate debugging
## Paste this snippet into the workflow file to enable tmate debugging
# - name: Setup tmate session
# uses: mxschmitt/action-tmate@v3
# if:
Expand Down Expand Up @@ -208,6 +207,10 @@ jobs:
command: |
python -c "import geant4_python_application as g4; g4.install_datasets()"

- name: Run tests debug
run: |
python -m pytest -s -vv tests/test_analysis.py::test_sensitive

- name: Run tests
run: |
python -m pytest -vv --reruns 3 --reruns-delay 30 --only-rerun "(?i)http|timeout|connection|socket|resolve"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ classifiers = [
dependencies = [
"awkward",
"vector",
"hepunits",
"fsspec",
"pyarrow",
"awkward-pandas",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Application {
void SetupAction();

void Initialize();
py::list Run(const py::object& primaries);
std::vector<py::object> Run(const py::object& primaries);

bool IsSetup() const;
bool IsInitialized() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ void InsertEvent(const G4Event* event, Builders& builder);
void InsertTrack(const G4Track* track, Builders& builder);
void InsertStep(const G4Step* step, Builders& builder);

py::object SnapshotBuilder(Builders& builder);
py::object BuilderToObject(std::unique_ptr<Builders> builders);

namespace units {
static constexpr auto energy = CLHEP::keV;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class EventAction : public G4UserEventAction {
void BeginOfEventAction(const G4Event*) override;

void EndOfEventAction(const G4Event*) override;

static double sensitiveEnergy;
};

}// namespace geant4_app
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

#pragma once

#include <G4Run.hh>
#include <G4RunManager.hh>
#include <G4UserRunAction.hh>

Expand All @@ -20,12 +21,12 @@ class RunAction : public G4UserRunAction {

/// Only one instance of RunAction is created for each thread.
static data::Builders& GetBuilder();
static std::unique_ptr<py::list> GetContainer();
static std::unique_ptr<std::vector<py::object>> GetContainer();

private:
std::unique_ptr<data::Builders> builder = nullptr;
std::mutex mutex;
static std::unique_ptr<py::list> container;
static std::unique_ptr<std::vector<py::object>> container;
static std::vector<std::unique_ptr<data::Builders>> buildersToSnapshot;
};

Expand Down
11 changes: 7 additions & 4 deletions src/geant4_application/src/Application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ Application::Application() {
Application::~Application() = default;

void Application::SetupRandomEngine() {
G4Random::setTheEngine(new CLHEP::RanecuEngine);
G4Random::setTheEngine(new CLHEP::MTwistEngine);
if (randomSeed == 0) {
randomSeed = std::random_device()();
}
cout << "seed set to: " << randomSeed << endl;
G4Random::setTheSeed(randomSeed);
}

Expand Down Expand Up @@ -88,6 +89,10 @@ void Application::SetupManager(unsigned short nThreads) {
delete G4VSteppingVerbose::GetInstance();
SteppingVerbose::SetInstance(new SteppingVerbose);

// https://geant4-forum.web.cern.ch/t/different-random-seeds-but-same-results/324/5
// seed needs to be setup before the run manager is created
SetupRandomEngine();

const auto runManagerType = nThreads > 0 ? G4RunManagerType::MTOnly : G4RunManagerType::SerialOnly;
runManager = unique_ptr<G4RunManager>(G4RunManagerFactory::CreateRunManager(runManagerType));
if (nThreads > 0) {
Expand All @@ -110,13 +115,11 @@ void Application::Initialize() {
throw runtime_error("Application is already initialized");
}

SetupRandomEngine();

runManager->Initialize();
isInitialized = true;
}

py::list Application::Run(const py::object& primaries) {
vector<py::object> Application::Run(const py::object& primaries) {
if (!IsInitialized()) {
Initialize();
}
Expand Down
103 changes: 52 additions & 51 deletions src/geant4_application/src/DataModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,82 +429,83 @@ py::object snapshot_builder(const T& builder) {
return from_buffers(builder.form(), builder.length(), container);
}

py::object SnapshotBuilder(Builders& builder) {
py::object BuilderToObject(unique_ptr<Builders> builder) {
// this will automatically clear the builders after the pointer goes out of scope
py::dict snapshot;
if (builder.fields.contains("run")) {
snapshot["run"] = snapshot_builder(builder.run);
if (builder->fields.contains("run")) {
snapshot["run"] = snapshot_builder(builder->run);
}
if (builder.fields.contains("id")) {
snapshot["id"] = snapshot_builder(builder.id);
if (builder->fields.contains("id")) {
snapshot["id"] = snapshot_builder(builder->id);
}
if (builder.fields.contains("primaries")) {
snapshot["primaries"] = snapshot_builder(builder.primaries);
if (builder->fields.contains("primaries")) {
snapshot["primaries"] = snapshot_builder(builder->primaries);
}
if (builder.fields.contains("track_id")) {
snapshot["track_id"] = snapshot_builder(builder.track_id);
if (builder->fields.contains("track_id")) {
snapshot["track_id"] = snapshot_builder(builder->track_id);
}
if (builder.fields.contains("track_parent_id")) {
snapshot["track_parent_id"] = snapshot_builder(builder.track_parent_id);
if (builder->fields.contains("track_parent_id")) {
snapshot["track_parent_id"] = snapshot_builder(builder->track_parent_id);
}
if (builder.fields.contains("track_initial_energy")) {
snapshot["track_initial_energy"] = snapshot_builder(builder.track_initial_energy);
if (builder->fields.contains("track_initial_energy")) {
snapshot["track_initial_energy"] = snapshot_builder(builder->track_initial_energy);
}
if (builder.fields.contains("track_initial_time")) {
snapshot["track_initial_time"] = snapshot_builder(builder.track_initial_time);
if (builder->fields.contains("track_initial_time")) {
snapshot["track_initial_time"] = snapshot_builder(builder->track_initial_time);
}
if (builder.fields.contains("track_weight")) {
snapshot["track_weight"] = snapshot_builder(builder.track_weight);
if (builder->fields.contains("track_weight")) {
snapshot["track_weight"] = snapshot_builder(builder->track_weight);
}
if (builder.fields.contains("track_initial_position")) {
snapshot["track_initial_position"] = snapshot_builder(builder.track_initial_position);
if (builder->fields.contains("track_initial_position")) {
snapshot["track_initial_position"] = snapshot_builder(builder->track_initial_position);
}
if (builder.fields.contains("track_initial_momentum")) {
snapshot["track_initial_momentum"] = snapshot_builder(builder.track_initial_momentum);
if (builder->fields.contains("track_initial_momentum")) {
snapshot["track_initial_momentum"] = snapshot_builder(builder->track_initial_momentum);
}
if (builder.fields.contains("track_particle")) {
snapshot["track_particle"] = snapshot_builder(builder.track_particle);
if (builder->fields.contains("track_particle")) {
snapshot["track_particle"] = snapshot_builder(builder->track_particle);
}
if (builder.fields.contains("track_particle_type")) {
snapshot["track_particle_type"] = snapshot_builder(builder.track_particle_type);
if (builder->fields.contains("track_particle_type")) {
snapshot["track_particle_type"] = snapshot_builder(builder->track_particle_type);
}
if (builder.fields.contains("track_creator_process")) {
snapshot["track_creator_process"] = snapshot_builder(builder.track_creator_process);
if (builder->fields.contains("track_creator_process")) {
snapshot["track_creator_process"] = snapshot_builder(builder->track_creator_process);
}
if (builder.fields.contains("track_creator_process_type")) {
snapshot["track_creator_process_type"] = snapshot_builder(builder.track_creator_process_type);
if (builder->fields.contains("track_creator_process_type")) {
snapshot["track_creator_process_type"] = snapshot_builder(builder->track_creator_process_type);
}
if (builder.fields.contains("track_children_ids")) {
snapshot["track_children_ids"] = snapshot_builder(builder.track_children_ids);
if (builder->fields.contains("track_children_ids")) {
snapshot["track_children_ids"] = snapshot_builder(builder->track_children_ids);
}
if (builder.fields.contains("step_energy")) {
snapshot["step_energy"] = snapshot_builder(builder.step_energy);
if (builder->fields.contains("step_energy")) {
snapshot["step_energy"] = snapshot_builder(builder->step_energy);
}
if (builder.fields.contains("step_time")) {
snapshot["step_time"] = snapshot_builder(builder.step_time);
if (builder->fields.contains("step_time")) {
snapshot["step_time"] = snapshot_builder(builder->step_time);
}
if (builder.fields.contains("step_track_kinetic_energy")) {
snapshot["step_track_kinetic_energy"] = snapshot_builder(builder.step_track_kinetic_energy);
if (builder->fields.contains("step_track_kinetic_energy")) {
snapshot["step_track_kinetic_energy"] = snapshot_builder(builder->step_track_kinetic_energy);
}
if (builder.fields.contains("step_process")) {
snapshot["step_process"] = snapshot_builder(builder.step_process);
if (builder->fields.contains("step_process")) {
snapshot["step_process"] = snapshot_builder(builder->step_process);
}
if (builder.fields.contains("step_process_type")) {
snapshot["step_process_type"] = snapshot_builder(builder.step_process_type);
if (builder->fields.contains("step_process_type")) {
snapshot["step_process_type"] = snapshot_builder(builder->step_process_type);
}
if (builder.fields.contains("step_volume")) {
snapshot["step_volume"] = snapshot_builder(builder.step_volume);
if (builder->fields.contains("step_volume")) {
snapshot["step_volume"] = snapshot_builder(builder->step_volume);
}
if (builder.fields.contains("step_volume_post")) {
snapshot["step_volume_post"] = snapshot_builder(builder.step_volume_post);
if (builder->fields.contains("step_volume_post")) {
snapshot["step_volume_post"] = snapshot_builder(builder->step_volume_post);
}
if (builder.fields.contains("step_nucleus")) {
snapshot["step_nucleus"] = snapshot_builder(builder.step_nucleus);
if (builder->fields.contains("step_nucleus")) {
snapshot["step_nucleus"] = snapshot_builder(builder->step_nucleus);
}
if (builder.fields.contains("step_position")) {
snapshot["step_position"] = snapshot_builder(builder.step_position);
if (builder->fields.contains("step_position")) {
snapshot["step_position"] = snapshot_builder(builder->step_position);
}
if (builder.fields.contains("step_momentum")) {
snapshot["step_momentum"] = snapshot_builder(builder.step_momentum);
if (builder->fields.contains("step_momentum")) {
snapshot["step_momentum"] = snapshot_builder(builder->step_momentum);
}
return snapshot;
}
Expand Down
5 changes: 5 additions & 0 deletions src/geant4_application/src/EventAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@ using namespace geant4_app;
EventAction::EventAction() : G4UserEventAction() {}

void EventAction::BeginOfEventAction(const G4Event* event) {
sensitiveEnergy = 0;
data::InsertEventBegin(event, RunAction::GetBuilder());
data::InsertEvent(event, RunAction::GetBuilder());
}

void EventAction::EndOfEventAction(const G4Event* event) {
data::InsertEventEnd(event, RunAction::GetBuilder());

cout << "END OF EVENT: event id: " << event->GetEventID() << " Sensitive energy: " << sensitiveEnergy << endl;
}

double EventAction::sensitiveEnergy = 0.0;
28 changes: 1 addition & 27 deletions src/geant4_application/src/PrimaryGeneratorAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,8 @@ using namespace geant4_app;
PrimaryGeneratorAction::PrimaryGeneratorAction() : G4VUserPrimaryGeneratorAction() {}

void PrimaryGeneratorAction::GeneratePrimaries(G4Event* event) {
if (!awkwardPrimaryEnergies.empty()) {
const double energy = awkwardPrimaryEnergies[event->GetEventID()];
gun.SetParticleEnergy(energy * keV);
}
if (!awkwardPrimaryPositions.empty()) {
const auto& position = awkwardPrimaryPositions[event->GetEventID()];
gun.SetParticlePosition(G4ThreeVector(position[0] * cm, position[1] * cm, position[2] * cm));
}
if (!awkwardPrimaryDirections.empty()) {
const auto& direction = awkwardPrimaryDirections[event->GetEventID()];
gun.SetParticleMomentumDirection(G4ThreeVector(direction[0], direction[1], direction[2]));
}
if (!awkwardPrimaryParticles.empty()) {
const auto& particleAwkward = awkwardPrimaryParticles[event->GetEventID()];
auto* particle = G4ParticleTable::GetParticleTable()->FindParticle(particleAwkward);
if (particle == nullptr) {
throw runtime_error("PrimaryGeneratorAction::GeneratePrimaries - particle '" + particleAwkward + "' not found");
}
gun.SetParticleDefinition(particle);
}

if (generatorType == "gun") {
gun.GeneratePrimaryVertex(event);
} else if (generatorType == "gps") {
gps.GeneratePrimaryVertex(event);
} else {
throw runtime_error("PrimaryGeneratorAction::GeneratePrimaries - generatorType must be 'gun', 'gps'");
}
gun.GeneratePrimaryVertex(event);
}

void PrimaryGeneratorAction::SetGeneratorType(const string& type) {
Expand Down
26 changes: 18 additions & 8 deletions src/geant4_application/src/RunAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
#include "geant4_application/RunAction.h"
#include "geant4_application/SteppingVerbose.h"

#include <G4Threading.hh>

#include <iostream>

using namespace std;
using namespace geant4_app;

RunAction::RunAction() : G4UserRunAction() {}

void RunAction::BeginOfRunAction(const G4Run*) {
void RunAction::BeginOfRunAction(const G4Run* run) {
lock_guard<std::mutex> lock(mutex);

builder = make_unique<data::Builders>(Application::GetEventFields());
Expand All @@ -18,35 +20,43 @@ void RunAction::BeginOfRunAction(const G4Run*) {
steppingVerbose->Initialize();

if (IsMaster()) {
container = make_unique<py::list>();
container = make_unique<vector<py::object>>();
}

cout << "RUN ID: " << run->GetRunID() << " RANDOM SEED: " << G4Random::getTheSeed() << endl;

for (int i = 0; i < 10; ++i) {
cout << "RANDOM NUMBER: " << G4UniformRand() << endl;
}
}

void RunAction::EndOfRunAction(const G4Run*) {
lock_guard<std::mutex> lock(mutex);

cout << "END OF RUN. Thread: " << G4Threading::G4GetThreadId() << endl;
if (!isMaster || !G4Threading::IsMultithreadedApplication()) {
buildersToSnapshot.push_back(std::move(builder));
cout << "LENGTH: " << buildersToSnapshot.back()->run.length() << endl;
}

if (isMaster) {
for (auto& builderToSnapshot: buildersToSnapshot) {
auto data = SnapshotBuilder(*builderToSnapshot);
container->append(data);
builderToSnapshot = nullptr;
for (auto& builders: buildersToSnapshot) {
container->push_back(BuilderToObject(std::move(builders)));
}
buildersToSnapshot.clear();
}

cout << "- END OF RUN. Thread: " << G4Threading::G4GetThreadId() << endl;
}

data::Builders& RunAction::GetBuilder() {
auto runAction = dynamic_cast<RunAction*>(const_cast<G4UserRunAction*>(G4RunManager::GetRunManager()->GetUserRunAction()));
return *runAction->builder;
}

unique_ptr<py::list> RunAction::container = nullptr;
unique_ptr<vector<py::object>> RunAction::container = nullptr;

unique_ptr<py::list> RunAction::GetContainer() {
unique_ptr<vector<py::object>> RunAction::GetContainer() {
return std::move(RunAction::container);
}

Expand Down
6 changes: 5 additions & 1 deletion src/geant4_application/src/SensitiveDetector.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@

#include "geant4_application/SensitiveDetector.h"
#include "geant4_application/EventAction.h"

using namespace std;
using namespace geant4_app;

SensitiveDetector::SensitiveDetector(const string& name) : G4VSensitiveDetector(name) {}

G4bool SensitiveDetector::ProcessHits(G4Step* step, G4TouchableHistory*) { return true; }
G4bool SensitiveDetector::ProcessHits(G4Step* step, G4TouchableHistory*) {
EventAction::sensitiveEnergy += step->GetTotalEnergyDeposit() / CLHEP::keV;
return true;
}

void SensitiveDetector::Initialize(G4HCofThisEvent*) {}
Loading