Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 10, 2025
1 parent 12aa429 commit 14c5459
Showing 1 changed file with 164 additions and 2 deletions.
166 changes: 164 additions & 2 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os
from typing import (
Optional,
Union,
Expand All @@ -18,7 +19,115 @@
from .env import PRECISION_DICT as PT_PRECISION_DICT


class CustomSilu(torch.nn.Module):
class CustomSiluJit(torch.nn.Module):
def __init__(self, threshold=3.0):
super().__init__()
self.threshold = threshold

# Precompute parameters for the tanh replacement
sigmoid_threshold = 1 / (1 + np.exp(-threshold))
self.slope = float(
sigmoid_threshold + threshold * sigmoid_threshold * (1 - sigmoid_threshold)
)
self.const = float(threshold * sigmoid_threshold)

# Generate and compile Jiterator kernels
self._generate_jiterator_code()
self._compile_jiterator_kernels()
self._define_autograd_functions()

def _generate_jiterator_code(self):
# Forward kernel
self.forward_code = f"""
template <typename T>
T custom_silu_forward(T x) {{
const T threshold = {self.threshold};
const T slope = {self.slope};
const T const_val = {self.const};
T sig = 1.0 / (1.0 + exp(-x));
T silu = x * sig;
T tanh_part = tanh(slope * (x - threshold)) + const_val;
return (x > threshold) ? tanh_part : silu;
}}
"""

# First-order gradient kernel
self.backward_code = f"""
template <typename T>
T custom_silu_backward(T x, T grad_output) {{
const T threshold = {self.threshold};
const T slope = {self.slope};
T sig = 1.0 / (1.0 + exp(-x));
T grad_silu = sig * (1 + x * (1 - sig));
T tanh_term = tanh(slope * (x - threshold));
T grad_tanh = slope * (1 - tanh_term * tanh_term);
T grad = (x > threshold) ? grad_tanh : grad_silu;
return grad * grad_output;
}}
"""

# Corrected second-order gradient kernel (FIXED HERE)
self.double_backward_code = f"""
template <typename T>
T custom_silu_double_backward(T x, T grad_grad_output) {{
const T threshold = {self.threshold};
const T slope = {self.slope};
T grad_grad;
if (x > threshold) {{
T tanh_term = tanh(slope * (x - threshold));
grad_grad = -2 * slope * slope * tanh_term * (1 - tanh_term * tanh_term);
}} else {{
T sig = 1.0 / (1.0 + exp(-x));
T sig_prime = sig * (1 - sig);
grad_grad = sig_prime * (2 + x * (1 - 2 * sig)); // FIXED COEFFICIENT
}}
return grad_grad * grad_grad_output;
}}
"""

def _compile_jiterator_kernels(self):
self.jitted_forward = torch.cuda.jiterator._create_jit_fn(self.forward_code)
self.jitted_backward = torch.cuda.jiterator._create_jit_fn(self.backward_code)
self.jitted_double_backward = torch.cuda.jiterator._create_jit_fn(
self.double_backward_code
)

def _define_autograd_functions(self):
class CustomSiluForward(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return self.jitted_forward(x)

@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
return CustomSiluBackward.apply(x, grad_output)

class CustomSiluBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, x, grad_output):
ctx.save_for_backward(x)
return self.jitted_backward(x, grad_output)

@staticmethod
def backward(ctx, grad_grad_output):
(x,) = ctx.saved_tensors
return self.jitted_double_backward(x, grad_grad_output), None

self.CustomSiluForward = CustomSiluForward

def forward(self, x):
return self.CustomSiluForward.apply(x)


class CustomSiluOp(torch.nn.Module):
def __init__(self, threshold=3.0):
super().__init__()

Expand Down Expand Up @@ -59,6 +168,50 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return result


class CustomSilu(torch.nn.Module):
def __init__(self, threshold=3.0):
super().__init__()

def sigmoid(x):
return 1 / (1 + np.exp(-x))

def silu(x):
return x * sigmoid(x)

def silu_grad(x):
sig = sigmoid(x)
return sig + x * sig * (1 - sig)

self.threshold = threshold
self.slope = float(silu_grad(threshold))
self.const = float(silu(threshold))

if not hasattr(torch.ops.deepmd, "thsilu"):

def thsilu(
argument0: torch.Tensor,
argument1: float,
argument2: float,
argument3: float,
) -> list[torch.Tensor]:
raise NotImplementedError(
"thsilu is not available since customized PyTorch OP library is not built when freezing the model. "
"See documentation for model compression for details."
)

# Note: this hack cannot actually save a model that can be runned using LAMMPS.
torch.ops.deepmd.thsilu = thsilu

def forward(self, x: torch.Tensor) -> torch.Tensor:
silu_part = F.silu(x)
mask = x > self.threshold
if torch.any(mask):
tanh_part = torch.tanh(self.slope * (x - self.threshold)) + self.const
return torch.where(x < self.threshold, silu_part, tanh_part)
else:
return silu_part


class CustomDSilu(torch.nn.Module):
def __init__(self, threshold=3.0, sig_s=0.0):
super().__init__()
Expand Down Expand Up @@ -97,7 +250,16 @@ def __init__(self, activation: Optional[str]) -> None:
threshold = (
float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0
)
self.custom_silu = CustomSilu(threshold=threshold)
# get op method from environment
SILU_OP = os.environ.get("SILU_OP", "default")
if SILU_OP == "default":
self.custom_silu = CustomSilu(threshold=threshold)
elif SILU_OP == "op":
self.custom_silu = CustomSiluOp(threshold=threshold)
elif SILU_OP == "jit":
self.custom_silu = CustomSiluJit(threshold=threshold)
else:
raise ValueError(f"Not defined SILU_OP: {SILU_OP}!")
else:
self.custom_silu = None

Expand Down

0 comments on commit 14c5459

Please sign in to comment.