-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
Hello, a very smart work. Could you please share the implementation in TDMPC-2?
This is the original MLP.
class NormedLinear(nn.Linear):
def __init__(self, *args, dropout=0., act=None, **kwargs):
super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features)
if act is None:
act = nn.Mish(inplace=False)
self.act = act
self.dropout = nn.Dropout(dropout, inplace=False) if dropout else None
def forward(self, x):
x = super().forward(x)
if self.dropout:
x = self.dropout(x)
return self.act(self.ln(x))
This is my implementation, but I'm not sure whether it's correct.
class NormedResidualLinear(nn.Linear):
def __init__(self, *args, dropout=0., act=None, **kwargs):
super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features)
self.act = nn.ReLU(inplace=False)
self.dropout = nn.Dropout(dropout, inplace=False) if dropout else None
def forward(self, x):
x = super().forward(x)
res = x
if self.dropout:
x = self.dropout(x)
return res + self.act(self.ln(x))
Metadata
Metadata
Assignees
Labels
No labels