Adding eye, eye_like, identity, identity_like functions. Feedback needed. #113
Closed
Description
Hello!
I'm about to submit a pull request for eye
and identity
functions. I've tried to follow the style of the library so my code deviates in two ways from the numpy
api:
1 - eye
/eye_like
and identity
/identity_like
functions are separate from each other instead of like
being a function parameter.
2 - The eye
array is constructed from a shape parameter instead of 'n' and 'm' parameters defined in numpy
.
Are these changes okay?
/** Fill an array of the given shape with ones in the diagonal and zeros everywhere else. */
array eye(const std::vector<int>& shape, int k, Dtype dtype, StreamOrDevice s = {});
array eye_like(const array& a, int k, StreamOrDevice s = {});
/** Fill an array of the given shape with ones in the major diagonal and zeros everywhere else. */
array identity(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
array identity_like(const array& a, StreamOrDevice s = {});
array eye(
const std::vector<int>& shape,
int k,
Dtype dtype,
StreamOrDevice s /* = {} */) {
if (shape.size() != 2) {
throw std::invalid_argument(
"Shape must be 2-dimensional for eye function.");
}
int n = shape[0];
int m = shape[1];
array result = zeros(shape, dtype, s);
// Check if k is within the bounds
if (k >= m || -k >= n) {
return result;
}
// Calculate the starting point and the length of the diagonal
int start = k >= 0 ? k : 0;
int end = k <= 0 ? n + k : m;
int length = std::min(n, m) - std::abs(k);
// Generate diagonal indices
array diag_indices = arange(start, end * m + start, m + 1, dtype, s);
array ones = full({length}, 1, dtype, s);
result = scatter(result, {diag_indices}, ones, 0, s);
return result;
}
array eye_like(const array& a, int k, StreamOrDevice s /* = {} */) {
return eye(a.shape(), k, a.dtype(), to_stream(s));
}
array identity(
const std::vector<int>& shape,
Dtype dtype,
StreamOrDevice s /* = {} */) {
return eye(shape, 0, dtype, to_stream(s));
}
array identity_like(const array& a, StreamOrDevice s /* = {} */) {
return identity(a.shape(), a.dtype(), to_stream(s));
}