How Pytorch Nn.module Save Submodule
I have some question about how pytorch nn.module works import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__()
Solution 1:
I will try to keep it simple.
Every time you create a new item in the class Net
for instance: self.sub_module = nn.Linear(10, 5)
it calls the method __setattr__
of its parent class, in this case nn.Module
. Then, inside __setattr__
method, the parameters are stored to the dict they belong. In this case since nn.Linear
is a module, it is stored to the _modules
dict.
Here is the piece of code that does this inside the Module
class https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L389
Post a Comment for "How Pytorch Nn.module Save Submodule"