Skip to content

Commit

Permalink
Fix several errors when strs are passed directly to endarray/etc
Browse files Browse the repository at this point in the history
  • Loading branch information
cgevans committed Oct 4, 2024
1 parent a62e5c0 commit f0bdd28
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/stickydesign/endclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ class EndArray(np.ndarray):

endtype: EndTypes

def __new__(cls, array: Union[Sequence[str], np.ndarray], endtype: EndTypes):
def __new__(cls, array: "Union[Sequence[str], np.ndarray, str, EndArray]", endtype: EndTypes = "S"):
if isinstance(array, EndArray):
return array
if isinstance(array, str):
array = [array]
if isinstance(array[0], str):
array = np.array(
[[nt[x] for x in y] for y in array], dtype=np.uint8)
Expand Down Expand Up @@ -165,16 +169,17 @@ def tolist(self) -> List[str]: # noqa: UP006


class Energetics(ABC):
@abstractproperty
def info(self) -> Dict[str, Any]: # noqa: F821
@property
@abstractmethod
def info(self) -> Dict[str, Any]:
...

@abstractmethod
def matching_uniform(self, seqs: Union[EndArray, np.ndarray]) -> np.ndarray:
def matching_uniform(self, seqs: Union[EndArray, np.ndarray, str]) -> np.ndarray:
...

@abstractmethod
def uniform(self, seqs1: Union[EndArray, np.ndarray], seqs2: Union[EndArray, np.ndarray]) -> np.ndarray:
def uniform(self, seqs1: Union[EndArray, np.ndarray, str], seqs2: Union[EndArray, np.ndarray, str]) -> np.ndarray:
...


Expand Down
3 changes: 3 additions & 0 deletions src/stickydesign/energetics_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _setup_params(self, temperature=37):
+ j * 4 + (3 - k)]

def matching_uniform(self, seqs):
seqs = EndArray(seqs)
assert seqs.endtype == 'S'
ps = PairSeqA(seqs)

Expand All @@ -187,6 +188,8 @@ def matching_uniform(self, seqs):
return -(np.sum(self.nndG[ps], axis=1) + self.initdG)

def uniform(self, seqs1, seqs2, debug=False):
seqs1 = EndArray(seqs1)
seqs2 = EndArray(seqs2)
assert seqs1.endtype == seqs2.endtype
assert seqs1.endtype == 'S'
if seqs1.shape != seqs2.shape:
Expand Down
5 changes: 5 additions & 0 deletions src/stickydesign/energetics_basic_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ def __init__(self, mismatchtype='max'):
"Mismatchtype {0} is not supported.".format(mismatchtype))

def matching_uniform(self, seqs):
seqs = EndArray(seqs)
return np.sum(self.nndG[tops(seqs)], axis=1) - self.initdG

def uniform_loopmismatch(self, seqs1, seqs2):
seqs1 = EndArray(seqs1)
seqs2 = EndArray(seqs2)
if seqs1.shape != seqs2.shape:
if seqs1.ndim == 1:
seqs1 = EndArray(
Expand Down Expand Up @@ -105,6 +108,8 @@ def uniform_loopmismatch(self, seqs1, seqs2):
return np.amax(en, 1) - self.initdG

def uniform_danglemismatch(self, seqs1, seqs2, fast=True):
seqs1 = EndArray(seqs1)
seqs2 = EndArray(seqs2)
if seqs1.shape != seqs2.shape:
if seqs1.ndim == 1:
seqs1 = EndArray(
Expand Down
3 changes: 3 additions & 0 deletions src/stickydesign/energetics_daoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _setup_params(self, temperature=37):
+ j * 4 + (3 - k)]

def matching_uniform(self, seqs):
seqs = EndArray(seqs)
ps = PairSeqA(seqs)

# In both cases here, the energy we want is the NN binding energy of
Expand All @@ -148,6 +149,8 @@ def matching_uniform(self, seqs):
return -(np.sum(self.nndG[ps], axis=1) + self.initdG + dcorr)

def uniform(self, seqs1, seqs2, debug=False):
seqs1 = EndArray(seqs1)
seqs2 = EndArray(seqs2)
if seqs1.shape != seqs2.shape:
if seqs1.ndim == 1:
seqs1 = EndArray(
Expand Down

0 comments on commit f0bdd28

Please sign in to comment.