# coding=utf-8
"""PyTorch RoBERTa model. """
import math
import warnings
import fitlog
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss, MarginRankingLoss
from transformers.activations import ACT2FN, gelu
from transformers.configuration_roberta import RobertaConfig
from transformers.modeling_roberta import (
RobertaPreTrainedModel,
RobertaModel
)
from .CVAEModel import CVAEModel
from .Attention import AttentionInArgs
from .SelfAttention import SelfAttention
from .GATModel import GAT
import logging
logger = logging.getLogger(__name__)
import numpy as np
class RobertaLMHead(nn.Module):
"""Roberta Head for masked language modeling."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x)
return x
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
# self.maxpool = nn.MaxPool2d((256, 1))
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
# x = self.maxpool(features).squeeze(1)
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class RobertaPDTBModel(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config, add_pooling_layer=False)
# self.roberta_for_mlm = RobertaModel(config, add_pooling_layer=False)
self.lm_head = RobertaLMHead(config)
self.classifier = RobertaClassificationHead(config)
self.laynorm = nn.LayerNorm(config.hidden_size)
# 加attention
self.self_attention = SelfAttention(input_size=768,
embedding_dim=256,
output_size=768
)
# self.self_atten = nn.MultiheadAttention(768, 2)
self.attention = AttentionInArgs(input_size=768, # 768
embedding_dim=256, # 256
output_size=768 # 256, 768
)
self.self_attention2 = nn.MultiheadAttention(768, 2, dropout=0.2)
self.attention2 = nn.MultiheadAttention(768, 2)
self.attn_fc = nn.Linear(768, 128)
self.atten_fc = nn.Linear(128, 256)
# torch.Size([8, 256, 768]) torch.Size([8, 256, 256])
# self.gat = GAT(in_features=768, n_dim=256, n_class=768)
self.projector = nn.Sequential(
nn.Linear(768, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 768)
)
self.cvae = CVAEModel(config.hidden_size, config.hidden_size)
self.init_weights()
def _select_attention(self, sequence_output, attention_mask, arg1_first=False):
# 加attention
arg_len = sequence_output.shape[1] // 2
# X: [8, 128, 768], [8, 128, 768] --> [8, 256, 768]
# X: [8, 82, 768], [8, 82, 768] --> [8, 164, 768]
arg1, arg2 = sequence_output[:, :arg_len, :], sequence_output[:, arg_len:, :]
# arg1_mask, arg2_mask = attention_mask[:, :arg_len], attention_mask[:, :arg_len]
arg1 = self.self_attention(arg1, arg1)
arg2 = self.self_attention(arg2, arg2)
# arg1, _ = self.self_atten(arg1, arg1, arg1)
# arg2, _ = self.self_atten(arg2, arg2, arg2)
# print(arg1.shape)
if arg1_first:
# seq_out, _ = self.inter_atten(arg1, arg2, arg2) # [8, 128, 768]
seq_out = self.attention(arg1, arg2)
else:
# seq_out, _ = self.inter_atten(arg2, arg1, arg1) # [8, 128, 768]
seq_out = self.attention(arg2, arg1)
# print(seq_out.shape)
return self.attn_fc(seq_out)
# return seq_out
def _random_mask(self, sequence_output, arg1_first=False):
# 加attention
arg_len = sequence_output.shape[1] // 2
# X: [8, 128, 768], [8, 128, 768] --> [8, 256, 768]
# X: [8, 82, 768], [8, 82, 768] --> [8, 164, 768]
arg1, arg2 = sequence_output[:, :arg_len, :], sequence_output[:, arg_len:, :]
if arg1_first:
return self.attn_fc(arg1)
else:
return self.attn_fc(arg2)
def _add_attention(self, sequence_output, attention_mask):
# 加attention
arg_len = sequence_output.shape[1] // 2
# X: [8, 128, 768], [8, 128, 768] --> [8, 256, 768]
# X: [8, 82, 768], [8, 82, 768] --> [8, 164, 768]
arg1, arg2 = sequence_output[:, :arg_len, :], sequence_output[:, arg_len:, :]
# arg1_mask, arg2_mask = attention_mask[:, :arg_len], attention_mask[:, :arg_len]
# 目标: [8, 256, 256]
arg1 = self.self_attention(arg1, arg1)
arg2 = self.self_attention(arg2, arg2)
sequence_output = self.attention(arg1, arg2, attention_mask) # [8, 256]
# logging.info('adj: ' + str(adj[0]) + ' ' + str(adj.shape))
# sequence_output = self.gat(sequence_output, adj)
return sequence_output
def _add_attention_main(self, sequence_output, attention_mask=None):
# 加attention
arg_len = sequence_output.shape[1] // 2
# X: [8, 128, 768], [8, 128, 768] --> [8, 256, 768]
# X: [8, 82, 768], [8, 82, 768] --> [8, 164, 768]
arg1, arg2 = sequence_output[:, :arg_len, :], sequence_output[:, arg_len:, :]
# arg1_mask, arg2_mask = attention_mask[:, :arg_len], attention_mask[:, :arg_len]
# 目标: [8, 256, 256]
arg1, _ = self.self_attention2(arg1.transpose(0, 1), arg1.transpose(0, 1), arg1.transpose(0, 1))
# print(arg1.shape)
arg2, _ = self.self_attention2(arg2.transpose(0, 1), arg2.transpose(0, 1), arg2.transpose(0, 1))
arg1_attn, _ = self.attention2(arg1, arg1, arg2)
arg2_attn, _ = self.attention2(arg2, arg2, arg1)
arg1_attn = arg1_attn.transpose(0, 1)
arg2_attn = arg2_attn.transpose(0, 1)
sequence_output = torch.cat([arg1_attn, arg2_attn], dim=1)
return sequence_output
def _mlm_attention(self, sequence_output, attention_mask, input_ids, args=None, tokenizer=None):
torch.set_printoptions(profile="full")
masked_input_ids = input_ids.clone().detach()
sequence_for_mask = self._select_attention(sequence_output, attention_mask, arg1_first=True)
# sequence_for_mask = self._random_mask(sequence_output, arg1_first=True)
sequence_for_mask = torch.sum(sequence_for_mask, dim=2)
# sequence_for_mask = torch.argmax(sequence_for_mask, dim=2) # [8, 128]
# print('seq mask: ', sequence_for_mask)
mask_idx = torch.argsort(sequence_for_mask, dim=1, descending=True)[:, :args.mask_num]
# print(mask_idx.shape, mask_idx)
for elem, idx in zip(masked_input_ids, mask_idx):
elem[idx] = 50264 # '<mask>': 50264
# for elem in masked_input_ids:
# idx = list(range(0, 128))
# random.shuffle(idx)
# idx = idx[:args.mask_num]
# idx = sequence_for_mask[:, idx]
# elem[128 + idx] = 50264
# print('masked input_ids_arg2: ', masked_input_ids, '
')
# print('input_ids: ', input_ids)
# print(tokenizer.convert_ids_to_tokens(masked_input_ids[0]))
# print(tokenizer.convert_ids_to_tokens(input_ids[0]))
sequence_for_mask = self._select_attention(sequence_output, attention_mask)
# sequence_for_mask = self._random_mask(sequence_output)
sequence_for_mask = torch.sum(sequence_for_mask, dim=2)
# sequence_for_mask = torch.argmax(sequence_for_mask, dim=2)
mask_idx = torch.argsort(sequence_for_mask, dim=1, descending=True)[:, :args.mask_num]
for elem, idx in zip(masked_input_ids, mask_idx):
elem[idx] = 50264 # '<mask>': 50264
# for elem in masked_input_ids:
# idx = list(range(0, 128))
# random.shuffle(idx)
# idx = idx[:args.mask_num]
# idx = sequence_for_mask[:, idx]
# elem[idx] = 50264
# print(mask_idx.shape, mask_idx)
# print('masked input_ids_arg1: ', masked_input_ids, '
')
# print('input_ids: ', input_ids)
# print(tokenizer.convert_ids_to_tokens(masked_input_ids[0]))
# print(tokenizer.convert_ids_to_tokens(input_ids[0]))
return masked_input_ids, input_ids
def get_contrastive_loss(self, self_sample, positive_sample, reverse_sample):
self.cos = nn.CosineSimilarity(dim=-1)
self_and_pos = self.cos(self_sample, positive_sample)
self_and_neg = self.cos(self_sample, reverse_sample)
temp1 = torch.div(self_and_pos, 0.5)
temp2 = torch.div(self_and_neg, 0.5)
loss = -nn.LogSoftmax(0)(torch.div(temp1, temp1 + temp2)).diag().sum()
return loss
def loss_hardest_from_batchneg_and_nonclick(self, gap_value, self_sample, positive_sample, reverse_sample=None, labels=None, device=None):
batch_size = self_sample.size(0)
self_sample= torch.mean(self_sample, 1)
positive_sample = torch.mean(positive_sample, 1)
reverse_sample = torch.mean(reverse_sample,1)
"""select hardest from batchneg and nonclick, the pos must be better than hardest case"""
# query_embeddings = fluid.layers.reshape(query_embeddings, shape=[batch_size, 768])
# pos_embeddings = fluid.layers.reshape(pos_embeddings, shape=[batch_size, 768])
# neg_embedding = fluid.layers.reshape(neg_embedding, shape=[batch_size, 768])
self_sample = self_sample.view(batch_size, 768)
positive_sample = positive_sample.view(batch_size, 768)
reverse_sample = reverse_sample.view(batch_size, 768)
# query_embeddings_norm = fluid.layers.l2_normalize(x=query_embeddings, axis=-1) #[b,768] # paddle
# pos_embeddings_norm = fluid.layers.l2_normalize(x=pos_embeddings, axis=-1)
# neg_embeddings_norm = fluid.layers.l2_normalize(x=neg_embedding, axis=-1)
# 和下面的sum,就是余弦相似度
self_sample_norm = F.normalize(self_sample, dim=-1, p=2)
positive_sample_norm = F.normalize(positive_sample, dim=-1, p=2)
reverse_sample_norm = F.normalize(reverse_sample, dim=-1, p=2)
# make eye
self_mask = torch.matmul(self_sample_norm,self_sample_norm.transpose(0,1))
ones = torch.ones_like(self_mask, dtype = torch.float32).to(device)
self_mask = ones + torch.sign(self_mask + 1e-4 - ones)
pos_score = torch.sum(self_sample_norm*positive_sample_norm, dim=1, keepdim=True)
# 自身和正例相似度
cosdist_pos_self = torch.matmul(positive_sample_norm, self_sample_norm.transpose(0,1))
# 可以不加这个,一个trick
cosdist_self_pos = torch.matmul(self_sample_norm, positive_sample_norm.transpose(0,1))
# 自身和cvae的相似度
neg_score_sup = torch.sum(self_sample_norm*reverse_sample_norm, dim=1, keepdim=True)
# cvae负例和batch内其他负例,拼接,一起用
cosdist_self_all = torch.cat([cosdist_self_pos - 10 * self_mask, neg_score_sup], dim=1)
neg_score_hardest, _ = torch.max(cosdist_self_all, dim=1, keepdim=True)
neg_score_for_pos, _ = torch.max(cosdist_pos_self - 10 * self_mask, dim=1, keepdim=True)
# labels = torch.ones((batch_size, 1), dtype=torch.float32).to(device)
# max(0, -y*(x1-x2) + margin)
margin_loss_func = MarginRankingLoss(margin=gap_value)
rank_loss_from_self = margin_loss_func(pos_score, neg_score_hardest, labels)
rank_loss_from_pos = margin_loss_func(pos_score, neg_score_for_pos, labels)
# rank_loss_from_self = F.margin_ranking_loss(pos_score, neg_score_hardest, labels, margin=gap_value)
# rank_loss_from_pos = F.margin_ranking_loss(pos_score, neg_score_for_pos, labels, margin=gap_value)
loss = rank_loss_from_self + rank_loss_from_pos
loss = torch.mean(loss)
return loss
def _inter_attention(self, sequence_output, attention_mask):
# print(arg1.shape)
seq_out, _ = self.attention2(sequence_output.transpose(0, 1),
sequence_output.transpose(0, 1),
sequence_output.transpose(0, 1)) # [8, 128, 768]
return seq_out.transpose(0, 1)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
Training=False,
tokenizer=None,
args=None,
global_step=0
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# 在这修改代码
sequence_output = self.laynorm(sequence_output)
###########################################################################################
# MLM任务
input_ids_label = None
if Training and args.do_mlm > 0:
masked_ids, input_ids_label = self._mlm_attention(sequence_output, attention_mask, input_ids, args, tokenizer=tokenizer)
mlm_tasks = self.roberta(masked_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
mlm_sequence_output = mlm_tasks[0]
prediction_scores = self.lm_head(mlm_sequence_output)
# 测试
if global_step % 100 == 0:
pre = torch.argmax(prediction_scores, dim=2)
print('pre: ', tokenizer.convert_ids_to_tokens(pre[0]))
print('label: ', tokenizer.convert_ids_to_tokens(input_ids[0]))
masked_mlm_loss = None
if input_ids_label is not None:
loss_fct_mlm = CrossEntropyLoss()
masked_mlm_loss = loss_fct_mlm(prediction_scores.view(-1, self.config.vocab_size), input_ids_label.view(-1))
###############################################################################################
# 主任务
# sequence_output = self._add_attention(sequence_output, attention_mask)
# sequence_output = self.atten_fc(sequence_output.transpose(1, 2)).transpose(1, 2)
sequence_output = self._inter_attention(sequence_output, attention_mask)
out, mu, logvar = self.cvae(x=sequence_output, y=labels, Training=True, device=args.device)
cvae_loss = CVAEModel.loss_function(recon_x=out, x=sequence_output, mu=mu, logvar=logvar)
# 正例构建
positive_sample = nn.Dropout(0.1)(sequence_output)
# positive_sample = self.cvae(x=sequence_output, y=labels, device=args.device)
# 负例构建
# reverse_sample = sequence_output[torch.randperm(sequence_output.shape[0])]
reverse_sample,_,_ = self.cvae(x=sequence_output, y=labels, Training=True, Use=True, device=args.device)
# reverse_sample = self.cvae(x=sequence_output, y=labels, device=args.device)
# contrastive_loss = self.get_contrastive_loss(sequence_output, positive_sample, reverse_sample)
contrastive_loss = self.loss_hardest_from_batchneg_and_nonclick(0.2, sequence_output, positive_sample,
reverse_sample, labels, args.device)
sequence_output = self.laynorm(sequence_output)
logits = self.classifier(sequence_output)
# 计算loss
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
# [8, 2], [8]
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# loss_主 + loss_mlm
if Training and args.do_mlm > 0 and masked_mlm_loss is not None:
loss = loss + args.mlm_theta * masked_mlm_loss
# logger.info(str(contrastive_loss))
if Training:
if global_step % 100 == 0:
fitlog.add_loss(contrastive_loss, name = 'contrastive loss', step=global_step)
fitlog.add_loss(cvae_loss, name = 'cvae_loss', step=global_step)
# loss = loss + 0.5 * contrastive_loss + 0.01 * cvae_loss # 0.01
# loss = loss + 0.5 * contrastive_loss + 0.001 * cvae_loss
# loss = loss + 0.05 * contrastive_loss + 0.001 * cvae_loss
loss = loss + args.con_theta * contrastive_loss + args.cvae_theta * cvae_loss # 0.005, 0.001
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)