Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 94 additions & 52 deletions gensn/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def log_prob(self, *obs, cond=None):
x, y = obs[: self.prior.n_rvs], obs[self.prior.n_rvs :]
return self.prior(*x, cond=cond) + self.conditional(*y, cond=x)

def factorized_log_prob(self, *obs, cond=None):
x, y = obs[: self.prior.n_rvs], obs[self.prior.n_rvs :]
return self.prior.factorized_log_prob(
*x, cond=cond
) + self.conditional.factorized_log_prob(*y, cond=x)

def forward(self, *obs, cond=None):
return self.log_prob(*obs, cond=cond)

Expand All @@ -38,23 +44,6 @@ def rsample(self, sample_shape=torch.Size([]), cond=None):
return turn_to_tuple(x_samples) + turn_to_tuple(y_samples)


# class DeltaDistribution(nn.Module):
# def __init__(self, value):
# self.value = value

# def log_prob(self, obs, cond=None):
# # TODO: write to deal with more than one rvs
# return torch.log(self.prob(obs, cond=cond))

# def prob(self, obs, cond=None):
# # TODO: write to deal with more than one rvs
# return torch.where(
# torch.equal(obs, parse_attr(self.value, cond=cond)), 0, 1
# ).to(obs.device)

# def sample(self, sample_shape=torch.size([]), cond=None):


class TrainableDistribution(nn.Module, ABC):
"""
Here we are providing the proper abstract base class for the
Expand All @@ -73,19 +62,16 @@ def n_rvs(self):
...

@abstractmethod
def log_prob(self, *obs, cond=None):
...
def log_prob(self, *obs, cond=None): ...

def forward(self, *obs, cond=None):
return self.log_prob(*obs, cond=cond)

@abstractmethod
def sample(self, sample_shape=torch.Size([]), cond=None):
...
def sample(self, sample_shape=torch.Size([]), cond=None): ...

@abstractmethod
def rsample(self, sample_shape=torch.Size([]), cond=None):
...
def rsample(self, sample_shape=torch.Size([]), cond=None): ...


class TrainableDistributionAdapter(nn.Module):
Expand All @@ -110,26 +96,6 @@ def __init__(self, distribution_class, *dist_args, _parameters=None, **dist_kwar
if _parameters is not None:
self.parameter_generator = _parameters

# overwrite extra_repr to include the distribution class
# TODO: consider adding the parameters as well
def extra_repr(self):
repr = f"distribution_class={self.distribution_class!r}"

if self.param_counts > 0:
repr += ", " + ", ".join(
f"{getattr(self, f'_arg{pos}')!r}" for pos in range(self.param_counts)
)

if len(self.param_keys) > 0:
repr += ", " + ", ".join(
f"{k}={getattr(self, k)!r}" for k in self.param_keys
)

if hasattr(self, "parameter_genrator"):
repr += ", " + f"_parameters={self.parameter_generator!r}"

return repr

def distribution(self, cond=None):
cond = turn_to_tuple(cond)

Expand Down Expand Up @@ -162,18 +128,25 @@ def sample(self, sample_shape=torch.Size([]), cond=None):
def rsample(self, sample_shape=torch.Size([]), cond=None):
return self.distribution(cond=cond).rsample(sample_shape=sample_shape)

# overwrite extra_repr to include the distribution class
# TODO: consider adding the parameters as well
def extra_repr(self):
repr = f"distribution_class={self.distribution_class!r}"

# def wrap_with_indep(distribution_class, event_dims=1):
# """
# Wrap the construction of the target distribution `distr_class` with
# D.Independent. The returned function can be used as if it is
# a constructor for an indepdenent version of the distribution
# """
if self.param_counts > 0:
repr += ", " + ", ".join(
f"{getattr(self, f'_arg{pos}')!r}" for pos in range(self.param_counts)
)

# def indep_distr(*args, **kwargs):
# return D.Independent(distribution_class(*args, **kwargs), event_dims)
if len(self.param_keys) > 0:
repr += ", " + ", ".join(
f"{k}={getattr(self, k)!r}" for k in self.param_keys
)

# return indep_distr
if hasattr(self, "parameter_genrator"):
repr += ", " + f"_parameters={self.parameter_generator!r}"

return repr


class IndependentTrainableDistributionAdapter(TrainableDistributionAdapter):
Expand All @@ -193,6 +166,12 @@ def __init__(
def distribution(self, cond=None):
return D.Independent(super().distribution(cond=cond), self.event_dims)

def factorized_distribution(self, cond=None):
return super().distribution(cond=cond)

def factorized_log_prob(self, *obs, cond=None):
return self.factorized_distribution(cond=cond).log_prob(*obs)

def extra_repr(self):
return super().extra_repr() + f", event_dims={self.event_dims}"

Expand All @@ -212,6 +191,9 @@ def forward(self, *obs, cond=None):
def log_prob(self, *obs, cond=None):
return self.trainable_distribution.log_prob(*obs, cond=cond)

def factorized_log_prob(self, *obs, cond=None):
return self.trainable_distribution.factorized_log_prob(*obs, cond=cond)

def sample(self, sample_shape=torch.Size([]), cond=None):
return self.trainable_distribution.sample(sample_shape=sample_shape, cond=cond)

Expand Down Expand Up @@ -421,3 +403,63 @@ def __init__(self, loc=None, scale=None, _parameters=None, event_dims=1):
**kwargs,
_parameters=_parameters,
)


class IndependentPoisson(WrappedTrainableDistribution):
"""
A trainable distribution that wraps a D.Independent(D.Poisson) distribution
"""

def __init__(self, rate=None, _parameters=None, event_dims=1):
"""
Args:
rate : torch.Tensor or nn.Parameter or None
The rate parameter of the poisson distribution. If None, _parameters must be provided
_parameters : callable or None
A function that takes in the conditioning variable and returns a dictionary of parameters
for the poisson distribution. If None, rate must be provided
event_dims : int
The number of dimensions to be considered as the event dimensions
"""
super().__init__()
if rate is None and _parameters is None:
raise ValueError("If rate is unspecificed, _parameters must be provided")
kwargs = {}
if rate is not None:
kwargs["rate"] = rate
self.trainable_distribution = IndependentTrainableDistributionAdapter(
D.Poisson,
event_dims=event_dims,
**kwargs,
_parameters=_parameters,
)


# class DeltaDistribution(nn.Module):
# def __init__(self, value):
# self.value = value

# def log_prob(self, obs, cond=None):
# # TODO: write to deal with more than one rvs
# return torch.log(self.prob(obs, cond=cond))

# def prob(self, obs, cond=None):
# # TODO: write to deal with more than one rvs
# return torch.where(
# torch.equal(obs, parse_attr(self.value, cond=cond)), 0, 1
# ).to(obs.device)

# def sample(self, sample_shape=torch.size([]), cond=None):


# def wrap_with_indep(distribution_class, event_dims=1):
# """
# Wrap the construction of the target distribution `distr_class` with
# D.Independent. The returned function can be used as if it is
# a constructor for an indepdenent version of the distribution
# """

# def indep_distr(*args, **kwargs):
# return D.Independent(distribution_class(*args, **kwargs), event_dims)

# return indep_distr
7 changes: 7 additions & 0 deletions gensn/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def log_prob(self, *obs, cond=None):
x, logL = self.transform(*obs, cond=cond)
return self.base_distribution.log_prob(*turn_to_tuple(x), cond=cond) + logL

def factorized_log_prob(self, *obs, cond=None):
x, logL = self.transform.factorized_forward(*obs, cond=cond)
return (
self.base_distribution.factorized_log_prob(*turn_to_tuple(x), cond=cond)
+ logL
)

def sample(self, sample_shape=torch.Size([]), cond=None):
samples = self.base_distribution.sample(sample_shape=sample_shape, cond=cond)
y, _ = self.transform.inverse(samples, cond=cond)
Expand Down
75 changes: 70 additions & 5 deletions gensn/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,27 @@


class TransformedParameter(nn.Module):
"""
A module for applying a transformation function to a torch.nn.Parameter.
This can be useful for ensuring that a parameter adheres to certain constraints
(e.g., positivity) after transformation. The forward method applies the transformation
function to the parameter.

Attributes:
parameter (nn.Parameter): The parameter to be transformed.
transform_fn (Callable): The transformation function to be applied to the parameter.
value (torch.Tensor): The transformed parameter value.

Args:
tensor (torch.Tensor): The initial tensor to be wrapped as an nn.Parameter.
transform_fn (Callable, optional): The transformation function to be applied.
If None, the identity function is used, meaning no transformation is applied.

Returns:
torch.Tensor: The transformed parameter as a tensor, with shape dependent on the
transformation function and the initial tensor shape.
"""

def __init__(self, tensor, transform_fn=None):
super().__init__()
self.parameter = nn.Parameter(tensor)
Expand All @@ -19,17 +40,43 @@ def forward(self, *args):


class Covariance(nn.Module):
def __init__(self, n_dims, rank=None, eps=1e-16):
"""
A module to represent a covariance matrix as a parameterized entity in a neural network.
This implementation ensures the covariance matrix is positive semi-definite by constructing
it as A @ A.T + epsilon * I, where A is a parameter matrix, and epsilon is a small positive
constant added for numerical stability.

Attributes:
A (nn.Parameter): The parameter matrix used to construct the covariance matrix.
eps (float): A small positive constant added to the diagonal for numerical stability.
value (torch.Tensor): The covariance matrix.

Args:
n_dims (int): The dimensionality of the square covariance matrix.
rank (int, optional): The rank of the matrix A used in constructing the covariance matrix.
If None, it defaults to n_dims, resulting in a full-rank covariance matrix.

Returns:
torch.Tensor: The covariance matrix, with shape (n_dims, n_dims).
"""

def __init__(self, n_dims, rank=None):
super().__init__()
if rank is None:
rank = n_dims
self.n_dims = n_dims
self.rank = rank
self.eps = eps
self.A = nn.Parameter(torch.randn(n_dims, rank))
self.A = nn.Parameter(
torch.randn(n_dims, rank) * 0.1
) # Initialize with small values to achieve positive definiteness
# self.eps = torch.finfo(self.A.dtype).eps
self.eps = 1e-4 # For numerical stability of positive definiteness

def forward(self, *args):
return self.A @ self.A.T + torch.eye(self.n_dims) * self.eps
return (
self.A @ self.A.T
+ torch.eye(self.n_dims).to(device=self.A.device) * self.eps
)

@property
def value(self):
Expand All @@ -38,11 +85,29 @@ def value(self):

# TODO: generalize this so that positiveness can arise from other functions
class PositiveDiagonal(nn.Module):
"""
A module for representing a diagonal matrix with positive diagonal elements. This is achieved
by squaring the elements of a parameter vector D and adding a small positive constant epsilon
to each squared element for numerical stability.

Attributes:
D (nn.Parameter): The parameter vector whose squared elements form the diagonal of the matrix.
eps (float): A small positive constant added to each element of the squared D for numerical stability.
value (torch.Tensor): The resulting diagonal matrix with positive diagonal elements.

Args:
n_dims (int): The dimensionality of the square diagonal matrix.
eps (float, optional): A small positive constant added for numerical stability. Defaults to 1e-16.

Returns:
torch.Tensor: The diagonal matrix with positive diagonal elements, with shape (n_dims, n_dims).
"""

def __init__(self, n_dims, eps=1e-16):
super().__init__()
self.n_dims = n_dims
self.eps = eps
self.D = nn.Parameter(torch.randn(n_dims))
self.eps = torch.finfo(self.D.dtype).eps

def forward(self, *args):
return torch.diag(self.D**2 + self.eps)
Expand Down
Loading