Skip to content

Commit

Permalink
Add ConsScaleLayer, MLP for flows, Clamp exp (deepchem#3964)
Browse files Browse the repository at this point in the history
* add flows file

* add flows to init

* add test file for flows

* fix doctest of realnvp, fixed deprecation notice for torch.nn.init.normal

* add masked affine flow class

* add tests for masked affine flow

* add maskedaffineflow to init

* replace | with union

* added docs to flows

* add flows to layer docs

* remove Affine from layers

* add Flows to init

* fix yapf

* add to docs

* fixed gan docs rendering

* add actnorm layer and more docs for Flow

* add tests for actnorm layer

* fix init yapf

* Add ClampExp and ConstScale Layer

* add MLP Flow
  • Loading branch information
shreyasvinaya authored May 1, 2024
1 parent 867eece commit 9f70cfa
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 2 deletions.
2 changes: 1 addition & 1 deletion deepchem/models/torch_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from deepchem.models.torch_models.acnn import AtomConvModel
from deepchem.models.torch_models.progressive_multitask import ProgressiveMultitask, ProgressiveMultitaskModel
from deepchem.models.torch_models.text_cnn import TextCNNModel
from deepchem.models.torch_models.flows import Flow, Affine, MaskedAffineFlow, ActNorm
from deepchem.models.torch_models.flows import Flow, Affine, MaskedAffineFlow, ActNorm, ClampExp, ConstScaleLayer, MLP_flow
from deepchem.models.torch_models.unet import UNet, UNetModel
try:
from deepchem.models.torch_models.dmpnn import DMPNN, DMPNNModel
Expand Down
186 changes: 186 additions & 0 deletions deepchem/models/torch_models/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,189 @@ def inverse(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self.shift.data = z.mean(dim=self.batch_dims, keepdim=True).data
self.data_dep_init_done = torch.tensor(1.0)
return super().inverse(z)


class ClampExp(nn.Module):
"""
A non Linearity layer that clamps the input tensor by taking the minimum of the
exponential of the input multiplied by a lambda parameter and 1.
.. math:: f(x) = min(exp(\lambda * x), 1)
Example
-------
>>> import torch
>>> from deepchem.models.torch_models.flows import ClampExp
>>> lambda_param = 1.0
>>> clamp_exp = ClampExp(lambda_param)
>>> input = torch.tensor([-1 ,0.5, 0.6, 0.7])
>>> clamp_exp(input)
tensor([0.3679, 1.0000, 1.0000, 1.0000])
"""

def __init__(self, lambda_param: float = 1.0) -> None:
"""
Initializes the ClampExp layer
Parameters
----------
lambda_param : float
Lambda parameter for the ClampExp layer
"""

self.lambda_param = lambda_param
super(ClampExp, self).__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ClampExp layer
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Transformed tensor according to ClampExp layer with the shape of 'x'.
"""
one = torch.tensor(1.0, device=x.device, dtype=x.dtype)
return torch.min(torch.exp(self.lambda_param * x), one)


class ConstScaleLayer(nn.Module):
"""
This layer scales the input tensor by a fixed factor
Example
-------
>>> import torch
>>> from deepchem.models.torch_models.flows import ConstScaleLayer
>>> scale = 2.0
>>> const_scale = ConstScaleLayer(scale)
>>> input = torch.tensor([1, 2, 3])
>>> const_scale(input)
tensor([2., 4., 6.])
"""

def __init__(self, scale: float = 1.0):
"""
Initializes the ConstScaleLayer
Parameters
----------
scale : float
Scaling factor
"""
super().__init__()
self.scale = torch.tensor(scale)

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass of the ConstScaleLayer
Parameters
----------
input : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Scaled tensor
"""
return input * self.scale


class MLP_flow(nn.Module):
"""
A Multi-Layer Perceptron (MLP) model for normalizing flows that is
used as a part of a Normalizing Flow model.
It is a modified version of the MLP model from `deepchem/deepchem/models/torch_models/layers.py`
to handle multiple layers
Example
-------
>>> import torch
>>> from deepchem.models.torch_models.flows import MLP_flow
>>> layers = [2, 4, 4, 2]
>>> mlp_flow = MLP_flow(layers)
>>> input = torch.tensor([1., 2.])
>>> output = mlp_flow(input)
>>> output.shape
torch.Size([2])
"""

def __init__(
self,
layers: list,
leaky: float = 0.0,
score_scale: Optional[float] = None,
output_fn=None,
output_scale: Optional[float] = None,
init_zeros: bool = False,
dropout: Optional[float] = None,
):
"""
Initializes the MLP_flow model
Parameters
----------
layers : list
List of layer sizes from start to end
leaky : float, optional default 0.0
Slope of the leaky part of the ReLU, if 0.0, standard ReLU is used
score_scale : float, optional
Factor to apply to the scores, i.e. output before output_fn
output_fn : str, optional
Function to be applied to the output, either None, "sigmoid", "relu", "tanh", or "clampexp"
output_scale : float, optional
Rescale outputs if output_fn is specified, i.e. scale * output_fn(out / scale)
init_zeros : bool, optional
Flag, if true, weights and biases of last layer are initialized with zeros
(helpful for deep models, see arXiv 1807.03039)
dropout : float, optional
If specified, dropout is done before last layer; if None, no dropout is done
"""
super().__init__()
net = nn.ModuleList([])
for k in range(len(layers) - 2):
net.append(nn.Linear(layers[k], layers[k + 1]))
net.append(nn.LeakyReLU(leaky))
if dropout is not None:
net.append(nn.Dropout(p=dropout))
net.append(nn.Linear(layers[-2], layers[-1]))
if init_zeros:
nn.init.zeros_(net[-1].weight)
nn.init.zeros_(net[-1].bias)
if output_fn is not None:
if score_scale is not None:
net.append(ConstScaleLayer(score_scale))
if output_fn == "sigmoid":
net.append(nn.Sigmoid())
elif output_fn == "relu":
net.append(nn.ReLU())
elif output_fn == "tanh":
net.append(nn.Tanh())
elif output_fn == "clampexp":
net.append(ClampExp())
else:
NotImplementedError("This output function is not implemented.")
if output_scale is not None:
net.append(ConstScaleLayer(output_scale))
self.net = nn.Sequential(*net)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the MLP_flow model
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Transformed tensor according to the MLP_flow model with the shape of 'x'
"""
return self.net(x)
56 changes: 55 additions & 1 deletion deepchem/models/torch_models/tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from deepchem.models.torch_models.flows import Affine, MaskedAffineFlow, ActNorm
from deepchem.models.torch_models.flows import Affine, MaskedAffineFlow, ActNorm, ClampExp, ConstScaleLayer, MLP_flow
has_torch = True
except:
has_torch = False
Expand Down Expand Up @@ -101,3 +101,57 @@ def test_actnorm():

assert np.array_equal(log_det_jacobian, value)
assert np.any(inverse_log_det_jacobian)


@unittest.skipIf(not has_torch, 'torch is not installed')
@pytest.mark.torch
def test_clampexp():
"""
This test evaluates the clampexp function.
"""
lambda_param_list = [0.1, 0.5, 1, 2, 5, 10]

tensor = torch.tensor([-1, 0.5, 0.6, 0.7])
outputs = {
0.1: [0.9048, 1.0000, 1.0000, 1.0000],
0.5: [0.6065, 1.0000, 1.0000, 1.0000],
1: [0.3679, 1.0000, 1.0000, 1.0000],
2: [0.1353, 1.0000, 1.0000, 1.0000],
5: [0.0067, 1.0000, 1.0000, 1.0000],
10: [0., 1.0000, 1.0000, 1.0000]
}
for lambda_param in lambda_param_list:
clamp_exp = ClampExp(lambda_param)
tensor_out = clamp_exp(tensor)
assert torch.allclose(tensor_out,
torch.Tensor(outputs[lambda_param]),
atol=1e-4)


@unittest.skipIf(not has_torch, 'torch is not installed')
@pytest.mark.torch
def test_constscalelayer():
"""
This test evaluates the ConstScaleLayer.
"""
scale = 2
const_scale_layer = ConstScaleLayer(scale)
tensor = torch.tensor([1, 2, 3, 4])
tensor_out = const_scale_layer(tensor)
assert torch.allclose(tensor_out, tensor * scale)


@unittest.skipIf(not has_torch, 'torch is not installed')
@pytest.mark.torch
def test_mlp_flow():
"""
This test evaluates the MLP_flow.
"""
seed = 42
layers = [2, 4, 4, 2]
mlp_flow = MLP_flow(layers)
torch.manual_seed(seed)
np.random.seed(seed)
input_tensor = torch.randn(1, 2)
output_tensor = mlp_flow(input_tensor)
assert output_tensor.shape == torch.Size([1, 2])
9 changes: 9 additions & 0 deletions docs/source/api_reference/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ Torch Layers
.. autoclass:: deepchem.models.torch_models.layers.HighwayLayer
:members:

.. autoclass:: deepchem.models.torch_models.flows.ClampExp
:members:

.. autoclass:: deepchem.models.torch_models.flows.ConstScaleLayer
:members:

Flow Layers
^^^^^^^^^^^

Expand All @@ -291,6 +297,9 @@ Flow Layers
.. autoclass:: deepchem.models.torch_models.flows.ActNorm
:members:

.. autoclass:: deepchem.models.torch_models.flows.MLP_flow
:members:

Grover Layers
^^^^^^^^^^^^^

Expand Down

0 comments on commit 9f70cfa

Please sign in to comment.