Skip to content

Conversation

@jamespooley
Copy link
Contributor

@jamespooley jamespooley commented Oct 19, 2025

Description

This is a draft PR being opened to let you know the bug is being worked on.

The code changes here get rid of post-fitting serialization errors like TypeError: Object of type SoftPlusHSGP is not JSON serializable when passing HSGP instances to the time_varying_intercept or time_varying_media parameters of the multidimensional.MMM class.

There’s not an open issue for this yet, but there’s also a related deserialization issue when passing arbitrary HSGP instances. The tl;dr is that after fitting and .save()-ing a model with them, there’s an error when .load()-ing the model.

One (admittedly not great) hotfix to get rid of that error is to replace

time_varying_intercept: Annotated[
            StrictBool | InstanceOf[HSGPBase],

in this block of code (and similarly with time_verying_media) with

time_varying_intercept: Annotated[
            StrictBool | InstanceOf[HSGPBase] | dict,

But even though this gets rid of the .load()-ing error, the time-varying structure for the saved model isn’t recovered after the .load(), and the time-varying components have the default configuration.

So I’m wondering if the logic there needs to be updated to use HSGPKwargs logic like that here or maybe something new like hsgp_from_dict() like how saturation_from_dict() is used?

If so, that’s likely another PR, but I wanted to raise it here since they’re both related to the “pass arbitrary HSGP instances” functionality.

Related Issue

Checklist


📚 Documentation preview 📚: https://pymc-marketing--2016.org.readthedocs.build/en/2016/

@github-actions github-actions bot added the MMM label Oct 19, 2025
@codecov
Copy link

codecov bot commented Oct 20, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 92.51%. Comparing base (e9b1b45) to head (fe4ea45).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2016       +/-   ##
===========================================
+ Coverage   36.79%   92.51%   +55.72%     
===========================================
  Files          68       68               
  Lines        9373     9396       +23     
===========================================
+ Hits         3449     8693     +5244     
+ Misses       5924      703     -5221     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@williambdean williambdean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @jamespooley

Can you take a look to see if the model will load (adding a test would be appreciated)

The loading logic is here:

"time_varying_intercept": json.loads(
attrs.get("time_varying_intercept", "false")
),

@github-actions github-actions bot added the tests label Oct 25, 2025
@jamespooley
Copy link
Contributor Author

@williambdean Drafted some tests to show that the loading logic itself...

"time_varying_intercept": json.loads(
attrs.get("time_varying_intercept", "false")
),

"time_varying_media": json.loads(attrs.get("time_varying_media", "false")),

...works fine (i.e., returns the same information in the user-specific HSGP instance) with the updated code.

But I also added a temporary test to clarify the issue that when calling MMM.load() on a successfully saved MMM with a HSGP instance passed to either of the time_varying_* parameters raises a ValidationError since that loading logic will return a dict, which currently isn't then converted into an instance of HSGPBase.

As mentioned, that error will go away by adding a | dict to the relevant portions of the following:

time_varying_intercept: Annotated[
StrictBool | InstanceOf[HSGPBase],
Field(
description=(
"Whether to use a time-varying intercept, or pass an HSGP instance "
"(e.g., SoftPlusHSGP) specifying dims and priors."
),
),
] = False,
time_varying_media: Annotated[
StrictBool | InstanceOf[HSGPBase],

But it seems like the proper solution would be something along these lines, just for the time_varying_* parameters that can accept HSGPs:

"adstock": adstock_from_dict(json.loads(attrs["adstock"])),
"saturation": saturation_from_dict(json.loads(attrs["saturation"])),

Any thoughts here?

@williambdean
Copy link
Contributor

But it seems like the proper solution would be something along these lines, just for the time_varying_* parameters that can accept HSGPs:

"adstock": adstock_from_dict(json.loads(attrs["adstock"])),
"saturation": saturation_from_dict(json.loads(attrs["saturation"])),

Any thoughts here?

Good point. Let's just start with the most obvious cases at the moment. We can make a function like that hsgp_from_dict which will just wrap these cases at the moment. Let's continue with further implementation in additional PRs

@jamespooley
Copy link
Contributor Author

jamespooley commented Oct 29, 2025

@williambdean Sounds good to me.

The following wouldn’t be pretty and I wouldn’t be super proud of it, but because there isn’t currently an equivalent of {"lookup_name": "logistic", ...} for HSGPs after .to_dict()-ing like there is for SaturationTransformation, modifying the HSGP serialization to be something like this…

attrs["time_varying_media"] = json.dumps(
    self.time_varying_media
    if not isinstance(self.time_varying_media, HSGPBase)
    else {
        **self.time_varying_media.to_dict(),
        **{"hsgp_class": self.time_varying_media.__class__.__name__},
    }
)

…and then a dead simple implementation of hsgp_from_dict like the following would likely hotfix the the MMM.load() issue.

from pymc_marketing.mmm.hsgp import HSGP, HSGPPeriodic, SoftPlusHSGP

def hsgp_from_dict(data: dict | bool):
    if isinstance(data, bool):
        return data

    HSGP_CLASSES= {
        "HSGP": HSGP,
        "SoftPlusHSGP": SoftPlusHSGP,
        "HSGPPeriodic": HSGPPeriodic,
    }

    data = data.copy()
    cls = HSGP_CLASSES[data.pop("hsgp_class")]

    return cls.from_dict(data)

It's certainly hacky, but don’t know if you think an approach along these lines is too hacky even for the first pass at this. But it could work until a proper implementation is tackled in future PRs.

If you're aligned on something like ☝️ as a "good enough for now" approach, I can clean up and test the implementation, update the code/tests, and then finalize this PR.

Copy link
Contributor

@williambdean williambdean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks for these alterations!

@williambdean
Copy link
Contributor

I see this is still a draft PR, @jamespooley. I will merge and we can continue working in a separate PR if needed.

@williambdean williambdean marked this pull request as ready for review November 3, 2025 12:19
@williambdean williambdean added the enhancement New feature or request label Nov 3, 2025
@williambdean williambdean changed the title Add .to_dict() logic to fix serialization errors when passing HSGP instances Pass arbitrary HSGP instance to time_varying_intercept/media in MMM Nov 3, 2025
@williambdean williambdean merged commit f320c44 into pymc-labs:main Nov 3, 2025
38 checks passed
@williambdean williambdean added bug Something isn't working and removed enhancement New feature or request labels Nov 3, 2025
@williambdean williambdean changed the title Pass arbitrary HSGP instance to time_varying_intercept/media in MMM fix: serialization for arbitrary HSGP instance to time_varying_intercept/media in MMM Nov 3, 2025
@jamespooley
Copy link
Contributor Author

Sounds good, @williambdean. Glad I could make a tiny contribution here, and I'd happily pitch in on future PRs related to this.

@williambdean
Copy link
Contributor

This is no tiny PR! Thanks @jamespooley 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working MMM tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants