Skip to content

Commit

Permalink
added predcit
Browse files Browse the repository at this point in the history
  • Loading branch information
MamadouSDiallo committed Jan 8, 2024
1 parent 75070d7 commit 64bb363
Show file tree
Hide file tree
Showing 14 changed files with 211 additions and 55 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ addopts = "--ignore=tests/sae --ignore=tests/types"

[tool.ruff]
src = ["src", "tests"]
# extend-exclude = ["tests"]

# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
select = ["D", "E", "F", "W", "I001"]
Expand Down
6 changes: 5 additions & 1 deletion src/samplics/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
__all__ = ["fit_eblup", "predict_eblup"]
from samplics.apis.fit import fit
from samplics.apis.predict import predict


__all__ = ["fit", "predict"]
14 changes: 7 additions & 7 deletions src/samplics/apis/base/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import numpy as np
import polars as pl

from samplics.types import Array, AuxVars, DirectEst, FitMethod, FitStats, Number
from samplics.types import AuxVars, DepVars, FitMethod, FitStats, Number


# Fitting a EBLUP model
def _fit(
y: DirectEst | Array,
y: DepVars,
x: AuxVars,
method: FitMethod,
intercept: bool = True, # if True, it adds an intercept of 1
Expand Down Expand Up @@ -91,7 +91,7 @@ def _fit(
def _iterative_fisher_scoring(
method: FitMethod,
# area: np.ndarray,
y: DirectEst,
y: DepVars,
x: AuxVars,
sig_e: dict,
# b_const: np.ndarray,
Expand Down Expand Up @@ -146,7 +146,7 @@ def _iterative_fisher_scoring(

def _partial_derivatives_fh(
method: FitMethod,
y: DirectEst,
y: DepVars,
x: AuxVars,
sig2_e: dict,
sig2_v: Number,
Expand Down Expand Up @@ -180,7 +180,7 @@ def _partial_derivatives_fh(

def _partial_derivatives_ml(
method: FitMethod,
y: DirectEst,
y: DepVars,
x: AuxVars,
sig2_e: dict,
sig2_v: Number,
Expand Down Expand Up @@ -217,7 +217,7 @@ def _partial_derivatives_ml(

def _partial_derivatives_reml(
method: FitMethod,
y: DirectEst,
y: DepVars,
x: AuxVars,
sig2_e: dict,
sig2_v: Number,
Expand Down Expand Up @@ -256,7 +256,7 @@ def _partial_derivatives_reml(

def _partial_derivatives(
method: FitMethod,
y: DirectEst,
y: DepVars,
x: AuxVars,
sig2_e: dict,
sig2_v: Number,
Expand Down
25 changes: 13 additions & 12 deletions src/samplics/apis/fit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from samplics.apis.base.fitting import _fit
from samplics.types.containers import Array, DirectEst, FitMethod


def fit(y: DirectEst | Array, x, method):
match y:
case DirectEst():
y_vec = y.est
return _fit(y=y_vec, x=x, method=FitMethod.fh)
case Array():
return _fit(y=y, x=x, method=FitMethod.fh)
case _:
raise TypeError("The type of `y` is not supported!")
from samplics.types import AuxVars, DirectEst, FitMethod


def fit(y: DirectEst, x: AuxVars, method: FitMethod):
# breakpoint()
return _fit(y=y, x=x, method=method)


def test2(y):
if isinstance(y, DirectEst):
return 2
else:
return 3
12 changes: 12 additions & 0 deletions src/samplics/apis/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from samplics.apis.sae.area_eblup import _predict_eblup
from samplics.types import AuxVars, DictStrNum, DirectEst, FitStats, Number


def predict(
x: AuxVars,
fit_stats: FitStats,
y: DirectEst,
intercept: bool = True, # if True, it adds an intercept of 1
b_const: DictStrNum | Number = 1.0,
):
return _predict_eblup(x=x, fit_eblup=fit_stats, y=y, intercept=intercept, b_const=b_const)
4 changes: 2 additions & 2 deletions src/samplics/apis/sae/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from samplics.apis.sae.area_eblup import _log_likelihood
# from samplics.apis.sae.area_eblup import _log_likelihood
from samplics.types.errors import (
CertaintyError,
DimensionError,
Expand All @@ -23,5 +23,5 @@
"MethodError",
"fit_eblup",
"predict_eblup",
"_log_likelihood",
# "_log_likelihood",
]
2 changes: 0 additions & 2 deletions src/samplics/apis/sae/area_eblup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
EblupEst,
FitMethod,
FitStats,
Mse,
Number,
)

Expand All @@ -22,7 +21,6 @@ def _predict_eblup(
x: AuxVars,
fit_eblup: FitStats,
y: DirectEst,
mse: Mse | list[Mse] | None = None,
intercept: bool = True, # if True, it adds an intercept of 1
b_const: DictStrNum | Number = 1.0,
) -> EblupEst:
Expand Down
17 changes: 8 additions & 9 deletions src/samplics/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,29 @@
FitStats,
)
from samplics.types.options import FitMethod, Mse


# from samplics.types.protocols import GlmmFitStats
from samplics.types.protocols import DepVars, IndepVars


__all__ = [
"DF",
"Array",
"Series",
"Number",
"StringNumber",
"AuxVars",
"DepVars",
"DF",
"DictStrNum",
"DictStrInt",
"DictStrFloat",
"DictStrBool",
"DirectEst",
"AuxVars",
"EblupEst",
"EblupFit",
"EbUnitModel",
"EbEst",
"EbFit",
"FitStats",
"FitMethod",
# "GlmmFitStats",
"IndepVars",
"Mse",
"Number",
"Series",
"StringNumber",
]
10 changes: 5 additions & 5 deletions src/samplics/types/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class AuxVars:
# TODO: Add missing values functionality
x: dict
nrecords: dict
nrecords: dict | Number
record_id: dict | None
domains: list | None
uid: int = int(dt.datetime.now(tz=dt.timezone.utc).strftime("%Y%m%d%H%M%S") + str(int(1e16 * rand.random())))
Expand Down Expand Up @@ -78,12 +78,12 @@ def __init__(
auxdata_dict[d] = auxdata_dict[d].drop("__domain")
record_id_dict[d] = record_id_dict[d]["__record_id"].to_list()
nrecords[d] = auxdata_dict[d].shape[0]

auxdata_dict = {k: auxdata_dict[k].to_dict(as_series=False) for k in auxdata_dict}
else:
auxdata_dict = x.insert_at_idx(0, pl.Series(record_id).alias("__record_id"))
record_id_dict = x.insert_at_idx(0, pl.Series(record_id).alias("__record_id"))
nrecords = x.shape[0]
auxdata_dict = x.to_dict()
record_id_dict = pl.DataFrame([__record_id], schema=["__record_id"]).to_dict()

auxdata_dict = {k: auxdata_dict[k].to_dict(as_series=False) for k in auxdata_dict}
# record_id_dict = {
# k: record_id_dict[k].to_dict(as_series=False) for k in record_id_dict
# }
Expand Down
53 changes: 42 additions & 11 deletions src/samplics/types/protocols.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,72 @@
from typing import Protocol
from typing import Protocol, runtime_checkable


# NumpyArray = TypeVar("np.ndarray", bound=np.ndarray)
# PandasDF = TypeVar("pd.DataFrame", bound=pd.DataFrame)
# PolarsDF = TypeVar("pl.DataFrame", bound=pl.DataFrame)


@runtime_checkable
class Missing(Protocol):
pass
...


@runtime_checkable
class DepVars(Protocol):
# TODO: Add missing values functionality
y: dict

@property
def nrecords():
...

# Design matrix X (fixed effect) in GLMM
class IndepVarsX(Protocol):
@property
def nvars():
...

def to_numpy():
...

def to_polars():
...

def to_pandas():
...


@runtime_checkable
class IndepVars(Protocol):
# TODO: Add missing values functionality
x: dict

@property
def nrecords():
...

# Design matrix Z (random effect) in GLMM
class IndepVarsZ(Protocol):
# TODO: Add missing values functionality
z: dict
@property
def nvars():
...

def to_numpy():
...

def to_polars():
...

def to_pandas():
...


@runtime_checkable
class ToDataFrame(Protocol):
def to_numpy():
pass
...

def to_polars():
pass
...

def to_pandas():
pass
...


# class FitStats(Protocol):
Expand Down
66 changes: 66 additions & 0 deletions tests/apis/core/test_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
import polars as pl

from samplics.apis import fit
from samplics.types import AuxVars, DirectEst, FitMethod, FitStats


def test_fit():
area = np.linspace(start=1, stop=50, num=50).tolist()
ssize_arr = np.unique(area, return_counts=True)
ssize = dict(zip(ssize_arr[0], ssize_arr[1]))
y_stderr = np.random.uniform(low=1, high=10, size=50)

x_df = pl.DataFrame(
{
"x1": np.random.choice(a=[1, 2, 7, 5], size=50).tolist(),
"x2": (13 * np.random.normal(size=50)).tolist(),
}
)
x = AuxVars(x=x_df, domain=area)

# y = 150 * np.random.normal(size=250)
y = (
2 * np.random.choice(a=[1, 2, 7, 5], size=50)
+ 3 * (13 * np.random.normal(size=50))
+ np.random.uniform(low=1, high=5, size=50)
)

y_hat = DirectEst(est=y, stderr=y_stderr, domain=area, ssize=ssize)

y_fit = fit(y=y_hat, x=x, method=FitMethod.reml)

isinstance(y_fit, FitStats)


from typing import Protocol


class Animal(Protocol):
def speak():
...


class Dog:
def speak(self):
print("I am a dog")


class Cat:
def speak(self):
print("I am a cat")


def use(animal: Animal):
match animal:
case Dog():
animal.speak()
case Cat():
animal.speak()
case _:
print("No animal")


miaou = Cat()

breakpoint()
11 changes: 6 additions & 5 deletions tests/apis/sae/test_area_eblup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import polars as pl
import pytest

from samplics.apis.sae import _fit_eblup, _predict_eblup
from samplics.apis import fit
from samplics.apis.sae.area_eblup import _predict_eblup

# from samplics.apis.sae import _log_likelihood, fit_eblup, predict_eblup
from samplics.types import AuxVars, DirectEst, FitMethod, Mse
Expand Down Expand Up @@ -49,17 +50,17 @@
auxvars = AuxVars(x=x, domain=area)

# Fit the linear mixed model
fit_ml = _fit_eblup(y=yhat, x=auxvars, method=FitMethod.ml)
fit_reml = _fit_eblup(y=yhat, x=auxvars, method=FitMethod.reml)
fit_ml = fit(y=yhat, x=auxvars, method=FitMethod.ml)
fit_reml = fit(y=yhat, x=auxvars, method=FitMethod.reml)
# fit_fh = fit_eblup(y=yhat, x=auxvars, method=FitMethod.fh)
# breakpoint()

# Predict the small area estimates
est_milk_reml = _predict_eblup(x=auxvars, fit_eblup=fit_reml, y=yhat, mse=Mse.taylor)
est_milk_reml = _predict_eblup(x=auxvars, fit_eblup=fit_reml, y=yhat)

# est_milk_reml.fit_stats.log_llike

breakpoint()
# breakpoint()


@pytest.mark.skipif(sys.platform == "linux", reason="Skip dev version on Github (Linux)")
Expand Down
Loading

0 comments on commit 64bb363

Please sign in to comment.