@@ -62,6 +62,15 @@ py_gemv(sycl::queue q,
62
62
throw std::runtime_error (" Inconsistent shapes." );
63
63
}
64
64
65
+ auto q_ctx = q.get_context ();
66
+ if (q_ctx != matrix.get_queue ().get_context () ||
67
+ q_ctx != vector.get_queue ().get_context () ||
68
+ q_ctx != result.get_queue ().get_context ())
69
+ {
70
+ throw std::runtime_error (
71
+ " USM allocation is not bound to the context in execution queue." );
72
+ }
73
+
65
74
int mat_flags = matrix.get_flags ();
66
75
int v_flags = vector.get_flags ();
67
76
int r_flags = result.get_flags ();
@@ -176,6 +185,14 @@ py_sub(sycl::queue q,
176
185
throw std::runtime_error (" Vectors must have the same length" );
177
186
}
178
187
188
+ if (q.get_context () != in_v1.get_queue ().get_context () ||
189
+ q.get_context () != in_v2.get_queue ().get_context () ||
190
+ q.get_context () != out_r.get_queue ().get_context ())
191
+ {
192
+ throw std::runtime_error (
193
+ " USM allocation is not bound to the context in execution queue" );
194
+ }
195
+
179
196
int in_v1_flags = in_v1.get_flags ();
180
197
int in_v2_flags = in_v2.get_flags ();
181
198
int out_r_flags = out_r.get_flags ();
@@ -277,6 +294,13 @@ py_axpby_inplace(sycl::queue q,
277
294
throw std::runtime_error (" Vectors must have the same length" );
278
295
}
279
296
297
+ if (q.get_context () != x.get_queue ().get_context () ||
298
+ q.get_context () != y.get_queue ().get_context ())
299
+ {
300
+ throw std::runtime_error (
301
+ " USM allocation is not bound to the context in execution queue" );
302
+ }
303
+
280
304
int x_flags = x.get_flags ();
281
305
int y_flags = y.get_flags ();
282
306
@@ -373,6 +397,11 @@ py::object py_norm_squared_blocking(sycl::queue q,
373
397
throw std::runtime_error (" Vector must be contiguous." );
374
398
}
375
399
400
+ if (q.get_context () != r.get_queue ().get_context ()) {
401
+ throw std::runtime_error (
402
+ " USM allocation is not bound to the context in execution queue" );
403
+ }
404
+
376
405
int r_typenum = r.get_typenum ();
377
406
if ((r_typenum != UAR_DOUBLE) && (r_typenum != UAR_FLOAT) &&
378
407
(r_typenum != UAR_CDOUBLE) && (r_typenum != UAR_CFLOAT))
@@ -437,6 +466,13 @@ py::object py_dot_blocking(sycl::queue q,
437
466
throw std::runtime_error (" Vectors must be contiguous." );
438
467
}
439
468
469
+ if (q.get_context () != v1.get_queue ().get_context () ||
470
+ q.get_context () != v2.get_queue ().get_context ())
471
+ {
472
+ throw std::runtime_error (
473
+ " USM allocation is not bound to the context in execution queue" );
474
+ }
475
+
440
476
int v1_typenum = v1.get_typenum ();
441
477
int v2_typenum = v2.get_typenum ();
442
478
@@ -500,6 +536,80 @@ py::object py_dot_blocking(sycl::queue q,
500
536
return res;
501
537
}
502
538
539
+ int py_cg_solve (sycl::queue exec_q,
540
+ dpctl::tensor::usm_ndarray Amat,
541
+ dpctl::tensor::usm_ndarray bvec,
542
+ dpctl::tensor::usm_ndarray xvec,
543
+ double rs_tol,
544
+ const std::vector<sycl::event> &depends = {})
545
+ {
546
+ if (Amat.get_ndim () != 2 || bvec.get_ndim () != 1 || xvec.get_ndim () != 1 ) {
547
+ throw py::value_error (" Expecting a matrix and two vectors" );
548
+ }
549
+
550
+ py::ssize_t n0 = Amat.get_shape (0 );
551
+ py::ssize_t n1 = Amat.get_shape (1 );
552
+
553
+ if (n0 != n1) {
554
+ throw py::value_error (" Matrix must be square." );
555
+ }
556
+
557
+ if (n0 != bvec.get_shape (0 ) || n0 != xvec.get_shape (0 )) {
558
+ throw py::value_error (
559
+ " Dimensions of the matrix and vectors are not consistent." );
560
+ }
561
+
562
+ bool all_contig = (Amat.get_flags () & USM_ARRAY_C_CONTIGUOUS) &&
563
+ (bvec.get_flags () & USM_ARRAY_C_CONTIGUOUS) &&
564
+ (xvec.get_flags () & USM_ARRAY_C_CONTIGUOUS);
565
+ if (!all_contig) {
566
+ throw py::value_error (" All inputs must be C-contiguous" );
567
+ }
568
+
569
+ int A_typenum = Amat.get_typenum ();
570
+ int b_typenum = bvec.get_typenum ();
571
+ int x_typenum = xvec.get_typenum ();
572
+
573
+ if (A_typenum != b_typenum || A_typenum != x_typenum) {
574
+ throw py::value_error (" All arrays must have the same type" );
575
+ }
576
+
577
+ if (exec_q.get_context () != Amat.get_queue ().get_context () ||
578
+ exec_q.get_context () != bvec.get_queue ().get_context () ||
579
+ exec_q.get_context () != xvec.get_queue ().get_context ())
580
+ {
581
+ throw std::runtime_error (
582
+ " USM allocations are not bound to context in execution queue" );
583
+ }
584
+
585
+ const char *A_ch = Amat.get_data ();
586
+ const char *b_ch = bvec.get_data ();
587
+ char *x_ch = xvec.get_data ();
588
+
589
+ if (A_typenum == UAR_DOUBLE) {
590
+ using T = double ;
591
+ int iters = cg_solver::cg_solve<T>(
592
+ exec_q, n0, reinterpret_cast <const T *>(A_ch),
593
+ reinterpret_cast <const T *>(b_ch), reinterpret_cast <T *>(x_ch),
594
+ depends, static_cast <T>(rs_tol));
595
+
596
+ return iters;
597
+ }
598
+ else if (A_typenum == UAR_FLOAT) {
599
+ using T = float ;
600
+ int iters = cg_solver::cg_solve<T>(
601
+ exec_q, n0, reinterpret_cast <const T *>(A_ch),
602
+ reinterpret_cast <const T *>(b_ch), reinterpret_cast <T *>(x_ch),
603
+ depends, static_cast <T>(rs_tol));
604
+
605
+ return iters;
606
+ }
607
+ else {
608
+ throw std::runtime_error (
609
+ " Unsupported data type. Use single or double precision." );
610
+ }
611
+ }
612
+
503
613
PYBIND11_MODULE (_onemkl, m)
504
614
{
505
615
// Import the dpctl extensions
@@ -518,4 +628,10 @@ PYBIND11_MODULE(_onemkl, m)
518
628
py::arg (" exec_queue" ), py::arg (" r" ), py::arg (" depends" ) = py::list ());
519
629
m.def (" dot_blocking" , &py_dot_blocking, " <v1, v2>" , py::arg (" exec_queue" ),
520
630
py::arg (" v1" ), py::arg (" v2" ), py::arg (" depends" ) = py::list ());
631
+
632
+ m.def (" cpp_cg_solve" , &py_cg_solve,
633
+ " Dispatch to call C++ implementation of cg_solve" ,
634
+ py::arg (" exec_queue" ), py::arg (" Amat" ), py::arg (" bvec" ),
635
+ py::arg (" xvec" ), py::arg (" rs_squared_tolerance" ) = py::float_ (1e-20 ),
636
+ py::arg (" depends" ) = py::list ());
521
637
}
0 commit comments