1 #include<torch/torch.h>
2 #include<torch/script.h>
3 #include<iostream>
4 using namespace std; // 项目中建议不要使用
5
6 void printTitle(const string& title)
7 {
8 cout << endl;
9 cout << "******【" << title << "】******" << endl;
10 }
11
12 int main()
13 {
14 // e.g. 2.4
15 {
16 printTitle("e.g. 2.4");
17 // 声明tensor
18 torch::Tensor t1 = torch::tensor({ { 1, 2, 3 }, { 2, 3, 4 } }, torch::kByte);
19 cout << "t1.dtype() = " << t1.dtype() << endl; // __int64
20 t1.print();
21 cout << "t1 = " << t1 << endl;
22 t1 = torch::range(1, 10, torch::kByte);
23
24 // 随机tensor
25 t1 = torch::randn({ 3, 3 }, torch::kFloat) * 10;
26 cout << "t1 = " << t1 << endl;
27 torch::Tensor t2 = t1.to(torch::kInt8);
28 cout << "t2 = " << t2 << endl;
29 }
30 // e.g. 2.5 随机
31 {
32 printTitle("e.g. 2.5 随机");
33 torch::Tensor t1 = torch::rand({ 3, 3 }, torch::kFloat32);
34 t1 = torch::randn({ 2, 3, 4 });
35 t1 = torch::zeros({ 2, 2, 2 }, torch::kUInt8);
36 t1 = torch::ones({ 3, 4 }) * 9;
37 t1 = torch::eye(3, torch::kFloat);
38 t1 = torch::randint(0, 4, { 3, 3 });
39 cout << "t1 = " << t1 << endl;
40
41 }
42 // e.g. 2.6 随机
43 {
44 printTitle("e.g. 2.6 随机");
45 torch::Tensor t1 = torch::rand({ 3, 3 }, torch::kFloat32);
46 // copy the shape of t1
47 torch::Tensor t2 = torch::zeros_like(t1);
48 t2 = torch::ones_like(t1);
49 t2 = torch::randn_like(t1);
50
51 // copy the dtype of t1
52 torch::Tensor t3 = t1.new_zeros({ 3, 3 }); // 生成相同类型且元素全为0的张量
53 t3 = torch::ones(t1.sizes(), t1.dtype()); // 和opencv一样
54 t3 = torch::zeros(t1.sizes(), t1.dtype());
55
56 cout << "t2 = " << t2 << endl;
57 cout << "t3 = " << t3 << endl;
58 }
59 // e.g. 2.8 设备
60 {
61 printTitle("e.g. 2.8 设备");
62 torch::Tensor t1 = torch::randn({ 3, 3 }, torch::Device("cpu"));
63 cout << "t1 = " << t1 << endl;
64 auto device = torch::Device("cuda:0");
65 torch::Tensor t2 = torch::randn({ 3, 3 }, torch::kF32).to(device);
66 cout << "t2 = " << t2 << endl;
67 cout << "t2.device = " << t2.device() << endl;
68 }
69 // e.g. 2.9 指针
70 {
71 printTitle("e.g. 2.9 指针");
72 torch::Tensor t1 = torch::randn({ 3, 4, 5 });
73 cout << t1 << endl;
74 int nd = t1.ndimension(); // channels = 3; 获取维度的数目
75 int nc = t1.size(0); // c
76 int nw = t1.size(1); // w
77 int nh = t1.size(2); // h
78 cout << nd << " " << nc << endl;
79 auto sz = t1.sizes(); // [c w h]
80 cout << "sz = " << sz << endl;
81 t1 = torch::randn({ 12 });
82 torch::Tensor t2 = t1.view({ -1, 3 }); // 将其第二个维度变为3,第一个维度会自动计算,不过不能整除就会报错
83 t2[0][0] = 99; // 访问元素
84 cout << "t2 = " << t2 << endl;
85 float* t2_ptr = (float*)t2.data_ptr(); // 获取指针
86 cout << "t2_ptr = " << t2_ptr << endl;
87 void* t22_ptr = (void*)t2.data_ptr(); // 指针指针,地址不变
88 cout << "t22_ptr = " << t22_ptr << endl;
89 auto t222_ptr = t2.contiguous().data_ptr(); // 指针指针,地址不变
90 cout << "t222_ptr = " << t222_ptr << endl;
91 auto t2222_ptr = t2.transpose(0, 1).contiguous().data_ptr(); // 步长和维度不兼容,重新生成张量(即:会重新分配内存)
92 cout << "t2222_ptr = " << t2222_ptr << endl;
93 }
94 // e.g. 2.10 mask
95 {
96 printTitle("e.g. 2.10 mask");
97 torch::Tensor t1 = torch::randn({ 2, 3, 4 });
98 cout << "t1 = " << t1 << endl;
99 torch::Tensor ele = t1[1][2][3];
100 cout << "ele = " << ele << endl;
101 double ele_ = ele.item().toDouble(); // tensor 转 double
102 cout << "ele_ = " << ele_ << endl;
103 torch::Tensor mask = t1.ge(0);
104 cout << "mask = " << mask << endl;
105 torch::Tensor t2 = t1.masked_select(mask); // t2 是一个向量
106 cout << "t2 =" << t2 << endl;
107 }
108 // e.g. 2.11 sqrt && sum
109 {
110 printTitle("e.g. 2.11 sqrt");
111 torch::Tensor t1 = torch::randint(1, 9, { 3, 3 });
112 cout << "t1 = " << t1 << endl;
113 torch::Tensor t2 = t1.to(torch::kFloat32);
114 torch::Tensor t3 = t2.sqrt(); // 操作不改变t2的值
115 t3 = torch::sqrt(t2); // 操作不改变t2的值
116 cout << "t3 = " << t3 << endl;
117 t2.sqrt_(); // 平方根原地操作,修改自己的值
118 cout << "t2 = " << t2 << endl;
119
120 // 也可以调用默认的sum()成员函数
121 cout << "t1 = " << t1 << endl;
122 torch::Tensor sum1 = torch::sum(t1); // 默认对所有的元素求和
123 torch::Tensor sum2 = torch::sum(t1, 0); // 对第0维的元素求和,即:按列进行求和
124 torch::Tensor sum3 = torch::sum(t1, { 1,0 }); // 写成{0, 1}会报编译错
125 cout << "sum3 = " << sum3.item().toFloat() << endl;
126
127 torch::Tensor mean1 = t1.mean(); // 对所有元素求平均,也可以用torch.mean函数
128 torch::Tensor mean2 = t1.mean(0); // 对第0维的元素求平均
129 // 写成{0, 1}会报编译错,同上
130 torch::Tensor mean3 = torch::mean(t1, { 1, 0 }); // 对第0、1维元素求平均, mean.shape = 1*1
131 cout << "mean1 = " << mean1.item().toFloat() << endl;
132 cout << "mean2 = " << mean2 << endl;
133 cout << "mean3 = " << mean3 << endl;
134 }
135 // e.g. 2.12 对应元素加、减、乘、除(其实都重载了运算符,自己取试一试)
136 {
137 printTitle("e.g. 2.12 ");
138 torch::Tensor t1 = torch::rand({ 2, 3 });
139 torch::Tensor t2 = torch::rand({ 2, 3 });
140 torch::Tensor t3 = t1 + t2;
141 torch::Tensor t4 = t1.sub(t2);
142 torch::Tensor t5 = t1.mul(t2);
143 torch::Tensor t6 = t1.div(2);
144 cout << "t1 = " << t1 << endl;
145 cout << "t2 = " << t2 << endl;
146 cout << "t3 = " << t3 << endl;
147 cout << "t4 = " << t4 << endl;
148 cout << "t5 = " << t5 << endl;
149 cout << "t6 = " << t6 << endl;
150 t6.add_(1); // 会修改t6中的值
151 cout << "t6 = " << t6 << endl;
152 }
153 // e.g. 2.13 min max argmax
154 {
155 printTitle("e.g. 2.13 min max argmax");
156 torch::Tensor t1 = torch::randn({ 3, 4 }, torch::kFloat64);
157 cout << "t1 = " << t1 << endl;
158 torch::Tensor mask_argmax = torch::argmax(t1, 0); // 返回的是沿着第0个维度,极大值所在位置
159 cout << "mask_argmax = " << mask_argmax << endl;
160 // max
161 std::tuple<torch::Tensor, torch::Tensor> maxVals = torch::max(t1, -1); // 函数调用,返回的是沿着最后一个维度,包含极大值和极大值所在位置的元组
162 torch::Tensor mask_max = std::get<0>(maxVals); // max val
163 torch::Tensor mask_max_idx = std::get<1>(maxVals); // index of maxVal
164 cout << "mask_max = " << mask_max << endl;
165 cout << "mask_max_idx = " << mask_max_idx << endl;
166 // min
167 std::tuple<torch::Tensor, torch::Tensor> minVals = t1.min(0); // 内置方法调用,返回的是沿着第0个维度,包含极小值和极小值所在位置的元组
168 torch::Tensor mask_min = std::get<0>(minVals); // min val
169 torch::Tensor mask_min_idx = std::get<1>(minVals);// index of minVal
170 cout << "mask_min = " << mask_min << endl;
171 cout << "mask_min_idx = " << mask_min_idx << endl;
172 // sort
173 std::tuple<torch::Tensor, torch::Tensor> sortVals = t1.sort(-1); // 沿着最后一个维度排序,返回排序后的张量和张量元素在该维度的原始位置
174 torch::Tensor tensorVal = std::get<0>(sortVals);
175 torch::Tensor tensorValIdx = std::get<1>(sortVals);
176 cout << "tensorVal = " << tensorVal << endl;
177 cout << "tensorValIdx = " << tensorValIdx << endl;
178 }
179 // e.g. 2.14 矩阵乘法
180 {
181 printTitle("e.g. 2.14 矩阵乘法");
182 torch::Tensor t1 = torch::tensor({ {1, 2}, {3, 4} }, torch::kFloat64); // 2×2
183 torch::Tensor t2 = torch::tensor({ {1, 1, 1}, {2, 3, 1} }, torch::kFloat64); // 2×3
184 auto t3 = t1.mm(t2); // 矩阵乘法, torch::mm
185 cout << "t1 = " << t1 << endl;
186 cout << "t2 = " << t2 << endl;
187 cout << "t3 = " << t3 << endl;
188 //
189 t1 = torch::randn({ 2, 3, 4 });
190 t2 = torch::randn({ 2, 4, 3 });
191 torch::Tensor t4 = t1.bmm(t2); // (迷你)批次矩阵乘法,返回结果为2×3×3,函数形式
192 cout << "t1 = " << t1 << endl;
193 cout << "t2 = " << t2 << endl;
194 cout << "t4 = " << t4 << endl;
195 }
196 // e.g. 2.16 Tensor堆叠、拼接
197 {
198 printTitle("e.g. 2.16 Tensor堆叠、拼接");
199 auto t1 = torch::randn({ 2, 3 });
200 auto t2 = torch::randn({ 2, 3 });
201 auto t3 = torch::stack({ t1, t2 }, -1); // 沿着最后一个维度做堆叠,返回大小为2×2×3的张量
202 cout << "t1.sizes() = " << t1.sizes() << endl;
203 cout << "t2.sizes() = " << t2.sizes() << endl;
204 cout << "t3.sizes() = " << t3.sizes() << endl;
205 }
206 // e.g. 2.17 2.18 拓展、压缩维度
207 {
208 printTitle("e.g. 2.17 2.18 拓展维度");
209 torch::Tensor t1 = torch::rand({ 3, 4 });
210 cout << "t1.sizes() = " << t1.sizes() << endl;
211 auto t11 = t1.unsqueeze(-1); // 扩增最后一个维度
212 cout << "t11.sizes() = " << t11.sizes() << endl;
213 auto t12 = t1.unsqueeze(-1).unsqueeze(-1); // 继续扩增最后一个维度
214 cout << "t12.sizes() = " << t12.sizes() << endl;
215 auto t13 = t1.unsqueeze(1); // 在第1个维度插入新一个维度 -> 3*4*1
216 cout << "t13.sizes() = " << t13.sizes() << endl;
217
218 auto t2 = torch::rand({ 1, 3, 4, 1 });
219 cout << "t2.sizes() = " << t2.sizes() << endl;
220 auto t21 = t2.squeeze(); // 压缩所有大小为1的维度
221 cout << "t21.sizes() = " << t21.sizes() << endl;
222 }
223 // e.g. 2.18
224 return 1;
225 }