Skip to content

Develop shape_consistency #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 14, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions bayesml/_check.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Code Author
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
# Yasushi Esaki <esakiful@gmail.com>
# Jun Nishikawa <jun.b.nishikawa@gmail.com>
import numpy as np

_EPSILON = np.sqrt(np.finfo(np.float64).eps)
Expand Down Expand Up @@ -100,6 +102,22 @@ def pos_def_sym_mat(val,val_name,exception_class):
pass
raise(exception_class(val_name + " must be a positive definite symmetric 2-dimensional numpy.ndarray."))

def sym_mats(val,val_name,exception_class):
if type(val) is np.ndarray:
if val.ndim >= 2 and val.shape[-1] == val.shape[-2]:
if np.allclose(val, np.swapaxes(val,-1,-2)):
return val
raise(exception_class(val_name + " must be a symmetric 2-dimensional numpy.ndarray."))

def pos_def_sym_mats(val,val_name,exception_class):
sym_mats(val,val_name,exception_class)
try:
np.linalg.cholesky(val)
return val
except np.linalg.LinAlgError:
pass
raise(exception_class(val_name + " must be a positive definite symmetric 2-dimensional numpy.ndarray."))

def float_(val,val_name,exception_class):
if np.issubdtype(type(val),np.floating):
return val
Expand Down Expand Up @@ -163,6 +181,14 @@ def float_vec_sum_1(val,val_name,exception_class):
return val
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))

def float_vecs_sum_1(val,val_name,exception_class):
if type(val) is np.ndarray:
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
return val.astype(float)
if np.issubdtype(val.dtype,np.floating) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
return val
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1, and the sum along the last dimension must equal to 1."))

def int_(val,val_name,exception_class):
if np.issubdtype(type(val),np.integer):
return val
Expand All @@ -189,3 +215,9 @@ def onehot_vecs(val,val_name,exception_class):
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(val >= 0) and np.all(val.sum(axis=-1)==1):
return val
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors."))

def shape_consistency(val: int, val_name: str, correct: int, correct_name: str, exception_class):
if val != correct:
message = (f"{val_name} must coincide with {correct_name}: "
+ f"{val_name} = {val}, {correct_name} = {correct}")
raise(exception_class(message))