Source code for towhee.models.retina_face.retinaface

# Copyright 2021 biubug6 . All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is modified by Zilliz.

#adapted from https://github.com/biubug6/Pytorch_Retinaface
#from collections import OrderedDict
from typing import Tuple, Dict

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import models

from towhee.models.retina_face.ssh import SSH
from towhee.models.retina_face.mobilenet_v1 import MobileNetV1
from towhee.models.retina_face.retinaface_fpn import RetinaFaceFPN
from towhee.models.retina_face.heads import ClassHead, BboxHead, LandmarkHead
from towhee.models.retina_face.prior_box import PriorBox
from towhee.models.retina_face.utils import decode, decode_landm, IntermediateLayerGetter

[docs]class RetinaFace(nn.Module): """ ReitinaFace RetinaFace: Single-stage Dense Face Localisation in the Wild. Described in https://arxiv.org/abs/1905.00641. Args: cfg (`Dict`): Network related settings. phase (`str`): train or test. """
[docs] def __init__(self, cfg: Dict=None, phase: str='train'): super().__init__() self.phase = phase backbone = None self.cfg = cfg if cfg['name'] == 'mobilenet0.25': backbone = MobileNetV1() #if cfg['pretrain']: # checkpoint = torch.load('./pretrained_weights/mobilenetV1X0.25_pretrain.tar', map_location=torch.device('cpu')) # new_state_dict = OrderedDict() # for k, v in checkpoint['state_dict'].items(): # name = k[7:] # remove module. # new_state_dict[name] = v # # load params # backbone.load_state_dict(new_state_dict) elif cfg['name'] == 'Resnet50': #backbone = models.resnet50(pretrained=cfg['pretrain']) backbone = models.resnet50(False) self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) in_channels_stage2 = cfg['in_channel'] in_channels_list = [ in_channels_stage2 * 2, in_channels_stage2 * 4, in_channels_stage2 * 8, ] out_channels = cfg['out_channel'] self.max_size = cfg['max_size'] self.target_size = cfg['target_size'] self.fpn = RetinaFaceFPN(in_channels_list,out_channels) self.ssh1 = SSH(out_channels, out_channels) self.ssh2 = SSH(out_channels, out_channels) self.ssh3 = SSH(out_channels, out_channels) self.class_head = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel']) self.bbox_head = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) self.landmark_head = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2): classhead = nn.ModuleList() for _ in range(fpn_num): classhead.append(ClassHead(inchannels,anchor_num)) return classhead def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2): bboxhead = nn.ModuleList() for _ in range(fpn_num): bboxhead.append(BboxHead(inchannels,anchor_num)) return bboxhead def _make_landmark_head(self,fpn_num: int=3,inchannels: int=64,anchor_num :int=2): landmarkhead = nn.ModuleList() for _ in range(fpn_num): landmarkhead.append(LandmarkHead(inchannels,anchor_num)) return landmarkhead
[docs] def forward(self,inputs: torch.FloatTensor): out = self.body(inputs) # FPN fpn = self.fpn(out) # SSH feature1 = self.ssh1(fpn[0]) feature2 = self.ssh2(fpn[1]) feature3 = self.ssh3(fpn[2]) features = [feature1, feature2, feature3] bbox_regressions = torch.cat([self.bbox_head[i](feature) for i, feature in enumerate(features)], dim=1) classifications = torch.cat([self.class_head[i](feature) for i, feature in enumerate(features)],dim=1) ldm_regressions = torch.cat([self.landmark_head[i](feature) for i, feature in enumerate(features)], dim=1) if self.phase == 'train': output = (bbox_regressions, classifications, ldm_regressions) else: output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) return output
def inference(self, img: np.array) -> Tuple[torch.FloatTensor, torch.FloatTensor]: im_shape = img.shape h, w, _ = im_shape im_height, im_width = h, w # pylint: disable=unsubscriptable-object scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) device = img.get_device() if device == -1: device = torch.device('cpu') mean = torch.FloatTensor(self.cfg['mean']).to(device) img = img - mean img = img.permute(2, 0, 1) img = img.unsqueeze(0) img = img.to(device) scale = scale.to(device) loc, conf, landms = self.forward(img) priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) priors = priorbox.forward() priors = priors.to(device) prior_data = priors.data boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) boxes = boxes * scale boxes = boxes.to(device) scores = conf.squeeze(0).data[:, 1] landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance']) scale_replicate = torch.Tensor([img.shape[3], img.shape[2]] * 5) scale_replicate = scale_replicate.to(device) dets = torch.hstack((boxes, scores.unsqueeze(-1))) keep = torchvision.ops.nms(dets[:,:4], dets[:,4], self.cfg['nms_threshold']) dets = dets[keep, :] landms = landms[keep] confident_ids = torch.where(dets[:,4] > self.cfg['confidence_threshold']) return dets[confident_ids], landms[confident_ids]