Skip to content

Conversation

@HansBug
Copy link
Member

@HansBug HansBug commented Oct 13, 2023

Description

import torch.nn

from treevalue import FastTreeValue


class MyModule(torch.nn.Module):
    def __init__(self, p):
        torch.nn.Module.__init__(self)
        self.relu = torch.nn.ReLU()
        self.p = torch.tensor(p)

    def forward(self, x):
        return self.relu(x + self.p)


class FullModule(torch.nn.Module):
    def __init__(self, **kwargs):
        torch.nn.Module.__init__(self)
        self._module_dict = torch.nn.ModuleDict({
            key: MyModule(value)
            for key, value in kwargs.items()
        })
        self._module_tv = FastTreeValue(self._module_dict)

    def forward(self, x):
        return self._module_tv(x)


model = FullModule(a=1, b=2)
print(model)

input_ = FastTreeValue({
    'a': torch.randn(3, 4),
    'b': torch.randn(2, 3),
})
print(model(input_))

TODO

  • Try to reduce the lines of ModuleDict&TreeValue usage

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@HansBug HansBug added the enhancement New feature or request label Oct 13, 2023
@HansBug HansBug self-assigned this Oct 13, 2023
@codecov
Copy link

codecov bot commented Oct 13, 2023

Codecov Report

Merging #88 (ab350d6) into main (bd171e4) will decrease coverage by 0.09%.
The diff coverage is 91.89%.

@@            Coverage Diff             @@
##             main      #88      +/-   ##
==========================================
- Coverage   98.88%   98.80%   -0.09%     
==========================================
  Files          43       43              
  Lines        2792     2837      +45     
==========================================
+ Hits         2761     2803      +42     
- Misses         31       34       +3     
Flag Coverage Δ
unittests 98.80% <91.89%> (-0.09%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
treevalue/tree/tree/__init__.py 100.00% <100.00%> (ø)
treevalue/tree/tree/tree.pyx 97.97% <96.66%> (-0.15%) ⬇️
treevalue/tree/integration/torch.py 90.47% <66.66%> (+23.80%) ⬆️

... and 6 files with indirect coverage changes

@HansBug HansBug requested a review from PaParaZz1 October 20, 2023 08:12
@HansBug HansBug merged commit 83bca17 into main Oct 22, 2023
@HansBug HansBug deleted the dev/dict branch October 22, 2023 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants