Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev(hansbug): add register support for treevalue #78

Merged
merged 2 commits into from
Feb 26, 2023
Merged

dev(hansbug): add register support for treevalue #78

merged 2 commits into from
Feb 26, 2023

Conversation

HansBug
Copy link
Member

@HansBug HansBug commented Feb 26, 2023

Description

import jax
import numpy as np

from treevalue import FastTreeValue


@jax.jit
def double(x):
    return x * 2


data = {
    'a': np.random.randint(0, 10, (2, 3)),
    'b': {
        'x': 233,
        'y': np.random.randn(2, 3)
    }
}
t = FastTreeValue(data)

if __name__ == '__main__':
    print(t)
    print(double(t))

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@HansBug HansBug requested a review from PaParaZz1 February 26, 2023 13:40
@HansBug HansBug self-assigned this Feb 26, 2023
@codecov
Copy link

codecov bot commented Feb 26, 2023

Codecov Report

Merging #78 (ee3184d) into main (ef52316) will decrease coverage by 0.07%.
The diff coverage is 94.73%.

@@            Coverage Diff             @@
##             main      #78      +/-   ##
==========================================
- Coverage   99.01%   98.95%   -0.07%     
==========================================
  Files          36       39       +3     
  Lines        2544     2578      +34     
==========================================
+ Hits         2519     2551      +32     
- Misses         25       27       +2     
Flag Coverage Δ
unittests 98.95% <94.73%> (-0.07%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
treevalue/tree/integration/jax.py 92.85% <92.85%> (ø)
treevalue/tree/integration/cjax.pyx 94.44% <94.44%> (ø)
treevalue/tree/__init__.py 100.00% <100.00%> (ø)
treevalue/tree/integration/__init__.py 100.00% <100.00%> (ø)
treevalue/tree/tree/flatten.pyx 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@HansBug HansBug merged commit b61551a into main Feb 26, 2023
@HansBug HansBug deleted the dev/jax branch February 26, 2023 15:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant