# Copyright 2021 Ross Wightman . All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is modified by Zilliz.
from typing import Union, Callable, Type
import torch
from torch import nn
from torch.nn import functional as F
from towhee.models.layers.activations import (GELU,
HardSigmoid,
HardSwish,
HardMish,
Mish,
PReLU,
Sigmoid,
Swish,
Tanh)
from towhee.models.layers.activations import (gelu,
hard_sigmoid,
hard_swish,
hard_mish,
mish,
sigmoid,
swish,
tanh)
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
_has_silu = 'silu' in dir(torch.nn.functional)
_has_hardswish = 'hardswish' in dir(torch.nn.functional)
_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
_has_mish = 'mish' in dir(torch.nn.functional)
_ACT_FN_DEFAULT = dict(
silu=F.silu if _has_silu else swish,
swish=F.silu if _has_silu else swish,
mish=F.mish if _has_mish else mish,
relu=F.relu,
relu6=F.relu6,
leaky_relu=F.leaky_relu,
elu=F.elu,
celu=F.celu,
selu=F.selu,
gelu=gelu,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
hard_swish=F.hardswish if _has_hardswish else hard_swish,
hard_mish=hard_mish
)
_ACT_FNS = (_ACT_FN_DEFAULT,)
_ACT_LAYER_DEFAULT = dict(
silu=nn.SiLU if _has_silu else Swish,
swish=nn.SiLU if _has_silu else Swish,
mish=nn.Mish if _has_mish else Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
leaky_relu=nn.LeakyReLU,
elu=nn.ELU,
prelu=PReLU,
celu=nn.CELU,
selu=nn.SELU,
gelu=GELU,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
hard_mish=HardMish,
)
_ACT_LAYERS = (_ACT_LAYER_DEFAULT,)
for a in _ACT_FNS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish'))
[docs]def get_act_fn(name: Union[Callable, str] = 'relu'):
"""
Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if callable(name) is True:
return name
return _ACT_FN_DEFAULT[name]
[docs]def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
"""
Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if isinstance(name, type):
return name
return _ACT_LAYER_DEFAULT[name]
[docs]def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
act_layer = get_act_layer(name)
if act_layer is None:
return None
return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)