Source code for towhee.models.lightning_dot.bi_encoder

# Built on top of the original implementation at
# Modifications by Copyright 2022 Zilliz. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch

from collections import defaultdict
from torch import nn

[docs]class BiEncoder(nn.Module): """ Bi-Encoder model component. Encapsulates query/question and context/passage encoders. """
[docs] def __init__(self, UniterEncoder: nn.Module, BertEncoder: nn.Module, args: Any, # pylint: disable=invalid-name fix_img_encoder: bool = False, fix_txt_encoder: bool = False, project_dim: int = 0): super().__init__() if args.img_model_type == 'uniter-base': self.img_model = UniterEncoder.init_encoder(args.img_model_config, checkpoint_path=args.img_checkpoint, project_dim=project_dim) else: raise ValueError(f'image encoder does not support other types ({args.img_model_type}) for now') if args.txt_model_type == 'bert-base': self.txt_model = BertEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim) elif args.txt_model_type == 'uniter-base': self.txt_model = UniterEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim) else: raise ValueError(f'txt encoder does not support other types ({args.txt_model_type}) for now') self.fix_img_encoder = fix_img_encoder self.fix_txt_encoder = fix_txt_encoder self.project_dim = project_dim if fix_txt_encoder: for param in self.txt_model.parameters(): param.requires_grad = False if fix_img_encoder: for param in self.img_model.parameters(): param.requires_grad = False
@staticmethod def get_representation(sub_model, input_ids, attention_mask, position_ids, img_feat, img_pos_feat, img_masks, gather_index=None, fix_encoder=False): if fix_encoder: with torch.no_grad(): sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids, img_feat, img_pos_feat, img_masks, gather_index) else: sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids, img_feat, img_pos_feat, img_masks, gather_index) if sequence_output.requires_grad_(requires_grad=True) pooled_output.requires_grad_(requires_grad=True) return sequence_output, pooled_output, hidden_states
[docs] def forward(self, batch, output_all_encoded_layers=False): # batch keys # imgs # txts # caps batch = defaultdict(lambda: None, batch) if 'txts' in batch: sb = batch['txts'] txt_seq, txt_pooled, _ = self.get_representation(self.txt_model, sb['input_ids'], sb['attention_mask'], sb['position_ids'], sb['img_feat'], sb['img_pos_feat'], sb['img_masks'], sb['gather_index'], self.fix_txt_encoder) else: txt_seq, txt_pooled = None, None if 'imgs' in batch: sb = batch['imgs'] img_seq, img_pooled, _ = self.get_representation(self.img_model, sb['input_ids'], sb['attention_mask'], sb['position_ids'], sb['img_feat'], sb['img_pos_feat'], sb['img_masks'], sb['gather_index'], self.fix_txt_encoder) else: img_seq, img_pooled = None, None if 'caps' in batch and batch['caps']['input_ids'] is not None: sb = batch['caps'] cap_seq, cap_pooled, _ = self.get_representation(self.txt_model, sb['input_ids'], sb['attention_mask'], sb['position_ids'], sb['img_feat'], sb['img_pos_feat'], sb['img_masks'], sb['gather_index'], self.fix_txt_encoder) else: cap_seq, cap_pooled = None, None if output_all_encoded_layers: return txt_seq, img_seq, cap_seq else: return txt_pooled, img_pooled, cap_pooled