Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions pybind/unitensor_py.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <format>
#include <vector>
#include <map>
#include <random>
#include <string>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand All @@ -9,10 +11,9 @@
#include <pybind11/numpy.h>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
#include <pybind11/warnings.h>

#include "cytnx.hpp"
// #include "../include/cytnx_error.hpp"
#include "complex.h"

namespace py = pybind11;
using namespace pybind11::literals;
Expand Down Expand Up @@ -66,6 +67,33 @@ void f_UniTensor_setelem_scal_int(UniTensor &self, const cytnx_uint64 &locator,
self.set_elem(tmp, rc);
}

// Parse UniTensor.get_blocks_ function's silent argument.
//
// This function should be replaced with `py::arg("silent") = false` after stopping
// support for the deprecated typo argument "slient".
inline bool parse_get_blocks_silent_arg(const py::args &args, const py::kwargs &kwargs) {
bool silent = false;
if (args.size() + kwargs.size() > 1) {
throw py::type_error("get_blocks_() takes at most 1 argument");
}
if (args.size() == 1) {
silent = py::cast<bool>(args[0]);
} else if (kwargs.contains("slient")) {
py::warnings::warn(
"Keyword 'slient' is deprecated and will be removed in v2.0.0; use 'silent' instead.",
PyExc_FutureWarning, 2);
silent = kwargs["slient"].cast<bool>();
} else if (kwargs.contains("silent")) {
silent = kwargs["silent"].cast<bool>();
} else if (kwargs.size() == 1) {
// The case that kwargs.size() > 1 has been caught above.
std::string kwarg_name = py::str(kwargs.begin()->first);
throw py::type_error(
std::format("'{}' is an invalid keyword argument for get_blocks_()", kwarg_name));
}
return silent;
}

void unitensor_binding(py::module &m) {
py::class_<cHclass>(m, "Helpclass")
.def("exists", &cHclass::exists)
Expand Down Expand Up @@ -700,11 +728,18 @@ void unitensor_binding(py::module &m) {
.def("get_blocks", [](const UniTensor &self) { return self.get_blocks(); })
.def(
"get_blocks_",
[](const UniTensor &self, const bool &silent) { return self.get_blocks_(silent); },
py::arg("silent") = false)
[](const UniTensor& self, py::args args, py::kwargs kwargs) {
return self.get_blocks_(parse_get_blocks_silent_arg(args, kwargs));
}
// ,py::arg("silent") = false // Uncmment this line after removing the deprecated argument.
)
.def(
"get_blocks_", [](UniTensor &self, const bool &silent) { return self.get_blocks_(silent); },
py::arg("silent") = false)
"get_blocks_",
[](UniTensor &self, py::args args, py::kwargs kwargs) {
return self.get_blocks_(parse_get_blocks_silent_arg(args, kwargs));
}
// ,py::arg("silent") = false // Uncmment this line after removing the deprecated argument.
)
.def(
"put_block",
[](UniTensor &self, const cytnx::Tensor &in, const cytnx_uint64 &idx) {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["scikit-build-core >=0.11", "pybind11 >=2.6"]
requires = ["scikit-build-core >=0.11", "pybind11 >=3.0"]
build-backend = "scikit_build_core.build"

[project]
Expand Down
94 changes: 94 additions & 0 deletions pytests/unitensor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import warnings

import pytest

import cytnx


def test_get_blocks_deprecated_slient_warning():
"""Test that using 'slient' parameter triggers FutureWarning"""
# Create a BlockTensor for testing
bond = cytnx.Bond(cytnx.BD_IN,
[cytnx.Qs(1) >> 1, cytnx.Qs(-1) >> 1],
[cytnx.Symmetry.U1()])
unitensor = cytnx.UniTensor([bond])

# Test that using deprecated 'slient' parameter raises FutureWarning
with pytest.warns(
FutureWarning,
match=
"Keyword 'slient' is deprecated and will be removed in v2.0.0; use 'silent' instead."
):
unitensor.get_blocks_(slient=True)

# Test that the method still works with deprecated parameter
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
result_deprecated = unitensor.get_blocks_(slient=True)
result_new = unitensor.get_blocks_(silent=True)
# Both should return the same result
assert len(result_deprecated) == len(result_new)


def test_get_blocks_new_silent_parameter():
"""Test that new 'silent' parameter works without warnings"""
bond = cytnx.Bond(cytnx.BD_IN,
[cytnx.Qs(1) >> 1, cytnx.Qs(-1) >> 1],
[cytnx.Symmetry.U1()])
unitensor = cytnx.UniTensor([bond])

# Test that new 'silent' parameter doesn't trigger warnings
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into exceptions
result_silent_true = unitensor.get_blocks_(silent=True)
result_silent_false = unitensor.get_blocks_(silent=False)
result_default = unitensor.get_blocks_()

# All should work without warnings
assert isinstance(result_silent_true, list)
assert isinstance(result_silent_false, list)
assert isinstance(result_default, list)


def test_get_blocks_positional_argument():
"""Test that positional argument still works"""
bond = cytnx.Bond(cytnx.BD_IN,
[cytnx.Qs(1) >> 1, cytnx.Qs(-1) >> 1],
[cytnx.Symmetry.U1()])
unitensor = cytnx.UniTensor([bond])

# Test positional argument
result_pos_true = unitensor.get_blocks_(True)
result_pos_false = unitensor.get_blocks_(False)

assert isinstance(result_pos_true, list)
assert isinstance(result_pos_false, list)


def test_get_blocks_argument_validation():
"""Test argument validation for get_blocks_"""
bond = cytnx.Bond(cytnx.BD_IN,
[cytnx.Qs(1) >> 1, cytnx.Qs(-1) >> 1],
[cytnx.Symmetry.U1()])
unitensor = cytnx.UniTensor([bond])

# Test too many arguments
with pytest.raises(
TypeError, match="get_blocks_\\(\\) takes at most 1 argument"):
unitensor.get_blocks_(True, False)

with pytest.raises(
TypeError, match="get_blocks_\\(\\) takes at most 1 argument"):
unitensor.get_blocks_(silent=True, slient=False)

with pytest.raises(
TypeError, match="get_blocks_\\(\\) takes at most 1 argument"):
unitensor.get_blocks_(True, silent=False)

# Test invalid keyword argument
with pytest.raises(
TypeError,
match=
"'invalid_arg' is an invalid keyword argument for get_blocks_\\(\\)"
):
unitensor.get_blocks_(invalid_arg=True)
Loading