zoukankan      html  css  js  c++  java
  • 推荐系统(二):基于pytorch的textdeepfm

    一、main函数

    import torch
    import tqdm
    from sklearn.metrics import roc_auc_score
    from torch.utils.data import DataLoader
    
    from torchfm.dataset.avazu import AvazuDataset
    from torchfm.dataset.criteo import CriteoDataset
    from torchfm.dataset.movielens import MovieLens1MDataset, MovieLens20MDataset
    from torchfm.dataset.zhihu import ZhihuDataset
    from torchfm.model.afi import AutomaticFeatureInteractionModel
    from torchfm.model.afm import AttentionalFactorizationMachineModel
    from torchfm.model.dcn import DeepCrossNetworkModel
    from torchfm.model.dfm import DeepFactorizationMachineModel
    from torchfm.model.ffm import FieldAwareFactorizationMachineModel
    from torchfm.model.fm import FactorizationMachineModel
    from torchfm.model.fnfm import FieldAwareNeuralFactorizationMachineModel
    from torchfm.model.fnn import FactorizationSupportedNeuralNetworkModel
    from torchfm.model.hofm import HighOrderFactorizationMachineModel
    from torchfm.model.lr import LogisticRegressionModel
    from torchfm.model.ncf import NeuralCollaborativeFiltering
    from torchfm.model.nfm import NeuralFactorizationMachineModel
    from torchfm.model.pnn import ProductNeuralNetworkModel
    from torchfm.model.wd import WideAndDeepModel
    from torchfm.model.xdfm import ExtremeDeepFactorizationMachineModel
    from torchfm.model.afn import AdaptiveFactorizationNetwork
    import numpy as np
    
    def get_word_vector():
        topic_word_vector = {}
        with open("word_vector/word_vectors_64d.txt", "r") as f:
            fList = f.readlines()
            for fLine in fList:
                rowList = fLine.split("\t")
                k = rowList[0]
                v = rowList[1].split(" ")
                topic_word_vector[k] = v
        return topic_word_vector
    
    def pad_sequences(x, topic_word_vector, b):
        maxlen = 20
        data_a_vec = []
        for sequence_a in x:
            sequence_vec = []
            if sequence_a == "-1":
                return np.zeros([b, maxlen, 64])
            for t in sequence_a.split(","):
                v = topic_word_vector[t]
                sequence_vec.append([float(t) for t in v])
            sequence_vec = np.array(sequence_vec)
            if maxlen > sequence_vec.shape[0]:
                add = np.zeros((maxlen - sequence_vec.shape[0], 64))
                sequenceVec = np.vstack((sequence_vec, add))
            else:
                sequenceVec = sequence_vec[:20]
            data_a_vec.append(sequenceVec)
        a_vec = np.array(data_a_vec)
        return a_vec
    
    
    def get_dataset(name, path, text_col):
        if name == 'movielens1M':
            return MovieLens1MDataset(path)
        elif name == 'movielens20M':
            return MovieLens20MDataset(path)
        elif name == 'criteo':
            return CriteoDataset(path)
        elif name == 'avazu':
            return AvazuDataset(path)
        elif name == 'zhihu':
            return ZhihuDataset(text_col, path)
        else:
            raise ValueError('unknown dataset name: ' + name)
    
    
    def get_model(name, dataset):
        """
        Hyperparameters are empirically determined, not opitmized.
        """
        field_dims = dataset.field_dims
        print(field_dims)
        if name == 'lr':
            return LogisticRegressionModel(field_dims)
        elif name == 'fm':
            return FactorizationMachineModel(field_dims, embed_dim=16)
        elif name == 'hofm':
            return HighOrderFactorizationMachineModel(field_dims, order=3, embed_dim=16)
        elif name == 'ffm':
            return FieldAwareFactorizationMachineModel(field_dims, embed_dim=4)
        elif name == 'fnn':
            return FactorizationSupportedNeuralNetworkModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)
        elif name == 'wd':
            return WideAndDeepModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)
        elif name == 'ipnn':
            return ProductNeuralNetworkModel(field_dims, embed_dim=16, mlp_dims=(16,), method='inner', dropout=0.2)
        elif name == 'opnn':
            return ProductNeuralNetworkModel(field_dims, embed_dim=16, mlp_dims=(16,), method='outer', dropout=0.2)
        elif name == 'dcn':
            return DeepCrossNetworkModel(field_dims, embed_dim=16, num_layers=3, mlp_dims=(16, 16), dropout=0.2)
        elif name == 'nfm':
            return NeuralFactorizationMachineModel(field_dims, embed_dim=64, mlp_dims=(64,), dropouts=(0.2, 0.2))
        elif name == 'ncf':
            # only supports MovieLens dataset because for other datasets user/item colums are indistinguishable
            assert isinstance(dataset, MovieLens20MDataset) or isinstance(dataset, MovieLens1MDataset)
            return NeuralCollaborativeFiltering(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2,
                                                user_field_idx=dataset.user_field_idx,
                                                item_field_idx=dataset.item_field_idx)
        elif name == 'fnfm':
            return FieldAwareNeuralFactorizationMachineModel(field_dims, embed_dim=4, mlp_dims=(64,), dropouts=(0.2, 0.2))
        elif name == 'dfm':
            #[6040, 3952]
            return DeepFactorizationMachineModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)
    
        elif name == 'xdfm':
            return ExtremeDeepFactorizationMachineModel(
                field_dims, embed_dim=16, cross_layer_sizes=(16, 16), split_half=False, mlp_dims=(16, 16), dropout=0.2)
        elif name == 'afm':
            return AttentionalFactorizationMachineModel(field_dims, embed_dim=16, attn_size=16, dropouts=(0.2, 0.2))
        elif name == 'afi':
            return AutomaticFeatureInteractionModel(
                 field_dims, embed_dim=16, atten_embed_dim=64, num_heads=2, num_layers=3, mlp_dims=(400, 400), dropouts=(0, 0, 0))
        elif name == 'afn':
            print("Model:AFN")
            return AdaptiveFactorizationNetwork(
                field_dims, embed_dim=16, LNN_dim=1500, mlp_dims=(400, 400, 400), dropouts=(0, 0, 0))
        else:
            raise ValueError('unknown model name: ' + name)
    
    
    class EarlyStopper(object):
    
        def __init__(self, num_trials, save_path):
            self.num_trials = num_trials
            self.trial_counter = 0
            self.best_accuracy = 0
            self.save_path = save_path
    
        def is_continuable(self, model, accuracy):
            if accuracy > self.best_accuracy:
                self.best_accuracy = accuracy
                self.trial_counter = 0
                torch.save(model, self.save_path)
                return True
            elif self.trial_counter + 1 < self.num_trials:
                self.trial_counter += 1
                return True
            else:
                return False
    
    
    def train(model, optimizer, data_loader, criterion, device, vev, b, log_interval=100):
        model.train()
        total_loss = 0
        tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
        for i, (fields, target, text) in enumerate(tk0):
            words = pad_sequences(text, vev, b)
    
            fields, target = fields.to(device), target.to(device)
            y = model(fields, words)
            loss = criterion(y, target.float())
            model.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if (i + 1) % log_interval == 0:
                tk0.set_postfix(loss=total_loss / log_interval)
                total_loss = 0
    
    
    def tst(model, data_loader, device, vev, b):
        model.eval()
        targets, predicts = list(), list()
        with torch.no_grad():
            for fields, target, text in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
    
                words = pad_sequences(text, vev, b)
                fields, target = fields.to(device), target.to(device)
                y = model(fields, words)
                targets.extend(target.tolist())
                predicts.extend(y.tolist())
    
        return roc_auc_score(targets, predicts)
    
    
    def main(dataset_name,
             dataset_path,
             model_name,
             epoch,
             learning_rate,
             batch_size,
             weight_decay,
             device,
             save_dir,
             text_col):
        vec = get_word_vector()
        device = torch.device(device)
        #读取数据
        dataset = get_dataset(dataset_name, dataset_path, text_col)
        #len一下是总行数
        #1000209
        train_length = int(len(dataset) * 0.8)
        valid_length = int(len(dataset) * 0.1)
        test_length = len(dataset) - train_length - valid_length
        #拆分数据集
        train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
            dataset, (train_length, valid_length, test_length))
        train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0)
        valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=0)
        test_data_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)
        #deepfm模型
    
        model = get_model(model_name, dataset).to(device)
        criterion = torch.nn.BCELoss()
        optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        early_stopper = EarlyStopper(num_trials=2, save_path=f'{save_dir}/{model_name}.pt')
        for epoch_i in range(epoch):
            train(model, optimizer, train_data_loader, criterion, device, vec, batch_size)
            auc = tst(model, valid_data_loader, device, vec, batch_size)
            print('epoch:', epoch_i, 'validation: auc:', auc)
            if not early_stopper.is_continuable(model, auc):
                print(f'validation: best auc: {early_stopper.best_accuracy}')
                break
        auc = tst(model, test_data_loader, device, vec, batch_size)
        print(f'test auc: {auc}')
    
    
    if __name__ == '__main__':
        # import argparse
        #
        # parser = argparse.ArgumentParser()
        # parser.add_argument('--dataset_name', default='criteo')
        # parser.add_argument('--dataset_path', help='criteo/train.txt, avazu/train, or ml-1m/ratings.dat')
        # parser.add_argument('--model_name', default='afi')
        # parser.add_argument('--epoch', type=int, default=100)
        # parser.add_argument('--learning_rate', type=float, default=0.001)
        # parser.add_argument('--batch_size', type=int, default=2048)
        # parser.add_argument('--weight_decay', type=float, default=1e-6)
        # parser.add_argument('--device', default='cuda:0')
        # parser.add_argument('--save_dir', default='chkpt')
        # args = parser.parse_args()
        # main(args.dataset_name,
        #      args.dataset_path,
        #      args.model_name,
        #      args.epoch,
        #      args.learning_rate,
        #      args.batch_size,
        #      args.weight_decay,
        #      args.device,
        #      args.save_dir)
        main("zhihu",
             "zhihu/zhihu.txt",
             "dfm",
             3,
             0.001,
             2,
             1e-6,
             "cpu",
             "chkpt",
             "q_title_words")

    二、dataset

        在类别型变量处理时,因为把出现频率太低的数据也加进索引字典的话,会导致模型学习的效果下降,所以在建立索引字典的时候我们会将词频太低的数据过滤,已经设置。

        最后得到的是每一列值的索引。

    import math
    import shutil
    import struct
    from collections import defaultdict
    from functools import lru_cache
    from pathlib import Path
    
    import lmdb
    import numpy as np
    import torch.utils.data
    from tqdm import tqdm
    import pandas as pd
    
    
    class ZhihuDataset(torch.utils.data.Dataset):
        """
        Criteo Display Advertising Challenge Dataset
    
        Data prepration:
            * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature
            * Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition
    
        :param dataset_path: criteo train.txt path.
        :param cache_path: lmdb cache path.
        :param rebuild_cache: If True, lmdb cache is refreshed.
        :param min_threshold: infrequent feature threshold.
    
        Reference:
            https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset
            https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf
        """
    
        def __init__(self, text_columns, dataset_path=None, cache_path='.zhihu', rebuild_cache=False, min_threshold=10):
            self.text = pd.read_csv(dataset_path, sep="\t")
            self.ALL = 25
            self.NUM_FEATS = 23
            self.NUM_INT_FEATS = 7
            self.min_threshold = min_threshold
            if rebuild_cache or not Path(cache_path).exists():
                shutil.rmtree(cache_path, ignore_errors=True)
                if dataset_path is None:
                    raise ValueError('create cache: failed: dataset_path is None')
                self.__build_cache(dataset_path, cache_path)
            self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True)
            with self.env.begin(write=False) as txn:
                self.length = txn.stat()['entries'] - 1
                self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32)
    
        def __getitem__(self, index):
            with self.env.begin(write=False) as txn:
                np_array = np.frombuffer(
                    txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long)
            x = np_array[1:]
            y = np_array[0]
            try:
                _text = self.text.iloc[index,24]
            except Exception as e:
                _text = "-1"
                print("-1")
            return x,y,_text
    
    
        def __len__(self):
            return self.length
    
        def __build_cache(self, path, cache_path):
            feat_mapper, defaults = self.__get_feat_mapper(path)
            with lmdb.open(cache_path, map_size=int(1e10)) as env:
                field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32)
                for i, fm in feat_mapper.items():
                    field_dims[i - 1] = len(fm) + 1
                with env.begin(write=True) as txn:
                    txn.put(b'field_dims', field_dims.tobytes())
                for buffer in self.__yield_buffer(path, feat_mapper, defaults):
                    with env.begin(write=True) as txn:
                        for key, value in buffer:
                            txn.put(key, value)
    
        def __get_feat_mapper(self, path):
            feat_cnts = defaultdict(lambda: defaultdict(int))
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: counting features')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        feat_cnts[i][convert_numeric_feature(values[i])] += 1
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        feat_cnts[i][values[i]] += 1
            feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()}
            feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()}
            defaults = {i: len(cnt) for i, cnt in feat_mapper.items()}
            return feat_mapper, defaults
    
        def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)):
            item_idx = 0
            buffer = list()
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: setup lmdb')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32)
                    np_array[0] = int(values[0])
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        np_array[i] = feat_mapper[i].get(convert_numeric_feature(values[i]), defaults[i])
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        np_array[i] = feat_mapper[i].get(values[i], defaults[i])
                    buffer.append((struct.pack('>I', item_idx), np_array.tobytes()))
                    item_idx += 1
                    if item_idx % buffer_size == 0:
                        yield buffer
                        buffer.clear()
                yield buffer
    
    
    
    @lru_cache(maxsize=None)
    def convert_numeric_feature(val: str):
        if val == '':
            return 'NULL'
        v = float(val)
        if v > 2:
            return str(int(math.log(v) ** 2))
        else:
            return str(v - 2)

    三、DeepFM 

    import torch
    
    from torchfm.layer import FactorizationMachine, FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron
    
    
    class DeepFactorizationMachineModel(torch.nn.Module):
        """
        A pytorch implementation of DeepFM.
    
        Reference:
            H Guo, et al. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, 2017.
        """
    
        def __init__(self, field_dims, embed_dim, mlp_dims, dropout):
            super().__init__()
            self.linear = FeaturesLinear(field_dims)
            self.fm = FactorizationMachine(reduce_sum=True)
    
            self.embedding = FeaturesEmbedding(field_dims, embed_dim)
            # 32
            self.embed_output_dim = len(field_dims) * embed_dim
    
            self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout)
    
        def forward(self, x, text):
            """
            :param x: Long tensor of size ``(batch_size, num_fields(列数))``
            """
            #x[3, 2]
    
            embed_x = self.embedding(x)
            #[3, 2, 16]
    
            #【3, 2*16】
    
            a = self.linear(x)
            b = self.fm(embed_x)
            c = self.mlp(embed_x.view(-1, self.embed_output_dim), text)
            x = self.linear(x) + self.fm(embed_x) + self.mlp(embed_x.view(-1, self.embed_output_dim), text)
            return torch.sigmoid(x.squeeze(1))

    四、layer

    import numpy as np
    import torch
    import torch.nn.functional as F
    from torch import nn
    
    
    class FeaturesLinear(torch.nn.Module):
    
        def __init__(self, field_dims, output_dim=1):
            super().__init__()
            # sum(field_dims) = [6040 + 3952] = 9992
            self.fc = torch.nn.Embedding(sum(field_dims), output_dim) #[9992, 1]
            self.bias = torch.nn.Parameter(torch.zeros((output_dim,)))
            self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
            #array([   0, 6040])
            self.a = 0
        def forward(self, x):
            """
            :param x: Long tensor of size ``(batch_size, num_fields)``
    
            """
    
    
            #对数据第二列加了6040
            a = x + x.new_tensor(self.offsets).unsqueeze(0)
            b = torch.sum(self.fc(a), dim=1) + self.bias
            #每个x都变成了一个值[a]
            return b
    
    
    class FeaturesEmbedding(torch.nn.Module):
    
        def __init__(self, field_dims, embed_dim):
            super().__init__()
            self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)  #嵌入维度(6040+3952, 16)
            self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
            # 偏执:[0, 6040]
            torch.nn.init.xavier_uniform_(self.embedding.weight.data)
    
        def forward(self, x):
            """
            :param x: Long tensor of size ``(batch_size, num_fields)``
            """
            # 对数据第二列加了6040
            x = x + x.new_tensor(self.offsets).unsqueeze(0)
            #嵌入结果:对所有人、电影编号。每个电影或者人的维度是16
            return self.embedding(x)
    
    
    class FieldAwareFactorizationMachine(torch.nn.Module):
    
        def __init__(self, field_dims, embed_dim):
            super().__init__()
            self.num_fields = len(field_dims)
            self.embeddings = torch.nn.ModuleList([
                torch.nn.Embedding(sum(field_dims), embed_dim) for _ in range(self.num_fields)
            ])
            self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
            for embedding in self.embeddings:
                torch.nn.init.xavier_uniform_(embedding.weight.data)
    
        def forward(self, x):
            """
            :param x: Long tensor of size ``(batch_size, num_fields)``
            """
            x = x + x.new_tensor(self.offsets).unsqueeze(0)
            xs = [self.embeddings[i](x) for i in range(self.num_fields)]
            ix = list()
            for i in range(self.num_fields - 1):
                for j in range(i + 1, self.num_fields):
                    ix.append(xs[j][:, i] * xs[i][:, j])
            ix = torch.stack(ix, dim=1)
            return ix
    
    
    class FactorizationMachine(torch.nn.Module):
    
        def __init__(self, reduce_sum=True):
            super().__init__()
            self.reduce_sum = reduce_sum
    
        def forward(self, x):
            """
            :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
            """
            #[3, 2(真实列数), 16]
    
            #[3, 16]
            square_of_sum = torch.sum(x, dim=1) ** 2
            #[3, 16]
            sum_of_square = torch.sum(x ** 2, dim=1)
            #[3, 1]
            ix = square_of_sum - sum_of_square
            if self.reduce_sum:
                ix = torch.sum(ix, dim=1, keepdim=True)
            #[3, 1]
            out = 0.5 * ix
            return out
    
    
    class MultiLayerPerceptron(torch.nn.Module):
    
        def __init__(self, input_dim, embed_dims, dropout, output_layer=True):
            super().__init__()
            print("input_dim: ", input_dim)
            print("embed_dims: ", embed_dims)
    
    
            self.mlp =nn.Sequential(
                nn.Linear(568,200),
                nn.ReLU(inplace = True),
                nn.Linear(200,200),
                nn.ReLU(inplace = True),
                nn.Linear(200,1),
                nn.ReLU(inplace = True)
            )
            # text(2, 20, 64)
            self.lstm = nn.LSTM(input_size=64, hidden_size=10, num_layers=4, batch_first=True)
    
    
    
        def forward(self, x, text):
            # (2, 20, 64)
            # out1: torch.Size([2, 20, 10])
            # x
            # shape: torch.Size([2, 368])
            # out: torch.Size([2, 200])
            # b: torch.Size([2, 568])
            # out1: torch.Size([2, 20, 10])
            # x shape: torch.Size([2, 368])
            # out: torch.Size([2, 200])
            # b: torch.Size([2, 568])
            """
            :param x: Float tensor of size ``(batch_size, embed_dim)``
            """
            #x [b, 368]
    
            text = np.float32(text)
            out1, (h1, c1) = self.lstm(torch.from_numpy(text))
            #torch.Size([2, 20, 10])
            # print("out1: ", out1.shape)
            # print("x shape:", x.shape)
            out = out1.contiguous().view(x.shape[0], 20*10)
            # print("out:", out.shape)
            #[b, 300]
            b = torch.cat([x, out], dim=1)
            #[b, 332]
            # print("b:", b.shape)
            c = self.mlp(b)
            # print(c.shape)
            return c
    
    
    class InnerProductNetwork(torch.nn.Module):
    
        def forward(self, x):
            """
            :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
            """
            num_fields = x.shape[1]
            row, col = list(), list()
            for i in range(num_fields - 1):
                for j in range(i + 1, num_fields):
                    row.append(i), col.append(j)
            return torch.sum(x[:, row] * x[:, col], dim=2)
    
    
    class OuterProductNetwork(torch.nn.Module):
    
        def __init__(self, num_fields, embed_dim, kernel_type='mat'):
            super().__init__()
            num_ix = num_fields * (num_fields - 1) // 2
            if kernel_type == 'mat':
                kernel_shape = embed_dim, num_ix, embed_dim
            elif kernel_type == 'vec':
                kernel_shape = num_ix, embed_dim
            elif kernel_type == 'num':
                kernel_shape = num_ix, 1
            else:
                raise ValueError('unknown kernel type: ' + kernel_type)
            self.kernel_type = kernel_type
            self.kernel = torch.nn.Parameter(torch.zeros(kernel_shape))
            torch.nn.init.xavier_uniform_(self.kernel.data)
    
        def forward(self, x):
            """
            :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
            """
            num_fields = x.shape[1]
            row, col = list(), list()
            for i in range(num_fields - 1):
                for j in range(i + 1, num_fields):
                    row.append(i), col.append(j)
            p, q = x[:, row], x[:, col]
            if self.kernel_type == 'mat':
                kp = torch.sum(p.unsqueeze(1) * self.kernel, dim=-1).permute(0, 2, 1)
                return torch.sum(kp * q, -1)
            else:
                return torch.sum(p * q * self.kernel.unsqueeze(0), -1)
    
    
    class CrossNetwork(torch.nn.Module):
    
        def __init__(self, input_dim, num_layers):
            super().__init__()
            self.num_layers = num_layers
            self.w = torch.nn.ModuleList([
                torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)
            ])
            self.b = torch.nn.ParameterList([
                torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)
            ])
    
        def forward(self, x):
            """
            :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
            """
            x0 = x
            for i in range(self.num_layers):
                xw = self.w[i](x)
                x = x0 * xw + self.b[i] + x
            return x
    
    
    class AttentionalFactorizationMachine(torch.nn.Module):
    
        def __init__(self, embed_dim, attn_size, dropouts):
            super().__init__()
            self.attention = torch.nn.Linear(embed_dim, attn_size)
            self.projection = torch.nn.Linear(attn_size, 1)
            self.fc = torch.nn.Linear(embed_dim, 1)
            self.dropouts = dropouts
    
        def forward(self, x):
            """
            :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
            """
            num_fields = x.shape[1]
            row, col = list(), list()
            for i in range(num_fields - 1):
                for j in range(i + 1, num_fields):
                    row.append(i), col.append(j)
            p, q = x[:, row], x[:, col]
            inner_product = p * q
            attn_scores = F.relu(self.attention(inner_product))
            attn_scores = F.softmax(self.projection(attn_scores), dim=1)
            attn_scores = F.dropout(attn_scores, p=self.dropouts[0], training=self.training)
            attn_output = torch.sum(attn_scores * inner_product, dim=1)
            attn_output = F.dropout(attn_output, p=self.dropouts[1], training=self.training)
            return self.fc(attn_output)
    
    
    class CompressedInteractionNetwork(torch.nn.Module):
    
        def __init__(self, input_dim, cross_layer_sizes, split_half=True):
            super().__init__()
            self.num_layers = len(cross_layer_sizes)
            self.split_half = split_half
            self.conv_layers = torch.nn.ModuleList()
            prev_dim, fc_input_dim = input_dim, 0
            for i in range(self.num_layers):
                cross_layer_size = cross_layer_sizes[i]
                self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1,
                                                        stride=1, dilation=1, bias=True))
                if self.split_half and i != self.num_layers - 1:
                    cross_layer_size //= 2
                prev_dim = cross_layer_size
                fc_input_dim += prev_dim
            self.fc = torch.nn.Linear(fc_input_dim, 1)
    
        def forward(self, x):
            """
            :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
            """
            xs = list()
            x0, h = x.unsqueeze(2), x
            for i in range(self.num_layers):
                x = x0 * h.unsqueeze(1)
                batch_size, f0_dim, fin_dim, embed_dim = x.shape
                x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
                x = F.relu(self.conv_layers[i](x))
                if self.split_half and i != self.num_layers - 1:
                    x, h = torch.split(x, x.shape[1] // 2, dim=1)
                else:
                    h = x
                xs.append(x)
            return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
    
    
    class AnovaKernel(torch.nn.Module):
    
        def __init__(self, order, reduce_sum=True):
            super().__init__()
            self.order = order
            self.reduce_sum = reduce_sum
    
        def forward(self, x):
            """
            :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
            """
            batch_size, num_fields, embed_dim = x.shape
            a_prev = torch.ones((batch_size, num_fields + 1, embed_dim), dtype=torch.float).to(x.device)
            for t in range(self.order):
                a = torch.zeros((batch_size, num_fields + 1, embed_dim), dtype=torch.float).to(x.device)
                a[:, t+1:, :] += x[:, t:, :] * a_prev[:, t:-1, :]
                a = torch.cumsum(a, dim=1)
                a_prev = a
            if self.reduce_sum:
                return torch.sum(a[:, -1, :], dim=-1, keepdim=True)
            else:
                return a[:, -1, :]

    五、项目目录

     六、数据集

    0    293    24    15    0    0    4    1    unknown    daily    1    0    0    0    0    MD470265    BR470265    PV002320    CT300532    PF470265    10    0    2    0    W81677,W3677,W3045,W4539
    0    582    20    14    20    3    5    0    male    monthly    0    1    0    1    0    MD470265    BR470265    PV419710    CT884433    PF470265    0    0    4    0    W624,W317,W1467,W243,W614
    0    351    16    34    66    20    9    1    unknown    daily    0    1    0    0    0    MD116493    BR641329    PV419710    CT378940    PF470265    10    0    2    0    W321,W3379,W1239,W5012,W321,W118,W6112,W19051,W928
    0    329    9    17    0    0    3    26    female    daily    1    0    0    1    0    MD470265    BR470265    PV789063    CT718720    PF470265    10    1    2    0    W3931,W1233,W188
    0    296    17    25    68    18    6    3    female    unknown    0    0    0    0    0    MD356265    BR896654    PV596305    CT956718    PF470265    1    1    3    0    W4867,W149,W18289,W8912,W600,W272
    0    630    28    11    30    10    4    0    male    daily    1    0    0    0    0    MD470265    BR470265    PV625797    CT470265    PF470265    10    0    4    0    W7975,W16005,W5975,W40542
    0    500    52    19    0    0    6    1    male    daily    0    1    0    1    0    MD596647    BR641329    PV921755    CT470265    PF470265    10    0    2    0    W206,W2106,W1884,W243,W166,W145
    0    324    30    24    0    0    11    2    unknown    daily    1    0    0    0    0    MD470265    BR470265    PV008215    CT470265    PF470265    10    0    2    0    W3046,W221,W2011,W5640,W3037,W796,W1887,W386,W200,W1290,W1061
    1    441    5    15    7    2    5    1    male    daily    1    0    0    1    0    MD079507    BR803759    PV596305    CT470265    PF470265    10    0    5    0    W4692,W7395,W17695,W3280,W2323
    0    725    100    27    0    0    9    5    male    daily    1    0    0    1    0    MD470265    BR470265    PV545833    CT545833    PF470265    10    0    3    0    W6249,W1307,W8223,W2720,W997,W390135,W16165,W628,W1026
    0    419    6    8    0    0    3    0    male    weekly    1    0    0    0    0    MD283943    BR800099    PV261182    CT296109    PF470265    10    0    3    0    W14887,W41914,W661
    0    318    0    21    137    52    8    4    unknown    daily    1    0    0    0    0    MD180416    BR214617    PV426364    CT449425    PF470265    10    0    4    0    W8618,W4632,W3037,W78640,W164,W1528,W27911,W4745
    1    427    29    21    0    0    5    5    unknown    daily    1    0    0    0    0    MD470265    BR470265    PV625797    CT383873    PF470265    10    0    3    0    W17996,W148421,W2642,W97,W1748
    0    329    11    11    118    26    2    0    unknown    weekly    0    1    0    0    0    MD634010    BR641329    PV596305    CT470265    PF470265    0    0    3    0    W5074,W4418
    0    593    49    18    86    21    4    1    unknown    daily    1    0    0    1    0    MD470265    BR470265    PV419710    CT378940    PF072986    10    0    1    0    W5516,W605,W149,W263
    0    368    1    20    0    0    9    1    female    unknown    0    0    0    0    0    MD665915    BR896654    PV280311    CT470265    PF470265    0    0    3    0    W11058,W19321,W54281,W2323,W48556,W6805,W3286,W49355,W272
    0    260    0    23    0    0    8    15    unknown    daily    1    0    0    0    0    MD577600    BR896654    PV625797    CT022556    PF470265    10    0    2    0    W973,W145,W11626,W118,W711,W10364,W3694,W2105
    1    479    2    36    93    29    11    4    male    daily    1    0    0    1    0    MD470265    BR470265    PV929066    CT929066    PF470265    10    0    4    0    W1114,W1462,W26500,W57,W33105,W2767,W6912,W13276,W7051,W26500,W4937
    1    506    65    19    80    19    4    0    unknown    daily    1    0    0    1    0    MD470265    BR470265    PV426364    CT470265    PF470265    10    0    3    0    W9887,W151,W973,W164
    1    395    0    8    0    0    3    1    unknown    daily    0    1    0    0    0    MD782190    BR641329    PV587951    CT587951    PF470265    10    0    3    0    W3647,W242,W1136
    1    585    58    16    72    24    3    5    male    daily    1    0    0    1    0    MD281170    BR800099    PV421141    CT965658    PF470265    10    0    2    0    W856,W1886,W18221
    0    606    100    30    0    0    10    6    male    weekly    0    1    0    1    0    MD661088    BR641329    PV545833    CT545833    PF470265    2    0    3    0    W51,W5056,W8565,W685,W57,W1209,W57,W26039,W463,W360
    0    463    42    17    40    17    6    0    female    weekly    0    0    1    1    0    MD470265    BR470265    PV462839    CT470265    PF470265    10    0    5    0    W10442,W1277,W1218,W2508,W10810,W1772
    0    325    9    31    0    0    10    1    unknown    monthly    0    1    0    0    0    MD470265    BR470265    PV625797    CT692736    PF470265    1    0    4    0    W3828,W6881,W272,W156789,W2323,W3370,W2928,W605,W381,W622
    0    527    5    15    0    0    3    95    female    weekly    0    1    0    1    0    MD566292    BR641329    PV462839    CT470265    PF470265    10    0    4    0    W3413,W4572,W54239
    1    361    7    32    0    0    10    1    unknown    weekly    0    1    0    0    0    MD470265    BR470265    PV596305    CT923918    PF470265    10    0    3    0    W23051,W4876,W730,W734,W117376,W1238,W316,W609,W896,W243
    1    351    0    12    74    14    4    16

    七、数据预处理

    import torch
    import pandas as pd
    import numpy as np
    # df = pd.read_csv('zhihutest.csv', sep="\t")
    # # 类别特征(16)
    # fixlen_category_columns = ['m_sex', 'm_access_frequencies', 'm_twoA', 'm_twoB', 'm_twoC',
    #                            'm_twoD', 'm_twoE', 'm_categoryA', 'm_categoryB', 'm_categoryC',
    #                            'm_categoryD', 'm_categoryE', 'm_num_interest_topic', 'num_topic_attention_intersection',
    #                            'q_num_topic_words',
    #                            'num_topic_interest_intersection'
    #                          ]
    # # 数值特征(7)
    # fixlen_number_columns = ['m_salt_score', 'm_num_atten_topic', 'q_num_title_chars_words',
    #                          'q_num_desc_chars_words', 'q_num_desc_words', 'q_num_title_words',
    #                          'days_to_invite'
    #                         ]
    # target = ['label']
    # text = ["q_title_words"]
    # #总列数 = 25
    # #数值列数: 7
    # #数值+类别 = 23
    # cols = target + fixlen_number_columns + fixlen_category_columns + text
    # fout = df[cols]
    # print(fout)
    # fout.to_csv("zhihu.txt", mode='a', header=False, index=False,  sep='\t')
    # df = pd.read_csv('zhihu.txt', sep="\t", usecols=[24])
    # print(df.iloc[1])
    a = np.random.rand(2,20,64)
    b = np.float32(a)
    lstm = torch.nn.LSTM(input_size=64, hidden_size=10, num_layers=4, batch_first=True)
    out1, (h1, c1) = lstm(torch.from_numpy(b))
    print(out1.size())

    八、word权值

    W1    0.12561196 -0.57268924 -0.14478925 -0.05249426 -0.036886618 -0.2870279 0.02139771 -0.30615488 -0.25206482 -0.3992112 -0.17992422 0.028214198 0.11070523 -0.1460403 0.04890066 -0.15179704 0.21902226 0.11398847 0.17869796 -0.37419954 -0.017541448 -0.44878265 0.16766284 -0.27377427 -0.28875342 -0.045036126 0.31909382 -0.25074694 -0.47739008 0.23619196 -0.22078764 -0.0702696 0.06263166 -0.20763765 0.09704907 -0.00838474 -0.1655464 0.20232989 0.056676675 -0.22352925 -0.097002655 0.008474487 0.12560087 -0.14395201 -0.12497431 0.04856251 -0.43297425 -0.03700408 0.02355051 -0.19106597 -0.29243338 -0.15640189 0.17767948 0.11662535 0.18274452 0.5316436 -0.2775053 -0.33759427 -0.47088715 -0.33324385 -0.10586255 0.15533929 -0.3057963 0.068118274
    W2    3.224765 2.2482696 -0.511986 -0.5329892 -0.94346815 0.7848761 -1.7466401 3.4368124 -1.3447977 1.507573 1.763892 0.04300465 -0.53149086 -0.65491444 -0.2475965 -0.7804506 -3.8735473 -1.604378 -2.8054268 1.866276 -1.2134124 -0.30723703 -1.6847026 0.97294986 0.16995417 1.5299996 1.7447681 -0.9420429 -0.18870392 2.814148 2.876312 0.4483537 -1.5696942 -1.5657848 0.30403557 -3.5486064 1.671584 1.8766971 -0.17914975 -1.1457291 -2.621439 1.3077521 0.5922553 -1.5568334 0.43938446 1.4022862 -0.52703506 -0.30781657 0.8833728 1.6657466 1.6585606 0.9285377 0.6063592 1.0806038 2.1549644 5.8111606 -1.8977596 -3.3637497 -2.6063447 0.18249024 4.9076834 0.36147368 1.0398824 0.9631124
    W3    -0.985937 0.11307016 0.012898494 -0.6822068 -0.7477715 -1.5083971 0.18873732 0.8862959 -0.13333559 0.44745678 -1.6474508 0.54304326 0.1943109 -1.5173006 0.052836645 0.6476899 -0.76287407 -1.7811123 1.4354367 0.9168619 -1.1915683 0.5675262 1.8146971 0.77335197 1.4627154 -1.3452206 -0.2794901 0.58791244 0.30307648 -0.33495015 -0.22242078 -0.8549375 0.8336274 0.508776 0.22765012 -0.2259076 1.1074381 -0.5277541 -1.149053 0.26546764 0.10992061 -1.0641412 -1.0234233 0.03276292 -0.09635724 0.6330204 0.2264861 1.5840359 -0.8735956 -0.7754864 0.58081776 -1.372525 0.4713049 0.058009893 -0.48929968 -1.101398 1.3791919 -0.4130507 0.2744553 0.5020652 1.0926367 0.5763596 -0.5411724 -0.014770236
    W4    -0.3367663 0.039051324 0.8155926 0.8351733 -0.3697751 0.9960453 -0.13953939 -0.049142037 -0.45639792 0.4392651 0.6774147 0.9220153 1.0198072 -0.91693723 0.13157529 -0.6226927 0.44958356 -0.7831065 -0.42610598 0.11303744 0.35572204 -0.87150127 -0.035929345 0.86534244 -0.5618806 -0.27772513 1.1201656 -0.461846 0.33991888 0.66242754 -0.30612975 0.4649827 -0.7512291 0.036252256 0.25191557 -0.75410306 -0.05073685 0.88312984 0.09898741 -0.19736929 -0.010180523 -0.16381735 0.51063424 -0.17956156 -0.052385904 -0.1254029 0.8307363 0.55609304 -0.8516192 -0.8239702 0.17293625 1.1451373 0.5673448 -0.04732473 1.1474786 -0.017807899 -0.7725623 -1.4019336 -0.12553573 0.4392915 -0.26784056 0.5315127 0.3335459 0.5802117
    W5    0.3074205 -1.0977745 0.7528213 0.6299011 0.1975374 0.12329541 -0.4517387 1.1735839 -0.29371184 -0.67508066 0.87006754 0.39526275 1.9724985 -0.38147122 0.15378407 -0.76870894 0.5208314 -0.7383249 0.06193004 0.14570965 0.14756544 -1.2027895 0.8386833 0.39646402 0.039275337 -0.25650737 2.2237034 0.13269994 -0.71340007 2.1263545 -0.20604254 0.6451037 0.4070587 -0.1542667 -0.14662863 -1.3520072 0.048581414 1.6555899 0.8380814 -0.3378566 -1.2679185 0.46683076 0.63580114 -0.6043699 0.28784633 -0.08640877 0.607381 0.9474251 -0.5116208 -0.5747555 -0.01787345 0.30581647 1.0210615 -0.55992776 0.29841763 0.6909648 -1.0312475 -1.847506 -1.0236931 -0.3247338 0.10282218 1.0663713 -0.519434 0.19721127
    W6    0.61494493 0.5444025 2.0673835 3.2731245 0.77909213 1.4477565 0.6443283 0.12583794 -1.2945272 -0.24951132 -0.7445469 -0.8159158 0.8766542 2.9739044 1.6059983 2.847475 0.15741627 -1.8566948 -0.77201325 -0.069421366 0.19065484 0.7409739 -2.4119422 -0.43610424 0.075945236 4.1809616 -2.794368 -0.6762555 0.6160988 -0.27621445 -2.6909072 1.6613462 -1.222234 1.379047 -0.8762335 -1.0866759 -1.227303 0.8094839 -0.991168 0.17862324 -0.8094492 1.3342611 0.28272808 0.56088215 0.5823514 0.19371018 2.4717267 0.40118575 0.6642596 0.70348346 -0.019687142 1.3274803 2.0931566 0.2689294 -1.8053719 1.6665238 0.62242544 1.0591005 -0.47918996 -4.0045266 -1.3526242 1.8648475 -0.0633559 0.12525558
    W7    -1.0922098 -2.088952 -1.9467407 -0.095274135 0.73766685 -1.3324981 -0.23863605 2.1223712 -0.5489143 0.10345685 -2.6016045 0.9703814 1.6614031 -0.15970723 -1.9562075 -1.6666706 1.0506083 -0.2557899 1.8368233 -0.067096 -1.8985215 -1.0996249 1.3290927 2.4161265 -0.47213927 0.3849108 0.0034949605 8.546595E-4 -2.08523 1.1498582 -0.7590559 -2.146992 1.9317857 0.49643072 1.058629 -0.46381965 -0.30364427 1.3321518 -0.63078755 -0.3178737 0.8979813 -0.28108698 -0.25992736 -1.065179 -1.1152307 -0.19449191 -0.67832875 2.0070937 0.9670678 -0.04272201 -0.90389216 -2.5703933 0.9788946 1.5174438 2.2033823 0.6190921 1.6634784 -0.84447604 -2.3429253 2.1273015 1.1511455 1.0568366 -2.179524 -0.12413034
    W8    1.8850589 -1.3846935 -1.3290175 -2.2212234 1.6148094 -0.21146151 1.894543 -0.024923552 -1.005133 -1.5330166 -0.9205202 -0.09463944 -0.98946047 -1.2322261 0.7949223 -1.8138707 0.96810377 1.449417 2.6712353 -0.51707065 0.72894305 0.8826398 -2.0346794 -1.7057327 -1.7577009 0.838245 1.0771643 0.6127623 -3.0707555 -2.7187696 0.5289711 0.2569929 2.7275715 1.4968799 -0.2452353 -0.6800362 -1.8762752 2.193181 0.4007349 0.011191922 -0.6259204 0.9192341 0.40032268 -0.9789072 -1.5940189 2.2916598 -1.561065 0.557704 0.6337721 -0.3848408 -1.9655528 0.67603004 -1.2684052 2.781325 1.8380938 1.5696958 -1.6227767 0.070153154 -2.0467916 -2.485381 0.30974072 1.0064175 -0.13441803 -1.329612
    W9    3.5464244 1.6036447 -2.5293458 3.1333735 -2.0442762 -2.9851632 1.4346856 -2.4663894 -0.025672281 0.08756648 0.6147293 0.023265101 -0.10919095 2.7192488 0.87204725 -0.2820551 -0.30785862 0.12090092 -3.5130498 -2.6789558 -0.5764017 -0.5846258 0.5859751 -0.35626814 -2.3668826 0.44950518 1.2879786 1.4998803 1.0923749 -0.00976043 -0.8428258 -1.0809498 1.7901529 0.2144969 -1.7267967 -0.090383604 -1.363677 -2.576643 -1.1217225 -2.468348 -1.0396751 0.53285617 1.3147327 0.81647974 -1.5917714 2.1906781 0.9483048 -2.4917877 1.8687235 0.27996638 -0.43738556 -0.095006086 -0.020946838 -1.6918502 -1.3026488 1.3057652 -0.9144045 2.3955936 -0.19458054 -3.215039 -3.495818 1.2492332 2.0508204 -0.35456043
    W10    -2.2662137 -1.9574718 -0.13302466 -4.3436356 1.7448403 -2.1677175 -0.26545736 2.2401462 2.2677085 -1.309996 -2.1214437 -1.177141 -2.4372704 -1.3198556 -1.1695244 1.168789 1.015884 -0.594806 -1.1740255 -0.57405066 0.7847314 1.2742159 -1.7266214 1.6874171 2.032907 -0.17446232 1.4693882 -1.7394896 -0.18483235 -0.53408825 1.0647908 -4.9963064 2.1754081 0.43848997 -1.308212 3.0940924 -0.8896101 0.23257956 -1.4059472 1.0995824 0.3212307 -1.3197386 0.8861918 0.43288517 0.55570036 2.0030572 -1.6898836 4.789943 2.012685 -0.87376535 1.3392564 -0.39256766 0.12622279 -1.9374942 1.4249638 -0.3851611 1.288376 -1.7455869 -2.9280279 0.11417221 0.39837018 -0.08797984 0.39745057 -1.1663383
    W11    -1.8233486 2.0506058 -3.2515428 -1.1880394 -1.3893188 1.0621682 -1.3315074 2.1461437 1.3952633 -1.6875589 -2.5420825 -1.9849988 -0.47683784 -1.8435638 1.514539 -0.80441093 -0.99173623 -1.4936146 1.9138579 0.7374752 1.9431698 -0.96729654 1.6452389 1.7321718 0.911836 0.8322483 -3.1768286 2.255713 3.3464186 1.1475248 -0.16955696 -1.7612045 -5.397299 -2.5172055 -0.2939885 -2.041203 1.5774808 -1.1870066 0.16685703 2.8498793 1.9320439 -1.8997726 2.4276707 -0.7750878 -3.3764153 -2.0251844 -0.87747693 -0.39764187 1.0188824 -0.039743427 1.9104943 -2.0641706 -1.1724603 0.58007103 1.0246912 0.8077267 1.3156842 -2.4531717 -2.1099017 -0.19247615 0.46387395 0.19127782 0.6958765 -0.5772965
    W12    0.48556754 -0.3782651 2.2997975 2.261273 0.4083756 3.128249 1.3454522 -1.7309117 0.4082361 0.42954805 -0.10541588 -0.0100395875 0.30382836 0.80302244 1.6591026 -0.93889457 1.3209151 0.3527157 -0.63149524 -0.72112906 1.0052388 -0.032466557 -3.2750657 -3.146142E-4 -2.0860028 1.3553876 -1.198162 -0.52659255 -0.22112496 0.06325757 -2.5704465 0.06943303 -0.43335226 1.0315912 0.69085115 -0.6642619 -1.1102613 -0.08351819 0.83566225 -1.1449703 0.8499442 -0.0996228 0.74162143 0.6656831 1.24156 -0.97076166 3.9931488 -1.1789719 -0.6533613 -0.4903609 -1.2098149 1.9589139 0.26404676 -0.37190273 -0.16028062 0.525885 -1.6114401 2.066881 0.21920817 -2.1344752 -2.06417 0.51205325 -0.110927485 0.29716808
    W13    0.19286776 0.97593105 1.1760665 1.6522682 -0.7669645 1.1909955 1.0924758 -2.266674 1.0367627 0.7968129 -0.28660434 0.4161866 0.7810446 -0.99575454 -0.16184656 -0.49383563 0.8447127 -0.16373028 -1.593914 1.0633731 1.0397971 -1.7984707 -0.14076096 0.11114727 0.5544285 0.63537514 1.7927763 -1.0619704 -0.440642 -0.23222017 -1.1853625 0.29143685 0.08879236 0.5184061 1.3055124 0.25131506 0.61567587 0.98425937 -0.18469898 -0.6938831 -1.7825798 -0.96643585 -1.0872706 0.27524996 0.6470144 1.2886372 1.4324383 1.8343577 -1.5068318 -1.7749896 -0.26814702 1.9943746 -0.58553463 -2.0208232 0.43590528 -0.6344104 -1.0268904 0.6773552 0.42696643 0.44570908 -1.4512285 -0.4421764 -0.16073075 0.9451965
    W14    -1.4975872 -0.05171344 -1.5837607 2.6378388 0.5145237 3.5511696 0.33359003 -2.8335233 2.8076153 0.21341299 -1.8630219 0.65512085 -1.1321408 -0.19878303 0.86941725 -1.0771896 0.3892032 -1.1617596 -1.9551601 -0.9577428 1.016048 -1.167049 0.64454824 -0.24170674 -1.0976492 -0.7169287 -1.5239555 0.2005307 -0.59923834 -1.1073819 -1.8658218 -1.3666234 2.4345636 1.9346207 0.28856465 2.6725085 0.41708738 -1.827212 0.38932738 1.1355253 -0.7327769 -1.5357171 -2.3008456 3.1287837 -0.61515254 0.3323807 0.57125115 1.2867411 -1.6832256 -1.9866133 0.033925824 3.0166345 1.7175115 0.07027325 3.2417045 -1.2321491 -0.91068125 2.1055794 0.013131445 1.0040492 -1.9852926 -1.4327652 0.023684088 -0.13692991
    W15    3.5057127 -0.043885194 -0.8190997 -2.2906077 -1.5311654 -0.499458 1.5530281 -1.6159873 -0.15570664 1.528915 0.84011084 1.2226365 -2.3276448 -0.24316585 0.9175231 -2.561437 -0.14275528 2.7427154 -1.5982548 -1.6611183 0.44843268 0.9835045 -1.5639778 -1.2193594 -1.581374 0.2868971 0.47565898 4.240041 -0.5649336 0.6045978 0.38314477 0.07715987 3.9825556 1.4035237 -2.9439008 0.23149264 -3.2482007 -2.0905168 1.8459865 1.058474 0.3007178 -0.0017837499 0.5709634 -0.9801643 0.7286936 5.197704 1.5633762 -2.095106 -0.590826 -0.5677418 -2.1619842 2.2878773 0.12194729 0.25743377 -2.256691 -0.7038811 -1.2787365 3.2720375 1.902351 -2.6107993 -2.0846295 -0.54492134 0.5879578 -0.5057746
    W16    -2.0452962 2.7790704 0.8599819 -4.8954825 -2.7227967 -2.4925845 -0.20474386 2.0819283 -0.18221791 -1.5121988 -0.3095648 1.3191587 -0.04623105 -1.705462 -3.7203848 2.1164687 -1.0359333 -5.023697 1.8302708 0.07236947 -1.2637256 -0.5877139 0.49410626 3.4786513 2.2011156 1.2362626 0.52787966 -1.7504953 -2.0405428 1.6479661 -0.78672725 0.94336957 0.67697954 0.5614243 0.7758589 0.67669255 -0.46426576 1.7613895 -0.90594846 -1.0265179 2.3652992 -1.7416269 -0.18174197 0.6209257 -3.5529306 -0.9391469 -1.575419 1.7552205 1.0178684 -1.1110094 2.0775204 -0.75388825 -2.1323745 2.0622365 0.96742326 0.10655046 -1.2731837 -0.94141895 -1.8334563 1.0716884 2.8153598 -0.43824026 1.9744577 -0.28706712
    W17    -1.7620399 4.528447 -1.7671988 0.5494214 2.3973732 0.02015205 -0.18269481 2.4271953 -1.1222895 -0.9618898 1.0856631 -1.5881574 2.8190672 2.087121 6.075676 0.6773515 -1.0749277 -0.014855039 -0.58664525 -3.5484717 -4.6698184 5.4795012 -4.5385756 2.3218167 0.035409868 -3.143663 -4.8714123 -3.0491023 0.84198755 -4.497052 0.9888961 4.2981386 2.3218338 -0.022358343 -4.0753236 0.13450176 1.5776345 -4.4275513 -1.3627177 0.7150768 -0.785076 -1.1661426 3.6368656 1.9779291 1.6386815 2.65138 -3.120982 0.9146857 3.7670047 3.9425747 2.5634406 -2.686382 -1.929005 -6.3999925 0.35940954 2.7331047 3.0736244 -2.3717353 4.7949843 -2.1107779 0.6746109 1.2574309 3.2318926 -2.500539
    W18    0.15473528 0.6167804 1.7167026 2.9121165 0.269503 0.10636123 3.4708831 -1.5382771 1.1992759 0.18168284 2.5019782 3.612141 3.9854047 -1.5882664 -1.5477096 -3.0422456 1.7823238 2.0636547 -1.4370937 1.3052343 5.6378145 -2.3541868 -3.6185682 -0.40783572 -3.478546 -1.3083613 1.7696565 -0.93460566 -2.4779005 -2.0160322 1.5108185 1.0226414 1.5649385 2.4810264 -1.0504055 1.5927712 -1.2125096 0.17816901 1.1041695 0.82321566 3.084892 0.049549546 -1.0349815 -0.20464003 4.185163 -0.27878857 1.4038125 0.021971399 4.4434958 -2.2899518 -0.72584987 2.5909247 -3.9081511 -0.26109782 0.0013557074 0.2290644 -2.0305889 2.272198 1.7540497 2.0196679 -0.9229598 1.0903741 1.6841339 0.24193175
    W19    -2.447269 -0.8793359 -1.1523393 -1.4507253 0.9732452 -1.527036 -2.315722 3.159439 -3.6261277 1.7409768 -4.9887567 1.2373376 -1.0411987 1.991606 1.9732735 1.0008464 0.69035655 0.6105124 1.1336687 0.1371391 -3.5243258 1.2731444 1.0261899 1.1641322 -1.2143499 0.8357258 -2.544022 0.4296619 1.7687589 -1.1008877 0.94163805 -1.4822719 -0.9027046 1.8346741 0.176702 -0.15226924 1.937083 -1.995543 -1.2816978 -1.1957624 1.4393616 -3.2401497 -0.22859575 1.2283462 -1.3554864 1.7207322 1.0686255 0.16442522 1.5106804 1.7412056 -2.0441098 -2.939095 3.0576096 0.05931877 -2.057031 0.87080085 2.4195716 1.8542806 -1.4897007 4.05886 1.440093 -1.5045505 -0.341551 -0.8826472
    W20    2.4317594 -2.5110888 -3.7020335 -0.59279317 -0.81826645 -0.73744965 0.26426792 -1.479243 3.2323906 0.13760854 -0.2138856 -1.0425788 0.3526553 0.1300069 0.71234524 -4.468805 1.0853753 -0.62058234 0.7129842 -0.8540014 -4.1106744 2.2727025 1.6877469 2.0950649 0.7688551 2.1442516 2.98514 -2.1409988 -5.0907016 1.5532246 -0.9199374 -3.2675993 0.17646134 1.6836916 4.1119466 -0.22145776 -3.7103643 4.236737 -2.9532087 -0.88752675 -0.49633628 1.5139341 -0.4011149 1.8202754 -1.7638682 -1.6365256 1.3082557 -2.4150834 0.96422815 0.587179 -1.5374656 0.86646986 0.7660016 -2.4612026 2.227851 1.3380407 -1.6148647 1.0061222 -3.6261077 1.2789214 -2.1962628 1.699905 -1.995816 -1.5530183
    W21    -0.2703056 1.8392719 -1.2054503 -0.48395538 -5.7971487 -2.6331599 4.0352526 -1.5096587 -2.3522365 0.10684795 1.370795 1.4191186 -3.4433208 0.49266464 -1.1942794 0.26438513 -0.9935388 4.7443166 -1.7586547 -0.5104129 -0.8527214 1.042357 -2.320489 -2.2102757 -0.86960524 -2.686497 1.2472658 4.5115643 5.1399703 0.69941145 -0.26024258 4.8025093 2.0294304 3.244914 -2.8054326 -0.4343817 -0.911383 -3.5767043 4.016076 3.488138 0.01705379 -0.58362496 0.18510656 -1.924577 3.2083511 3.7377274 1.9063605 -0.41186696 1.2022467 0.6376134 -0.7062932 -0.94675505 0.2987619 -0.059485022 -4.064135 -1.4685118 -0.85666335 4.5056853 2.71125 -2.231179 -2.8568563 -1.2534839 -0.81981933 -1.4769105
    W22    -0.2617614 -0.99235636 -1.3360858 0.012544698 3.7825391 0.36606872 0.3767669 2.223355 0.6381537 -1.3306204 -0.85466087 -0.96644515 0.9590631 2.5559342 2.4542904 -0.07758791 0.145989 -1.2972796 0.8034251 0.53936785 -1.5072265 2.3272536 -0.5694311 -0.10740515 1.7682558 -1.1184568 -0.77690285 0.20744331 -2.3669522 -0.7972892 -0.8343902 -0.040297605 3.9376059 0.5414876 -1.1917127 -0.90247643 -0.1319108 -0.9595053 0.08407973 1.3535333 -1.7921218 1.0143303 -0.6247334 -1.3810456 1.6846753 2.7751825 0.50540644 1.2734851 1.1276027 2.1919713 -0.6157536 -1.3058081 0.113103576 -2.5760565 -0.5319981 1.0121036 2.246896 0.9857944 1.3536208 0.036636755 -0.5948883 0.6924922 -1.4375504 -1.3249679
    W23    -0.9651987 1.489567 3.7892187 -3.7961268 -0.29648548 0.053300433 0.6702064 -2.208451 -2.7420046 -0.63721174 1.0220832 1.7955016 1.1505986 -0.70014924 0.21482033 -0.37031448 1.1962228 1.9465573 1.6602011 -3.2899258 -1.1198668 2.1969705 -4.3133745 0.3307541 -1.5535967 2.6950045 -0.45252553 2.102209 -0.023604078 1.0128834 -0.07702835 0.3234905 -1.1599221 -1.5770309 -1.531138 0.11932727 0.697383 -3.0229428 3.1187625 -1.9312394 -0.030920366 -0.30986938 -1.5411919 1.8363559 0.9686007 1.0706614 0.48427907 0.5915161 -1.5009121 -1.1401902 -0.18640403 -0.78009117 -1.8338839 1.1936028 -1.535172 -2.3523526 -2.109703 -0.8195621 -1.2982402 -1.0543036 2.4673889 -1.792173 1.935543 2.516964
    W24    -1.8251535 1.8876748 1.0971074 -1.4578272 1.3089255 -1.9490236 1.6183308 0.26262164 3.3819761 -2.1310601 -2.011434 2.4080157 0.0038511725 -0.26461595 -1.7176443 0.91868275 0.89175886 -0.14072378 0.12834296 0.3248617 2.096536 1.138677 -3.444761 1.2823188 0.6991723 0.1784932 0.9018187 -1.0168108 0.76319486 -2.3365304 -1.036325 -3.0408838 2.3307629 0.34960595 -2.049877 3.4054108 -2.1324139 -0.29900455 0.445328 3.2364163 2.3072093 -0.0030180116 -1.2624956 -0.3893402 -1.4439074 1.735823 1.4525877 5.1423 3.5513415 -1.6825833 -0.27627033 1.696878 -1.2332087 -0.10024076 -1.4734092 -0.9189101 0.16650778 0.6176878 -2.812203 -0.6697208 -1.8221841 0.029978639 -0.15918644 -0.34443417
    W25    -0.33015317 -0.38741946 1.6975459 -0.89117235 -1.5155373 2.9153888 0.8343981 -0.26450408 -1.1157738 -2.2361076 2.7365768 0.74073046 -0.5137639 -0.1309951 -0.24929214 -2.1488292 0.84538704 1.5118271 0.7902999 -1.2011646 0.7888036 1.3562698 -2.5850587 -0.29860246 -4.388407 1.7568411 -0.018577736 0.009535411 1.2861177 -0.39966074 1.4036655 2.6377625 1.0923429 1.3092021 -1.6747708 -0.14070179 -2.3365026 0.2586771 1.8804626 1.289512 2.9491673 1.5510907 -0.679748 0.31118867 0.8567329 -0.8637471 2.0013115 0.029997077 -0.42404655 -2.0511773 -1.9803513 0.5438882 -0.047828138 1.214576 -1.1743958 -1.8469224 0.20354868 2.088674 -0.10474008 -1.1028297 -0.40986 -1.366427 1.6552207 0.17932093
    W26    0.08060814 -3.5432673 1.0444154 3.1851668 -1.3206809 -1.4131405 1.3266301 -0.666004 -0.63391036 4.94799 0.493011 0.75151974 -2.107727 1.4825441 -1.1863682 -3.371569 -0.132423 0.9060172 -3.7868505 -1.2878495 5.8609447 -4.816618 2.8310385 -3.53249 0.042789314 -1.9698722 1.4370093 2.9523568 2.135083 1.5827739 1.0286837 0.96793324 2.2594564 3.6815593 -2.2796483 -2.6946282 -1.5791732 -2.1242776 -0.22565378 0.4873966 -0.5912663 -2.5352976 2.42151 0.27029833 4.5077214 4.1300635 -0.28895676 -1.4161577 -1.6664641 0.43432903 2.890621 2.5134094 -1.7249513 0.50292605 1.0906894 -0.15024044 -1.8207229 2.071012 1.4593502 2.6433182 -2.4981005 -0.048957057 1.1318222 -3.0151772

     九、参考文献

    1、https://rixwew.github.io/pytorch-fm/

    2、https://github.com/rixwew/pytorch-fm

    代码注释

    import math
    import shutil
    import struct
    from collections import defaultdict
    from functools import lru_cache
    from pathlib import Path
    
    import lmdb
    import numpy as np
    import torch.utils.data
    from tqdm import tqdm
    import pandas as pd
    
    
    class ZhihuDataset(torch.utils.data.Dataset):
        """
        Criteo Display Advertising Challenge Dataset
    
        Data prepration:
            * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature
            * Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition
    
        :param dataset_path: criteo train.txt path.
        :param cache_path: lmdb cache path.
        :param rebuild_cache: If True, lmdb cache is refreshed.
        :param min_threshold: infrequent feature threshold.
    
        Reference:
            https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset
            https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf
        """
    
        def __init__(self, text_columns, dataset_path=None, cache_path='.zhihu', rebuild_cache=False, min_threshold=10):
            self.text = pd.read_csv(dataset_path, sep="\t")
            """25包含文本列+类别列"""
            self.ALL = 25
            """23是数值和类别总列数"""
            self.NUM_FEATS = 23
            """7是数值列数"""
            self.NUM_INT_FEATS = 7
            self.min_threshold = min_threshold
            if rebuild_cache or not Path(cache_path).exists():
                shutil.rmtree(cache_path, ignore_errors=True)
                if dataset_path is None:
                    raise ValueError('create cache: failed: dataset_path is None')
                self.__build_cache(dataset_path, cache_path)
            self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True)
            with self.env.begin(write=False) as txn:
                self.length = txn.stat()['entries'] - 1
                self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32)
    
        def __getitem__(self, index):
            with self.env.begin(write=False) as txn:
                np_array = np.frombuffer(
                    txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long)
            x = np_array[1:]
            y = np_array[0]
            try:
                _text = self.text.iloc[index,24]
            except Exception as e:
                _text = "-1"
                print("-1")
            return x,y,_text
    
    
        def __len__(self):
            return self.length
    
        def __build_cache(self, path, cache_path):
            """数据集的path"""
            #即{"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            # defaults = {1列:值个数, 2列:值个数。。。。}
            feat_mapper, defaults = self.__get_feat_mapper(path)
            with lmdb.open(cache_path, map_size=int(1e10)) as env:
                field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32)#[0,0,0,0,....]
                for i, fm in feat_mapper.items():
                    field_dims[i - 1] = len(fm) + 1#[ 17  25  15  44  32  12  44   4   6   3   3   3   3   3  94  18  42 109 3  12   5   6   2]
                with env.begin(write=True) as txn:
                    txn.put(b'field_dims', field_dims.tobytes())
                for buffer in self.__yield_buffer(path, feat_mapper, defaults):
                    """buffer是一个矩阵,行是数据行数,列是值对应的索引"""
                    with env.begin(write=True) as txn:
                        for key, value in buffer: #key是行号,v是一个长度未总列数的数组
                            txn.put(key, value)
                            # key为行号,key未field_dims
    
        def __get_feat_mapper(self, path):
            #{"key1": {"a1":22, "a2":33}}
            """数据集的path"""
    
            feat_cnts = defaultdict(lambda: defaultdict(int))
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: counting features')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    """len_v:这个是每一行属性个数"""
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        """0是label, 1~8数值"""
                        feat_cnts[i][convert_numeric_feature(values[i])] += 1
                    #{"i": {"value": 个数1}, "i": {"value": 个数1}, "i": {"value": 个数1}}
    
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        feat_cnts[i][values[i]] += 1
    
                """对于每一行数据,都去统计下数值出现的次数,连续值也映射了,也统计了次数"""
            # {"1列": {"value": 个数1, "value2": 个数2, "value3": 个数3}, "2列": {"value": 个数1}, "i": {"value": 个数1}}
            feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()}
            # {"1列": {"value","value2","value3"}, "2列": {"value","value2","value3"}}
            #将每一列个数大于10的值保留
            feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()}
            #{1: {'32': 0, '37': 1, '40': 2, '45': 3, '34': 4, '38': 5, '43': 6, '44': 7, '33': 8, '36': 9, '42': 10, '30': 11, '31': 12, '35': 13, '39': 14, '41': 15}, 2: {'19': 0, '12': 1, '9': 2, '13': 3, '15': 4, '1': 5, '21': 6, '-2.0': 7, '11': 8, '7': 9, '20': 10, '10': 11, '0.0': 12, '5': 13, '6': 14, '3': 15, '2': 16, '14': 17, '-1.0': 18, '16': 19, '8': 20, '18': 21, '17': 22, '4': 23},
            #即{"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            defaults = {i: len(cnt) for i, cnt in feat_mapper.items()}
            #defaults = {1列:值个数, 2列:值个数。。。。}
            return feat_mapper, defaults
    
        def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)):
            """
            :param path: 数据集路径
            :param feat_mapper: {"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            :param defaults: defaults = {1列:值个数, 2列:值个数。。。。}
            :param buffer_size:
            :return:
            """
            item_idx = 0
            buffer = list()
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: setup lmdb')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    #列数
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    #23个0 [0, 0, 0, 0......]
                    np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32)
                    #把label直接放进去第一列,np_array的作用是标注一下,列的值的索引
                    np_array[0] = int(values[0])
    
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        a = feat_mapper[i]
                        a1 = convert_numeric_feature(values[i])
                        b1 = defaults[i]
                        b= feat_mapper[i].get(a1)
                        c= feat_mapper[i].get(a1, b1)
                        np_array[i] = feat_mapper[i].get(convert_numeric_feature(values[i]), defaults[i])
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        np_array[i] = feat_mapper[i].get(values[i], defaults[i]) #根据value得到索引序号,没有索引序号给默认值
                    buffer.append((struct.pack('>I', item_idx), np_array.tobytes())) #(idx(所在的行), np列表)
                    item_idx += 1
                    if item_idx % buffer_size == 0:
                        """bufer满了,返回清空"""
                        yield buffer
                        buffer.clear()
                yield buffer
    
    @lru_cache(maxsize=None)
    def convert_numeric_feature(val: str):
        """对每个数值特征进行先loge,再平方返回"""
        if val == '':
            return 'NULL'
        v = float(val)
        if v > 2:
            return str(int(math.log(v) ** 2))
        else:
            return str(v - 2)
    import math
    import shutil
    import struct
    from collections import defaultdict
    from functools import lru_cache
    from pathlib import Path
    
    import lmdb
    import numpy as np
    import torch.utils.data
    from tqdm import tqdm
    import pandas as pd
    
    
    class ZhihuDataset(torch.utils.data.Dataset):
        """
        Criteo Display Advertising Challenge Dataset
    
        Data prepration:
            * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature
            * Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition
    
        :param dataset_path: criteo train.txt path.
        :param cache_path: lmdb cache path.
        :param rebuild_cache: If True, lmdb cache is refreshed.
        :param min_threshold: infrequent feature threshold.
    
        Reference:
            https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset
            https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf
        """
    
        def __init__(self, text_columns, dataset_path=None, cache_path='.zhihu', rebuild_cache=False, min_threshold=10):
            self.word_vec = get_word_vector()
            self.text_vec = {}
            #self.text = pd.read_csv(dataset_path, sep="\t")
            """25包含文本列+类别列"""
            self.ALL = 25
            """23是数值和类别总列数"""
            self.NUM_FEATS = 23
            """7是数值列数"""
            self.NUM_INT_FEATS = 7
            self.min_threshold = min_threshold
            if rebuild_cache or not Path(cache_path).exists():
                shutil.rmtree(cache_path, ignore_errors=True)
                if dataset_path is None:
                    raise ValueError('create cache: failed: dataset_path is None')
                self.__build_cache(dataset_path, cache_path)
                np.save('.zhihu/text_vec.npy')
    
            self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True)
            with self.env.begin(write=False) as txn:
                self.length = txn.stat()['entries'] - 1
                self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32)
    
            self.load_dict = np.load('.zhihu/text_vec.npy', allow_pickle=True).item()
    
        def __getitem__(self, index):
            with self.env.begin(write=False) as txn:
                np_array = np.frombuffer(
                    txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long)
            x = np_array[1:]
            y = np_array[0]
            return x,y,self.load_dict[index]
    
    
        def __len__(self):
            return self.length
    
        def __build_cache(self, path, cache_path):
            """数据集的path"""
            #即{"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            # defaults = {1列:值个数, 2列:值个数。。。。}
            feat_mapper, defaults = self.__get_feat_mapper(path)
            with lmdb.open(cache_path, map_size=int(1e10)) as env:
                field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32)#[0,0,0,0,....]
                for i, fm in feat_mapper.items():
                    field_dims[i - 1] = len(fm) + 1#[ 17  25  15  44  32  12  44   4   6   3   3   3   3   3  94  18  42 109 3  12   5   6   2]
                with env.begin(write=True) as txn:
                    txn.put(b'field_dims', field_dims.tobytes())
                for buffer in self.__yield_buffer(path, feat_mapper, defaults):
                    """buffer是一个矩阵,行是数据行数,列是值对应的索引"""
                    with env.begin(write=True) as txn:
                        for key, value in buffer: #key是行号,v是一个长度未总列数的数组
    
                            txn.put(key, value)
                            # key为行号,key未field_dims
    
        def __get_feat_mapper(self, path):
            #{"key1": {"a1":22, "a2":33}}
            """数据集的path"""
    
            feat_cnts = defaultdict(lambda: defaultdict(int))
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: counting features')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    """len_v:这个是每一行属性个数"""
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        """0是label, 1~8数值"""
                        feat_cnts[i][convert_numeric_feature(values[i])] += 1
                    #{"i": {"value": 个数1}, "i": {"value": 个数1}, "i": {"value": 个数1}}
    
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        feat_cnts[i][values[i]] += 1
    
                """对于每一行数据,都去统计下数值出现的次数,连续值也映射了,也统计了次数"""
            # {"1列": {"value": 个数1, "value2": 个数2, "value3": 个数3}, "2列": {"value": 个数1}, "i": {"value": 个数1}}
            feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()}
            # {"1列": {"value","value2","value3"}, "2列": {"value","value2","value3"}}
            #将每一列个数大于10的值保留
            feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()}
            #{1: {'32': 0, '37': 1, '40': 2, '45': 3, '34': 4, '38': 5, '43': 6, '44': 7, '33': 8, '36': 9, '42': 10, '30': 11, '31': 12, '35': 13, '39': 14, '41': 15}, 2: {'19': 0, '12': 1, '9': 2, '13': 3, '15': 4, '1': 5, '21': 6, '-2.0': 7, '11': 8, '7': 9, '20': 10, '10': 11, '0.0': 12, '5': 13, '6': 14, '3': 15, '2': 16, '14': 17, '-1.0': 18, '16': 19, '8': 20, '18': 21, '17': 22, '4': 23},
            #即{"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            defaults = {i: len(cnt) for i, cnt in feat_mapper.items()}
            #defaults = {1列:值个数, 2列:值个数。。。。}
            return feat_mapper, defaults
    
        def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)):
            """
            :param path: 数据集路径
            :param feat_mapper: {"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            :param defaults: defaults = {1列:值个数, 2列:值个数。。。。}
            :param buffer_size:
            :return:
            """
            item_idx = 0
            buffer = list()
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: setup lmdb')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    #列数
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    #23个0 [0, 0, 0, 0......]
                    np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32)
                    #把label直接放进去第一列,np_array的作用是标注一下,列的值的索引
                    np_array[0] = int(values[0])
    
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        np_array[i] = feat_mapper[i].get(convert_numeric_feature(values[i]), defaults[i])
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        np_array[i] = feat_mapper[i].get(values[i], defaults[i]) #根据value得到索引序号,没有索引序号给默认值
    
                    self.text_vec[item_idx] = pad_sequences(values[self.NUM_FEATS + 1], self.word_vec)
    
                    buffer.append((struct.pack('>I', item_idx), np_array.tobytes())) #(idx(所在的行), np列表)
                    item_idx += 1
                    if item_idx % buffer_size == 0:
                        """bufer满了,返回清空"""
                        yield buffer
                        buffer.clear()
                yield buffer
    
    
    
    @lru_cache(maxsize=None)
    def convert_numeric_feature(val: str):
        """对每个数值特征进行先loge,再平方返回"""
        if val == '':
            return 'NULL'
        v = float(val)
        if v > 2:
            return str(int(math.log(v) ** 2))
        else:
            return str(v - 2)
    
    def get_word_vector():
        topic_word_vector = {}
        with open(r"D:\SCI资料\textDeepFM_attention\examples\word_vector\word_vectors_64d.txt", "r") as f:
            fList = f.readlines()
            for fLine in fList:
                rowList = fLine.split("\t")
                k = rowList[0]
                v = rowList[1].replace("\\n", "").replace("\n", "").split(" ")
                topic_word_vector[k] = v
        return topic_word_vector
    
    def pad_sequences(text, topic_word_vector):
        maxlen = 40
        sentence = []
        for t in text.split(","):
            t = t.replace(" ", "").replace("-1", "")
            if not t:
                return np.zeros((maxlen, 64))
            v = topic_word_vector[t]
            sentence.append(v)
        if maxlen > len(sentence):
            _add = np.zeros((maxlen - len(sentence), 64))
            sentence_vec = np.vstack((np.array(sentence, dtype=np.float32), _add))
        else:
            sentence_vec = np.array(sentence[:maxlen], dtype=np.float32)
        return sentence_vec
    import math
    import shutil
    import struct
    from collections import defaultdict
    from functools import lru_cache
    from pathlib import Path
    
    import lmdb
    import numpy as np
    import torch.utils.data
    from tqdm import tqdm
    import pandas as pd
    
    
    class ZhihuDataset(torch.utils.data.Dataset):
        """
        Criteo Display Advertising Challenge Dataset
    
        Data prepration:
            * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature
            * Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition
    
        :param dataset_path: criteo train.txt path.
        :param cache_path: lmdb cache path.
        :param rebuild_cache: If True, lmdb cache is refreshed.
        :param min_threshold: infrequent feature threshold.
    
        Reference:
            https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset
            https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf
        """
    
        def __init__(self, text_columns, dataset_path=None, cache_path='.zhihu', rebuild_cache=False, min_threshold=10):
            self.word_vec = get_word_vector()
            #self.text = pd.read_csv(dataset_path, sep="\t")
            """25包含文本列+类别列"""
            self.ALL = 25
            """23是数值和类别总列数"""
            self.NUM_FEATS = 23
            """7是数值列数"""
            self.NUM_INT_FEATS = 7
            self.min_threshold = min_threshold
    
    
            if rebuild_cache or not Path(cache_path).exists():
                shutil.rmtree(cache_path, ignore_errors=True)
                shutil.rmtree(".zhihu_text", ignore_errors=True)
                if dataset_path is None:
                    raise ValueError('create cache: failed: dataset_path is None')
                self.__build_cache(dataset_path, cache_path)
                #np.save('.zhihu/text_vec.npy')
    
            self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True)
            self.text_env = lmdb.open(".zhihu_text", create=False, lock=False, readonly=True)
            with self.env.begin(write=False) as txn:
                self.length = txn.stat()['entries'] - 1
                self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32)
            print("行数:", self.length)
            #self.load_dict = np.load('.zhihu/text_vec.npy', allow_pickle=True).item()
    
        def __getitem__(self, index):
            with self.env.begin(write=False) as txn:
                np_array = np.frombuffer(
                    txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long)
    
            with self.text_env.begin(write=False) as text_txn:
                text = text_txn.get(struct.pack('>I', index)).decode()
            text_vec = pad_sequences(text, self.word_vec)
    
            # get函数通过键值查询数据
            x = np_array[1:]
            y = np_array[0]
            return x,y, text_vec
            #self.load_dict[index]
    
    
        def __len__(self):
            return self.length
    
        def __build_cache(self, path, cache_path):
            """数据集的path"""
            #即{"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            # defaults = {1列:值个数, 2列:值个数。。。。}
            feat_mapper, defaults = self.__get_feat_mapper(path)
            with lmdb.open(".zhihu_text", map_size=int(1e10)) as text_env:
                with lmdb.open(cache_path, map_size=int(1e10)) as env:
                    field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32)#[0,0,0,0,....]
                    for i, fm in feat_mapper.items():
                        field_dims[i - 1] = len(fm) + 1#[ 17  25  15  44  32  12  44   4   6   3   3   3   3   3  94  18  42 109 3  12   5   6   2]
                    with env.begin(write=True) as txn:
                        txn.put(b'field_dims', field_dims.tobytes())
                    for buffer in self.__yield_buffer(path, feat_mapper, defaults):
                        """buffer是一个矩阵,行是数据行数,列是值对应的索引"""
                        with env.begin(write=True) as txn:
                            with text_env.begin(write=True) as text_txn:
                                for key, value, text in buffer: #key是行号,v是一个长度未总列数的数组
                                    txn.put(key, value)
                                    text_txn.put(key, text)
                                    # key为行号,key未field_dims
    
        def __get_feat_mapper(self, path):
            #{"key1": {"a1":22, "a2":33}}
            """数据集的path"""
    
            feat_cnts = defaultdict(lambda: defaultdict(int))
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: counting features')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    """len_v:这个是每一行属性个数"""
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        """0是label, 1~8数值"""
                        feat_cnts[i][convert_numeric_feature(values[i])] += 1
                    #{"i": {"value": 个数1}, "i": {"value": 个数1}, "i": {"value": 个数1}}
    
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        feat_cnts[i][values[i]] += 1
    
                """对于每一行数据,都去统计下数值出现的次数,连续值也映射了,也统计了次数"""
            # {"1列": {"value": 个数1, "value2": 个数2, "value3": 个数3}, "2列": {"value": 个数1}, "i": {"value": 个数1}}
            feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()}
            # {"1列": {"value","value2","value3"}, "2列": {"value","value2","value3"}}
            #将每一列个数大于10的值保留
            feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()}
            #{1: {'32': 0, '37': 1, '40': 2, '45': 3, '34': 4, '38': 5, '43': 6, '44': 7, '33': 8, '36': 9, '42': 10, '30': 11, '31': 12, '35': 13, '39': 14, '41': 15}, 2: {'19': 0, '12': 1, '9': 2, '13': 3, '15': 4, '1': 5, '21': 6, '-2.0': 7, '11': 8, '7': 9, '20': 10, '10': 11, '0.0': 12, '5': 13, '6': 14, '3': 15, '2': 16, '14': 17, '-1.0': 18, '16': 19, '8': 20, '18': 21, '17': 22, '4': 23},
            #即{"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            defaults = {i: len(cnt) for i, cnt in feat_mapper.items()}
            #defaults = {1列:值个数, 2列:值个数。。。。}
            return feat_mapper, defaults
    
        def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)):
            """
            :param path: 数据集路径
            :param feat_mapper: {"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            :param defaults: defaults = {1列:值个数, 2列:值个数。。。。}
            :param buffer_size:
            :return:
            """
            item_idx = 0
            buffer = list()
            # text_env = lmdb.open(".zhihu_text", map_size=int(1e10))
            # text_txn = text_env.begin(write=True)
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: setup lmdb')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    #列数
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    #23个0 [0, 0, 0, 0......]
                    np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32)
                    #把label直接放进去第一列,np_array的作用是标注一下,列的值的索引
                    np_array[0] = int(values[0])
    
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        np_array[i] = feat_mapper[i].get(convert_numeric_feature(values[i]), defaults[i])
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        np_array[i] = feat_mapper[i].get(values[i], defaults[i]) #根据value得到索引序号,没有索引序号给默认值
                    #text_txn.put(key=str(item_idx).encode(), value=(values[self.NUM_FEATS + 1]).encode())
                    #self.text_vec[item_idx] = pad_sequences(values[self.NUM_FEATS + 1], self.word_vec)
                    buffer.append((struct.pack('>I', item_idx), np_array.tobytes(), values[self.NUM_FEATS + 1].encode())) #(idx(所在的行), np列表)
                    item_idx += 1
                    if item_idx % buffer_size == 0:
                        """bufer满了,返回清空"""
                        yield buffer
                        buffer.clear()
                yield buffer
    
    
    
    @lru_cache(maxsize=None)
    def convert_numeric_feature(val: str):
        """对每个数值特征进行先loge,再平方返回"""
        if val == '':
            return 'NULL'
        v = float(val)
        if v > 2:
            return str(int(math.log(v) ** 2))
        else:
            return str(v - 2)
    
    def get_word_vector():
        topic_word_vector = {}
        with open(r"D:\SCI资料\textDeepFM_attention\examples\word_vector\word_vectors_64d.txt", "r") as f:
            fList = f.readlines()
            for fLine in fList:
                rowList = fLine.split("\t")
                k = rowList[0]
                v = rowList[1].replace("\\n", "").replace("\n", "").split(" ")
                topic_word_vector[k] = v
        return topic_word_vector
    
    def pad_sequences(text, topic_word_vector):
        maxlen = 40
        sentence = []
        for t in text.split(","):
            t = t.replace(" ", "").replace("-1", "")
            if not t:
                return np.zeros((maxlen, 64))
            v = topic_word_vector[t]
            sentence.append(v)
        if maxlen > len(sentence):
            _add = np.zeros((maxlen - len(sentence), 64))
            sentence_vec = np.vstack((np.array(sentence, dtype=np.float32), _add))
        else:
            sentence_vec = np.array(sentence[:maxlen], dtype=np.float32)
        return sentence_vec
    import math
    import shutil
    import struct
    from collections import defaultdict
    from functools import lru_cache
    from pathlib import Path
    
    import lmdb
    import numpy as np
    import torch.utils.data
    from tqdm import tqdm
    import pandas as pd
    
    
    class ZhihuDataset(torch.utils.data.Dataset):
        """
        Criteo Display Advertising Challenge Dataset
    
        Data prepration:
            * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature
            * Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition
    
        :param dataset_path: criteo train.txt path.
        :param cache_path: lmdb cache path.
        :param rebuild_cache: If True, lmdb cache is refreshed.
        :param min_threshold: infrequent feature threshold.
    
        Reference:
            https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset
            https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf
        """
    
        def __init__(self, text_columns, dataset_path=None, cache_path='.zhihu', rebuild_cache=False, min_threshold=10):
            self.word_vec = get_word_vector()
            #self.text = pd.read_csv(dataset_path, sep="\t")
            """25包含文本列+类别列"""
            self.ALL = 25
            """23是数值和类别总列数"""
            self.NUM_FEATS = 23
            """7是数值列数"""
            self.NUM_INT_FEATS = 7
            self.min_threshold = min_threshold
    
    
            if rebuild_cache or not Path(cache_path).exists():
                shutil.rmtree(cache_path, ignore_errors=True)
                shutil.rmtree(".zhihu_text", ignore_errors=True)
                if dataset_path is None:
                    raise ValueError('create cache: failed: dataset_path is None')
                self.__build_cache(dataset_path, cache_path)
                #np.save('.zhihu/text_vec.npy')
    
            self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True)
            self.text_env = lmdb.open(".zhihu_text", create=False, lock=False, readonly=True)
            with self.env.begin(write=False) as txn:
                self.length = txn.stat()['entries'] - 1
                self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32)
            print("行数:", self.length)
            #self.load_dict = np.load('.zhihu/text_vec.npy', allow_pickle=True).item()
    
        def __getitem__(self, index):
            with self.env.begin(write=False) as txn:
                np_array = np.frombuffer(
                    txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long)
    
            with self.text_env.begin(write=False) as text_txn:
                text = np.frombuffer(text_txn.get(str(index).encode()), dtype=np.float64)
            text_vec = np.array(text).reshape(40, 64)
    
    
            # get函数通过键值查询数据
            x = np_array[1:]
            y = np_array[0]
            return x,y, text_vec
            #self.load_dict[index]
    
    
        def __len__(self):
            return self.length
    
        def __build_cache(self, path, cache_path):
            """数据集的path"""
            #即{"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            # defaults = {1列:值个数, 2列:值个数。。。。}
            feat_mapper, defaults = self.__get_feat_mapper(path)
            with lmdb.open(".zhihu_text", map_size=int(1e11)) as text_env:
                with lmdb.open(cache_path, map_size=int(1e10)) as env:
                    field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32)#[0,0,0,0,....]
                    for i, fm in feat_mapper.items():
                        field_dims[i - 1] = len(fm) + 1#[ 17  25  15  44  32  12  44   4   6   3   3   3   3   3  94  18  42 109 3  12   5   6   2]
                    with env.begin(write=True) as txn:
                        txn.put(b'field_dims', field_dims.tobytes())
                    for buffer in self.__yield_buffer(path, feat_mapper, defaults):
                        """buffer是一个矩阵,行是数据行数,列是值对应的索引"""
                        with env.begin(write=True) as txn:
                            with text_env.begin(write=True) as text_txn:
                                for key, value, _id, text in buffer: #key是行号,v是一个长度未总列数的数组
                                    txn.put(key, value)
                                    text_vec = pad_sequences(text, self.word_vec)
                                    text_txn.put(str(_id).encode(), text_vec)
                                    # key为行号,key未field_dims
    
        def __get_feat_mapper(self, path):
            #{"key1": {"a1":22, "a2":33}}
            """数据集的path"""
    
            feat_cnts = defaultdict(lambda: defaultdict(int))
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: counting features')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    """len_v:这个是每一行属性个数"""
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        """0是label, 1~8数值"""
                        feat_cnts[i][convert_numeric_feature(values[i])] += 1
                    #{"i": {"value": 个数1}, "i": {"value": 个数1}, "i": {"value": 个数1}}
    
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        feat_cnts[i][values[i]] += 1
    
                """对于每一行数据,都去统计下数值出现的次数,连续值也映射了,也统计了次数"""
            # {"1列": {"value": 个数1, "value2": 个数2, "value3": 个数3}, "2列": {"value": 个数1}, "i": {"value": 个数1}}
            feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()}
            # {"1列": {"value","value2","value3"}, "2列": {"value","value2","value3"}}
            #将每一列个数大于10的值保留
            feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()}
            #{1: {'32': 0, '37': 1, '40': 2, '45': 3, '34': 4, '38': 5, '43': 6, '44': 7, '33': 8, '36': 9, '42': 10, '30': 11, '31': 12, '35': 13, '39': 14, '41': 15}, 2: {'19': 0, '12': 1, '9': 2, '13': 3, '15': 4, '1': 5, '21': 6, '-2.0': 7, '11': 8, '7': 9, '20': 10, '10': 11, '0.0': 12, '5': 13, '6': 14, '3': 15, '2': 16, '14': 17, '-1.0': 18, '16': 19, '8': 20, '18': 21, '17': 22, '4': 23},
            #即{"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            defaults = {i: len(cnt) for i, cnt in feat_mapper.items()}
            #defaults = {1列:值个数, 2列:值个数。。。。}
            return feat_mapper, defaults
    
        def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)):
            """
            :param path: 数据集路径
            :param feat_mapper: {"1列": {"value":索引,"value2":索引,"value3":索引}, "2列": {"value","value2","value3"}}
            :param defaults: defaults = {1列:值个数, 2列:值个数。。。。}
            :param buffer_size:
            :return:
            """
            item_idx = 0
            buffer = list()
            # text_env = lmdb.open(".zhihu_text", map_size=int(1e10))
            # text_txn = text_env.begin(write=True)
            with open(path) as f:
                pbar = tqdm(f, mininterval=1, smoothing=0.1)
                pbar.set_description('Create criteo dataset cache: setup lmdb')
                for line in pbar:
                    values = line.rstrip('\n').split('\t')
                    #列数
                    len_v = len(values)
                    if len_v != self.ALL:
                        continue
                    #23个0 [0, 0, 0, 0......]
                    np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32)
                    #把label直接放进去第一列,np_array的作用是标注一下,列的值的索引
                    np_array[0] = int(values[0])
    
                    for i in range(1, self.NUM_INT_FEATS + 1):
                        np_array[i] = feat_mapper[i].get(convert_numeric_feature(values[i]), defaults[i])
                    for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1):
                        np_array[i] = feat_mapper[i].get(values[i], defaults[i]) #根据value得到索引序号,没有索引序号给默认值
                    #text_txn.put(key=str(item_idx).encode(), value=(values[self.NUM_FEATS + 1]).encode())
                    #self.text_vec[item_idx] = pad_sequences(values[self.NUM_FEATS + 1], self.word_vec)
                    buffer.append((struct.pack('>I', item_idx), np_array.tobytes(), item_idx,values[self.NUM_FEATS + 1])) #(idx(所在的行), np列表)
                    item_idx += 1
                    if item_idx % buffer_size == 0:
                        """bufer满了,返回清空"""
                        yield buffer
                        buffer.clear()
                yield buffer
    
    
    
    @lru_cache(maxsize=None)
    def convert_numeric_feature(val: str):
        """对每个数值特征进行先loge,再平方返回"""
        if val == '':
            return 'NULL'
        v = float(val)
        if v > 2:
            return str(int(math.log(v) ** 2))
        else:
            return str(v - 2)
    
    def get_word_vector():
        topic_word_vector = {}
        with open(r"D:\SCI资料\textDeepFM_attention\examples\word_vector\word_vectors_64d.txt", "r") as f:
            fList = f.readlines()
            for fLine in fList:
                rowList = fLine.split("\t")
                k = rowList[0]
                v = rowList[1].replace("\\n", "").replace("\n", "").split(" ")
                topic_word_vector[k] = v
        return topic_word_vector
    
    def pad_sequences(text, topic_word_vector):
        maxlen = 40
        sentence = []
        for t in text.split(","):
            t = t.replace(" ", "").replace("-1", "")
            if not t:
                return np.zeros((maxlen, 64))
            v = topic_word_vector[t]
            sentence.append(v)
        if maxlen > len(sentence):
            _add = np.zeros((maxlen - len(sentence), 64))
            sentence_vec = np.vstack((np.array(sentence, dtype=np.float32), _add))
        else:
            sentence_vec = np.array(sentence[:maxlen], dtype=np.float32)
        return sentence_vec
  • 相关阅读:
    第三十七节 log日志模块
    第三十六节 更新备注信息
    第三十五节 取消关注的股票
    第三十四节 路由添加正则功能以及添加关注功能
    第三十三节 通过带有参数的装饰器完成路由功能
    第三十二节 带有参数的装饰器
    Web_CSS
    Web_HTML
    Python操作MySQL
    MySQL_索引原理
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14689757.html
Copyright © 2011-2022 走看看