-
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
75070d7
commit 64bb363
Showing
14 changed files
with
211 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
__all__ = ["fit_eblup", "predict_eblup"] | ||
from samplics.apis.fit import fit | ||
from samplics.apis.predict import predict | ||
|
||
|
||
__all__ = ["fit", "predict"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,14 @@ | ||
from samplics.apis.base.fitting import _fit | ||
from samplics.types.containers import Array, DirectEst, FitMethod | ||
|
||
|
||
def fit(y: DirectEst | Array, x, method): | ||
match y: | ||
case DirectEst(): | ||
y_vec = y.est | ||
return _fit(y=y_vec, x=x, method=FitMethod.fh) | ||
case Array(): | ||
return _fit(y=y, x=x, method=FitMethod.fh) | ||
case _: | ||
raise TypeError("The type of `y` is not supported!") | ||
from samplics.types import AuxVars, DirectEst, FitMethod | ||
|
||
|
||
def fit(y: DirectEst, x: AuxVars, method: FitMethod): | ||
# breakpoint() | ||
return _fit(y=y, x=x, method=method) | ||
|
||
|
||
def test2(y): | ||
if isinstance(y, DirectEst): | ||
return 2 | ||
else: | ||
return 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from samplics.apis.sae.area_eblup import _predict_eblup | ||
from samplics.types import AuxVars, DictStrNum, DirectEst, FitStats, Number | ||
|
||
|
||
def predict( | ||
x: AuxVars, | ||
fit_stats: FitStats, | ||
y: DirectEst, | ||
intercept: bool = True, # if True, it adds an intercept of 1 | ||
b_const: DictStrNum | Number = 1.0, | ||
): | ||
return _predict_eblup(x=x, fit_eblup=fit_stats, y=y, intercept=intercept, b_const=b_const) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy as np | ||
import polars as pl | ||
|
||
from samplics.apis import fit | ||
from samplics.types import AuxVars, DirectEst, FitMethod, FitStats | ||
|
||
|
||
def test_fit(): | ||
area = np.linspace(start=1, stop=50, num=50).tolist() | ||
ssize_arr = np.unique(area, return_counts=True) | ||
ssize = dict(zip(ssize_arr[0], ssize_arr[1])) | ||
y_stderr = np.random.uniform(low=1, high=10, size=50) | ||
|
||
x_df = pl.DataFrame( | ||
{ | ||
"x1": np.random.choice(a=[1, 2, 7, 5], size=50).tolist(), | ||
"x2": (13 * np.random.normal(size=50)).tolist(), | ||
} | ||
) | ||
x = AuxVars(x=x_df, domain=area) | ||
|
||
# y = 150 * np.random.normal(size=250) | ||
y = ( | ||
2 * np.random.choice(a=[1, 2, 7, 5], size=50) | ||
+ 3 * (13 * np.random.normal(size=50)) | ||
+ np.random.uniform(low=1, high=5, size=50) | ||
) | ||
|
||
y_hat = DirectEst(est=y, stderr=y_stderr, domain=area, ssize=ssize) | ||
|
||
y_fit = fit(y=y_hat, x=x, method=FitMethod.reml) | ||
|
||
isinstance(y_fit, FitStats) | ||
|
||
|
||
from typing import Protocol | ||
|
||
|
||
class Animal(Protocol): | ||
def speak(): | ||
... | ||
|
||
|
||
class Dog: | ||
def speak(self): | ||
print("I am a dog") | ||
|
||
|
||
class Cat: | ||
def speak(self): | ||
print("I am a cat") | ||
|
||
|
||
def use(animal: Animal): | ||
match animal: | ||
case Dog(): | ||
animal.speak() | ||
case Cat(): | ||
animal.speak() | ||
case _: | ||
print("No animal") | ||
|
||
|
||
miaou = Cat() | ||
|
||
breakpoint() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.