12
12
#include < dmlc/type_traits.h>
13
13
#include < dmlc/registry.h>
14
14
#include < vector>
15
+ #include < map>
15
16
#include < string>
16
17
#include < memory>
17
18
#include " ./base.h"
@@ -446,7 +447,10 @@ MXNET_API void SampleGaussian(real_t mu, real_t sigma, NDArray *out);
446
447
/* ! \brief definition of NDArray function */
447
448
typedef std::function<void (NDArray **used_vars,
448
449
real_t *scalars,
449
- NDArray **mutate_vars)> NDArrayAPIFunction;
450
+ NDArray **mutate_vars,
451
+ int num_params,
452
+ char **param_keys,
453
+ char **param_vals)> NDArrayAPIFunction;
450
454
/* ! \brief mask information on how functions can be exposed */
451
455
enum NDArrayFunctionTypeMask {
452
456
/* ! \brief all the use_vars should go before scalar */
@@ -491,7 +495,8 @@ struct NDArrayFunctionReg
491
495
*/
492
496
inline NDArrayFunctionReg &set_function (void (*fsetvalue)(const real_t &rhs,
493
497
NDArray *out)) {
494
- body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) {
498
+ body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
499
+ int num_params, char **param_keys, char **param_vals) {
495
500
(*fsetvalue)(s[0 ], mutate_vars[0 ]);
496
501
};
497
502
num_mutate_vars = 1 ; num_scalars = 1 ;
@@ -507,8 +512,8 @@ struct NDArrayFunctionReg
507
512
inline NDArrayFunctionReg &set_function (void (*fbinary)(const NDArray &lhs,
508
513
const NDArray &rhs,
509
514
NDArray *out)) {
510
- body = [fbinary] (NDArray **used_vars,
511
- real_t *s, NDArray **mutate_vars ) {
515
+ body = [fbinary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
516
+ int num_params, char **param_keys, char **param_vals ) {
512
517
(*fbinary)(*used_vars[0 ], *used_vars[1 ], mutate_vars[0 ]);
513
518
};
514
519
num_use_vars = 2 ; num_mutate_vars = 1 ;
@@ -526,8 +531,8 @@ struct NDArrayFunctionReg
526
531
inline NDArrayFunctionReg &set_function (void (*fscalar)(const NDArray &lhs,
527
532
const real_t &rhs,
528
533
NDArray *out)) {
529
- body = [fscalar] (NDArray **used_vars,
530
- real_t *s, NDArray **mutate_vars ) {
534
+ body = [fscalar] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
535
+ int num_params, char **param_keys, char **param_vals ) {
531
536
(*fscalar)(*used_vars[0 ], s[0 ], mutate_vars[0 ]);
532
537
};
533
538
num_use_vars = 1 ; num_mutate_vars = 1 ; num_scalars = 1 ;
@@ -544,15 +549,36 @@ struct NDArrayFunctionReg
544
549
*/
545
550
inline NDArrayFunctionReg &set_function (void (*funary)(const NDArray &src,
546
551
NDArray *out)) {
547
- body = [funary] (NDArray **used_vars,
548
- real_t *s, NDArray **mutate_vars ) {
552
+ body = [funary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
553
+ int num_params, char **param_keys, char **param_vals ) {
549
554
(*funary)(*used_vars[0 ], mutate_vars[0 ]);
550
555
};
551
556
num_use_vars = 1 ; num_mutate_vars = 1 ;
552
557
type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget ;
553
558
this ->add_argument (" src" , " NDArray" , " Source input to the function." );
554
559
return *this ;
555
560
}
561
+ /* !
562
+ * \brief set the function body to a unary NDArray function
563
+ * this will also auto set the parameters correctly
564
+ * \param funary function body to set
565
+ * \return ref to the registered entry, used to set properties
566
+ */
567
+ inline NDArrayFunctionReg &set_function (
568
+ void (*fgeneric)(NDArray **used_vars,
569
+ real_t *s,
570
+ NDArray **mutate_vars,
571
+ const std::map<std::string, std::string>& param)) {
572
+ body = [fgeneric] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
573
+ int num_params, char **param_keys, char **param_vals) {
574
+ std::map<std::string, std::string> param;
575
+ for (int i = 0 ; i < num_params; ++i) {
576
+ param[param_keys[i]] = param_vals[i];
577
+ }
578
+ fgeneric (used_vars, s, mutate_vars, param);
579
+ };
580
+ return *this ;
581
+ }
556
582
/* !
557
583
* \brief set the number of mutate variables
558
584
* \param n number of mutate variablesx
0 commit comments