Skip to content

Commit

Permalink
Matrix-Vector Multiplication
Browse files Browse the repository at this point in the history
More readable dense operation.
  • Loading branch information
ax3l committed Oct 11, 2024
1 parent 482ce2e commit df452e2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 90 deletions.
38 changes: 8 additions & 30 deletions src/particles/elements/RFCavity.H
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,6 @@ namespace RFCavityData
// shift due to alignment errors of the element
shift_in(x, y, px, py);

// intialize output values
amrex::ParticleReal xout = x;
amrex::ParticleReal yout = y;
amrex::ParticleReal tout = t;

// initialize output values of momenta
amrex::ParticleReal pxout = px;
amrex::ParticleReal pyout = py;
amrex::ParticleReal ptout = pt;

// get the linear map
amrex::SmallMatrix<amrex::ParticleReal, 6, 6, amrex::Order::F, 1> const R = refpart.map;

Expand All @@ -226,30 +216,18 @@ namespace RFCavityData
// https://uspas.fnal.gov/materials/09UNM/ComputationalMethods.pdf.
// R denotes the transfer matrix in the basis (x,px,y,py,t,pt),
// so that, e.g., R(3,4) = dyf/dpyi.
amrex::SmallVector<amrex::ParticleReal, 6, 1> const v{x, px, y, py, t, pt};

// push particles using the linear map
// clang-format off
xout = R(1,1)*x + R(1,2)*px + R(1,3)*y
+ R(1,4)*py + R(1,5)*t + R(1,6)*pt;
pxout = R(2,1)*x + R(2,2)*px + R(2,3)*y
+ R(2,4)*py + R(2,5)*t + R(2,6)*pt;
yout = R(3,1)*x + R(3,2)*px + R(3,3)*y
+ R(3,4)*py + R(3,5)*t + R(3,6)*pt;
pyout = R(4,1)*x + R(4,2)*px + R(4,3)*y
+ R(4,4)*py + R(4,5)*t + R(4,6)*pt;
tout = R(5,1)*x + R(5,2)*px + R(5,3)*y
+ R(5,4)*py + R(5,5)*t + R(5,6)*pt;
ptout = R(6,1)*x + R(6,2)*px + R(6,3)*y
+ R(6,4)*py + R(6,5)*t + R(6,6)*pt;
// clang-format on
auto const out = R * v;

// assign updated values
x = xout;
y = yout;
t = tout;
px = pxout;
py = pyout;
pt = ptout;
x = out[1];
px = out[2];
y = out[3];
py = out[4];
t = out[5];
pt = out[6];

// undo shift due to alignment errors of the element
shift_out(x, y, px, py);
Expand Down
38 changes: 8 additions & 30 deletions src/particles/elements/SoftQuad.H
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,6 @@ namespace SoftQuadrupoleData
// shift due to alignment errors of the element
shift_in(x, y, px, py);

// intialize output values
amrex::ParticleReal xout = x;
amrex::ParticleReal yout = y;
amrex::ParticleReal tout = t;

// initialize output values of momenta
amrex::ParticleReal pxout = px;
amrex::ParticleReal pyout = py;
amrex::ParticleReal ptout = pt;

// get the linear map
amrex::SmallMatrix<amrex::ParticleReal, 6, 6, amrex::Order::F, 1> const R = refpart.map;

Expand All @@ -231,30 +221,18 @@ namespace SoftQuadrupoleData
// https://uspas.fnal.gov/materials/09UNM/ComputationalMethods.pdf .
// R denotes the transfer matrix in the basis (x,px,y,py,t,pt),
// so that, e.g., R(3,4) = dyf/dpyi.
amrex::SmallVector<amrex::ParticleReal, 6, 1> const v{x, px, y, py, t, pt};

// push particles using the linear map
// clang-format off
xout = R(1,1)*x + R(1,2)*px + R(1,3)*y
+ R(1,4)*py + R(1,5)*t + R(1,6)*pt;
pxout = R(2,1)*x + R(2,2)*px + R(2,3)*y
+ R(2,4)*py + R(2,5)*t + R(2,6)*pt;
yout = R(3,1)*x + R(3,2)*px + R(3,3)*y
+ R(3,4)*py + R(3,5)*t + R(3,6)*pt;
pyout = R(4,1)*x + R(4,2)*px + R(4,3)*y
+ R(4,4)*py + R(4,5)*t + R(4,6)*pt;
tout = R(5,1)*x + R(5,2)*px + R(5,3)*y
+ R(5,4)*py + R(5,5)*t + R(5,6)*pt;
ptout = R(6,1)*x + R(6,2)*px + R(6,3)*y
+ R(6,4)*py + R(6,5)*t + R(6,6)*pt;
// clang-format on
auto const out = R * v;

// assign updated values
x = xout;
y = yout;
t = tout;
px = pxout;
py = pyout;
pt = ptout;
x = out[1];
px = out[2];
y = out[3];
py = out[4];
t = out[5];
pt = out[6];

// undo shift due to alignment errors of the element
shift_out(x, y, px, py);
Expand Down
38 changes: 8 additions & 30 deletions src/particles/elements/SoftSol.H
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,6 @@ namespace SoftSolenoidData
// shift due to alignment errors of the element
shift_in(x, y, px, py);

// intialize output values
amrex::ParticleReal xout = x;
amrex::ParticleReal yout = y;
amrex::ParticleReal tout = t;

// initialize output values of momenta
amrex::ParticleReal pxout = px;
amrex::ParticleReal pyout = py;
amrex::ParticleReal ptout = pt;

// get the linear map
amrex::SmallMatrix<amrex::ParticleReal, 6, 6, amrex::Order::F, 1> const R = refpart.map;

Expand All @@ -242,30 +232,18 @@ namespace SoftSolenoidData
// https://uspas.fnal.gov/materials/09UNM/ComputationalMethods.pdf.
// R denotes the transfer matrix in the basis (x,px,y,py,t,pt),
// so that, e.g., R(3,4) = dyf/dpyi.
amrex::SmallVector<amrex::ParticleReal, 6, 1> const v{x, px, y, py, t, pt};

// push particles using the linear map
// clang-format off
xout = R(1,1)*x + R(1,2)*px + R(1,3)*y
+ R(1,4)*py + R(1,5)*t + R(1,6)*pt;
pxout = R(2,1)*x + R(2,2)*px + R(2,3)*y
+ R(2,4)*py + R(2,5)*t + R(2,6)*pt;
yout = R(3,1)*x + R(3,2)*px + R(3,3)*y
+ R(3,4)*py + R(3,5)*t + R(3,6)*pt;
pyout = R(4,1)*x + R(4,2)*px + R(4,3)*y
+ R(4,4)*py + R(4,5)*t + R(4,6)*pt;
tout = R(5,1)*x + R(5,2)*px + R(5,3)*y
+ R(5,4)*py + R(5,5)*t + R(5,6)*pt;
ptout = R(6,1)*x + R(6,2)*px + R(6,3)*y
+ R(6,4)*py + R(6,5)*t + R(6,6)*pt;
// clang-format on
auto const out = R * v;

// assign updated values
x = xout;
y = yout;
t = tout;
px = pxout;
py = pyout;
pt = ptout;
x = out[1];
px = out[2];
y = out[3];
py = out[4];
t = out[5];
pt = out[6];

// undo shift due to alignment errors of the element
shift_out(x, y, px, py);
Expand Down

0 comments on commit df452e2

Please sign in to comment.