Source code for towhee.models.layers.time2vec

# Original paper:
#
#     "Time2Vec: Learning a Vector Representation of Time" https://arxiv.org/abs/1907.05321
#
# Implemented with Pytorch by Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from torch import nn


[docs]class Time2Vec(nn.Module): """ Time2Vec implementaion in Pytorch. Args: seq_len (`int`): the length of input sequence activation (`str`): activation functions used for periodic time embedding (only support "sin" or "cos") Return: embedding with the same shape as input Example: >>> from towhee.models.layers.time2vec import Time2Vec >>> import torch >>> >>> x = torch.randn(3, 64) >>> model = Time2Vec(seq_len=64, activation="sin") >>> model(x).shape torch.Size([3, 64]) """
[docs] def __init__(self, seq_len: int, activation: str, **kwargs): super().__init__(**kwargs) self.w0 = nn.parameter.Parameter(torch.randn(seq_len, 1)) self.b0 = nn.parameter.Parameter(torch.randn(1)) self.w = nn.parameter.Parameter(torch.randn(seq_len, 1)) self.b = nn.parameter.Parameter(torch.randn(1)) if activation == 'sin': self.f = torch.sin elif activation == 'cos': self.f = torch.cos else: raise ValueError(f'Activation {activation} is not supported yet.') self.fc = nn.Linear(2, seq_len)
[docs] def forward(self, x): periodic = self.f(torch.matmul(x, self.w) + self.b) # Periodic time embedding linear = torch.matmul(x, self.w0) + self.b0 # Linear time embedding out = torch.cat([periodic, linear], -1) # Concat time embeddings out = self.fc(out) # Convert embedding dimension back to sequence length return out