-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlib.rs
107 lines (95 loc) · 3.14 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
use kessler::{run, CollisionEvent, ExplosionEvent, SatKind, Satellite};
use numpy::{PyArray1, PyArray3, ToPyArray};
use pyo3::prelude::*;
// Wrappers for the structs provided by kessler
#[pyclass(name = "CollisionEvent")]
#[derive(Debug)]
pub struct CollisionEventPyWrapper(CollisionEvent);
#[pymethods]
impl CollisionEventPyWrapper {
#[new]
pub fn __new__(
_py: Python,
satellite_one: &SatellitePyWrapper,
satellite_two: &SatellitePyWrapper,
min_characteristic_length: f32,
) -> Self {
let satellites = &[satellite_one.0.to_owned(), satellite_two.0.to_owned()];
CollisionEventPyWrapper(CollisionEvent::new(satellites, min_characteristic_length))
}
}
#[pyclass(name = "ExplosionEvent")]
#[derive(Debug)]
pub struct ExplosionEventPyWrapper(ExplosionEvent);
#[pymethods]
impl ExplosionEventPyWrapper {
#[new]
pub fn __new__(_py: Python, satellite: &SatellitePyWrapper, min_characteristic_length: f32) -> Self {
let satellite = satellite.0.to_owned();
ExplosionEventPyWrapper(ExplosionEvent::new(satellite, min_characteristic_length))
}
}
#[pyclass(name = "Satellite")]
#[derive(Debug)]
pub struct SatellitePyWrapper(Satellite);
#[pymethods]
impl SatellitePyWrapper {
#[new]
pub fn __new__(
_py: Python,
position: &PyArray1<f32>,
velocity: &PyArray1<f32>,
mass: f32,
) -> Self {
let position: [f32; 3] = {
unsafe {
let position_slice = position.as_slice().expect("Failed to get slice from position array");
if position_slice.len() != 3 {
panic!("Position array must have a length of 3");
}
let mut arr = [0.0; 3];
arr.copy_from_slice(position_slice);
arr
}
};
let velocity: [f32; 3] = {
unsafe {
let velocity_slice = velocity.as_slice().expect("Failed to get slice from velocity array");
if velocity_slice.len() != 3 {
panic!("Velocity array must have a length of 3");
}
let mut arr = [0.0; 3];
arr.copy_from_slice(velocity_slice);
arr
}
};
SatellitePyWrapper(
Satellite::new(
position,
velocity,
mass,
SatKind::Rb //TODO
)
)
}
}
// The name of the module must be the same as the rust package name
#[pymodule]
fn kesspy(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[pyfn(m)]
fn run_collision<'py>(py: Python<'py>, event: &CollisionEventPyWrapper) -> &'py PyArray3<f32> {
let event = &event.0;
let output = run(event);
output.to_pyarray(py)
}
#[pyfn(m)]
fn run_explosion<'py>(py: Python<'py>, event: &ExplosionEventPyWrapper) -> &'py PyArray3<f32> {
let event = &event.0;
let output = run(event);
output.to_pyarray(py)
}
m.add_class::<ExplosionEventPyWrapper>()?;
m.add_class::<CollisionEventPyWrapper>()?;
m.add_class::<SatellitePyWrapper>()?;
Ok(())
}