zoukankan      html  css  js  c++  java
  • Pytorch list tensor 转 onehot

    def test_onehot():
        v = torch.tensor([[0.1, 0.2, 0.7],
                          [0.1, 0.6, 0.3],
                          [0.1, 0.5, 0.4],
                          [0.8, 0.1, 0.1], ])
    
        print('v', v.size(), v)
        # 按照形状创建全0张量
        result = torch.zeros_like(v, dtype=torch.long)
        # 目标维度
        dim = -1
        # 根据索引将值改为1
        result.scatter_(dim,
                        v.argmax(dim).unsqueeze(dim),
                        torch.ones(4, dtype=torch.long).unsqueeze(dim))
    
        print('result', result.size(), result)
    
    

    输出

    v torch.Size([4, 3]) tensor([[0.1000, 0.2000, 0.7000],
            [0.1000, 0.6000, 0.3000],
            [0.1000, 0.5000, 0.4000],
            [0.8000, 0.1000, 0.1000]])
    result torch.Size([4, 3]) tensor([[0, 0, 1],
            [0, 1, 0],
            [0, 1, 0],
            [1, 0, 0]])
    

    自用两个方法

    def list_onehot(actions: list, n: int) -> torch.Tensor:
        """
        列表动作值转 onehot
        actions: 动作列表
        n: 动作总个数
        """
        result = []
        for action in actions:
            result.append([int(k == action) for k in range(n)])
        result = torch.tensor(result, dtype=torch.long)
        if torch.cuda.is_available():
            result = result.cuda()
        return result
    
    def max_onehot(props: torch.Tensor, dim=-1) -> torch.Tensor:
        """
        动作概率 tensor 转 onehot
        props: 动作概率表
        dim: 目标维度
        """
        result = torch.zeros_like(props, dtype=torch.long)
        src = torch.ones(self.batchSize, dtype=torch.long).unsqueeze(dim)
        if torch.cuda.is_available():
            result = result.cuda()
            src = src.cuda()
        result.scatter_(dim, props.argmax(dim).unsqueeze(dim), src)
        return result
    
  • 相关阅读:
    CentOS 6.3 下编译Nginx(笔记整理)
    XMPP协议相关基础概念(Strophe学习笔记)
    StackMapTable format error
    hibernate的子查询
    strophe与openfire模拟的XMPP简单hello程序
    Strophe.Status的所有值
    博客园添加SyntaxHighlighter
    P5395 【模板】第二类斯特林数·行
    test20191001
    test20190829
  • 原文地址:https://www.cnblogs.com/congxinglong/p/15587476.html
Copyright © 2011-2022 走看看