Skip to content

Adding eye, eye_like, identity, identity_like functions. Feedback needed. #113

Closed
@cyrilzakka

Description

Hello!
I'm about to submit a pull request for eye and identity functions. I've tried to follow the style of the library so my code deviates in two ways from the numpy api:
1 - eye/eye_like and identity/identity_like functions are separate from each other instead of like being a function parameter.
2 - The eye array is constructed from a shape parameter instead of 'n' and 'm' parameters defined in numpy.
Are these changes okay?

/** Fill an array of the given shape with ones in the diagonal and zeros everywhere else. */
array eye(const std::vector<int>& shape, int k, Dtype dtype, StreamOrDevice s = {});
array eye_like(const array& a, int k, StreamOrDevice s = {});

/** Fill an array of the given shape with ones in the major diagonal and zeros everywhere else. */
array identity(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
array identity_like(const array& a, StreamOrDevice s = {});
array eye(
    const std::vector<int>& shape,
    int k,
    Dtype dtype,
    StreamOrDevice s /* = {} */) {
  if (shape.size() != 2) {
    throw std::invalid_argument(
        "Shape must be 2-dimensional for eye function.");
  }

  int n = shape[0];
  int m = shape[1];
  array result = zeros(shape, dtype, s);

  // Check if k is within the bounds
  if (k >= m || -k >= n) {
    return result;
  }

  // Calculate the starting point and the length of the diagonal
  int start = k >= 0 ? k : 0;
  int end = k <= 0 ? n + k : m;
  int length = std::min(n, m) - std::abs(k);

  // Generate diagonal indices
  array diag_indices = arange(start, end * m + start, m + 1, dtype, s);
  array ones = full({length}, 1, dtype, s);
  result = scatter(result, {diag_indices}, ones, 0, s);

  return result;
}

array eye_like(const array& a, int k, StreamOrDevice s /* = {} */) {
  return eye(a.shape(), k, a.dtype(), to_stream(s));
}

array identity(
    const std::vector<int>& shape,
    Dtype dtype,
    StreamOrDevice s /* = {} */) {
  return eye(shape, 0, dtype, to_stream(s));
}

array identity_like(const array& a, StreamOrDevice s /* = {} */) {
  return identity(a.shape(), a.dtype(), to_stream(s));
}

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions