# Implementation of "Neural Audio Fingerprint for High-specific Audio Retrieval based on Contrasive Learning."
# https://arxiv.org/abs/2010.11910
#
# Inspired by https://github.com/stdio2016/pfann
#
# Additions & Modifications by Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from towhee.models.layers.conv2d_separable import SeparableConv2d
[docs]class FrontConv(nn.Module):
"""
Front Convolutional Layers
Args:
dim (`int`):
Dimension of features
h (`int`):
Height of input.
in_f (`int`):
Padding parameter.
in_t (`int`):
Padding parameter.
fuller (`bool=False`):
Whether to use group in conv2 layer.
activation (`str`):
Activation layer.
strides (`list of tuple`):
A list of stride tuples.
relu_after_bn (`bool`):
Whether to add ReLU after batch norm.
"""
[docs] def __init__(self, dim, h, in_f, in_t,
fuller=False, activation='relu',
strides=None, relu_after_bn=True):
super().__init__()
channels = [1, dim, dim, 2 * dim, 2 * dim, 4 * dim, 4 * dim, h, h]
convs = []
if activation == 'relu':
activation = nn.ReLU()
elif activation == 'elu':
activation = nn.ELU()
else:
raise ValueError(f'Invalid activation "{activation}". Accept "relu" or "elu" only.')
for i in range(8):
kernel_size = 3
stride = (2, 2)
if strides is not None:
stride = strides[i][0][1], strides[i][1][0]
sep_conv = SeparableConv2d(
in_c=channels[i],
out_c=channels[i + 1],
kernel_size=kernel_size,
stride=stride,
in_f=in_f, in_t=in_t,
fuller=fuller, activation=activation, relu_after_bn=relu_after_bn
)
convs.append(sep_conv)
in_f = (in_f - 1) // stride[1] + 1
in_t = (in_t - 1) // stride[0] + 1
assert in_f == in_t == 1, 'output must be 1x1'
self.convs = nn.ModuleList(convs)
def hack(self):
for conv in self.convs:
conv.hack()
[docs] def forward(self, x):
x = x.unsqueeze(1)
for _, conv in enumerate(self.convs):
x = conv(x)
return x
[docs]class DivEncoder(nn.Module):
"""
Divider & encoder
dim (`int`):
Dimension of features
h (`int`):
Height of input.
u (`int`):
Parameter to multiple linear dimension.
"""
[docs] def __init__(self, dim, h, u):
super().__init__()
assert h % dim == 0, f'h ({h}) must be divisible by d ({dim})'
v = h // dim
self.d = dim
self.h = h
self.u = u
self.v = v
self.linear1 = nn.Conv1d(dim * v, dim * u, kernel_size=(1,), groups=dim)
self.elu = nn.ELU()
self.linear2 = nn.Conv1d(dim * u, dim, kernel_size=(1,), groups=dim)
[docs] def forward(self, x, norm=True):
x = x.reshape([-1, self.h, 1])
x = self.linear1(x)
x = self.elu(x)
x = self.linear2(x)
x = x.reshape([-1, self.d])
if norm:
x = torch.nn.functional.normalize(x, p=2.0)
return x
[docs]class NNFp(nn.Module):
"""
Neural network Fingerprinter
Args:
dim (`int`):
Dimension of features
h (`int`):
Height of input.
u (`int`):
Parameter to multiple linear dimension.
in_f (`int`):
Padding parameter.
in_t (`int`):
Padding parameter.
fuller (`bool=False`):
Whether to use group in conv2 layer.
activation (`str`):
Activation layer.
strides (`list of tuple`):
A list of stride tuples.
relu_after_bn (`bool`):
Whether to add ReLU after batch norm.
"""
[docs] def __init__(self, dim, h, u, in_f, in_t, fuller=False, activation='relu', strides=None, relu_after_bn=True):
super().__init__()
self.front = FrontConv(
dim, h, in_f, in_t,
fuller=fuller,
activation=activation,
strides=strides,
relu_after_bn=relu_after_bn
)
self.div_encoder = DivEncoder(dim, h, u)
self.hacked = False
def hack(self):
self.hacked = not self.hacked
self.front.hack()
[docs] def forward(self, x, norm=True):
if self.hacked:
x = x.flip([1, 2])
x = self.front(x)
x = self.div_encoder(x, norm=norm)
return x