zoukankan      html  css  js  c++  java
  • libtorch 哪些函数比较常用?

    如何打印模型?

    // print register_module
    // auto Tiny_Net = std::make_shared<VGG9>();
    // print_modules(Tiny_Net)
    void print_modules(const std::shared_ptr<torch::nn::Module> &module, size_t level = 0) {
    
    	auto tabs = [&](size_t num) {
    		for (size_t i = 0; i < num; i++) {
    			std::cout << "	";
    		}
    	};
    
    	std::cout << module->name() << " (
    ";
    	for (const auto& parameter : module->named_parameters()) {
    		tabs(level + 1);
    		std::cout << parameter.key() << '	';
    		std::cout << parameter.value().sizes() << '
    ';
    	}
    	
    	tabs(level);
    	std::cout << ")
    ";
    }
    
    		//输入32x32 3通道图片
    		auto input = torch::rand({ 1,3,32,32 });
    
    		//输出
    		auto output_bilinear = torch::upsample_bilinear2d(input, { 8,8 }, false);
    		auto output_nearest = torch::upsample_nearest2d(input, { 5,5 });
    		auto output_avg = torch::adaptive_avg_pool2d(input, { 3,9 });
    		
    		std::cout << output_bilinear << std::endl;
    		std::cout << output_nearest << std::endl;
    		std::cout << output_avg << std::endl;
    

    libtorch 加载 pytorch 模块进行预测示例

    void mat2tensor(const char * path, torch::Tensor &output)
    {
    	//读取图片
    	cv::Mat img = cv::imread(path);
    	if (img.empty()) {
    		printf("load image failed!");
    		system("pause");
    	}
    
    	//调整大小
    	cv::resize(img, img, { 224,224 });
    	cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
    	//浮点
    	img.convertTo(img, CV_32F, 1.0 / 255.0);
    
    	torch::TensorOptions option(torch::kFloat32);
    	auto img_tensor = torch::from_blob(img.data, { 1,img.rows,img.cols,img.channels() }, option);// opencv H x W x C  torch C x H x W
    	img_tensor = img_tensor.permute({ 0,3,1,2 });// 调整 opencv 矩阵的维度使其和 torch 维度一致
    
    	//均值归一化
    	img_tensor[0][0] = img_tensor[0][0].sub_(0.485).div_(0.229);
    	img_tensor[0][1] = img_tensor[0][1].sub_(0.456).div_(0.224);
    	img_tensor[0][2] = img_tensor[0][2].sub_(0.406).div_(0.225);
    
    	output = img_tensor.clone();
    }
    
    int main() 
    {
    	torch::Tensor dog;
    	mat2tensor("dog.png", dog);
    
    	// Load model.
    	std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("model.pt");
    	
    	assert(module != nullptr);
    	std::cout << "ok
    " << std::endl;
    
    	// Create a vector of inputs.
    	std::vector<torch::jit::IValue> inputs;
    	torch::Tensor tensor = torch::rand({ 1, 3, 224, 224 });
    	inputs.push_back(dog);
    
    	// Execute the model and turn its output into a tensor.
    	at::Tensor output = module->forward(inputs).toTensor();
    
    	//加载标签文件
    	std::string label_file = "synset_words.txt";
    	std::fstream fs(label_file, std::ios::in);
    	if (!fs.is_open()) {
    		printf("label open failed!
    ");
    		system("pause");
    	}
    	std::string line;
    	std::vector<std::string> labels;
    	while (std::getline(fs,line))
    	{
    		labels.push_back(line);
    	}
    
    	//排序 {1,1000} 矩阵取前10个元素(预测值),返回一个矩阵和一个矩阵的下标索引
    	std::tuple<torch::Tensor,torch::Tensor> result = output.topk(10, -1);
    	auto top_scores = std::get<0>(result).view(-1);//{1,10} 变成 {10}
    	auto top_idxs = std::get<1>(result).view(-1);
    	std::cout << top_scores << "
    " << top_idxs << std::endl;
    
    	//打印结果
    	for (int i = 0; i < 10; ++i) {
    		std::cout << "score: " << top_scores[i].item().toFloat() << "	" << "label: " << labels[top_idxs[i].item().toInt()] << std::endl;
    	}
    	
    	system("pause");
    	return 0;
    ]
    

    torch::sort

    	torch::Tensor x = torch::rand({ 3,3 });
    	std::cout << x << std::endl;
    
    	//排序操作 true 大到小排序,false 小到大排序
    	auto out = x.sort(-1, true);
    
    	std::cout << std::get<0>(out) << "
    " << std::get<1>(out) << std::endl;
    

    输出如下:

     0.0855  0.4925  0.4323
     0.8314  0.8954  0.0709
     0.0996  0.3108  0.6845
    [ Variable[CPUFloatType]{3,3} ]
    
     0.4925  0.4323  0.0855
     0.8954  0.8314  0.0709
     0.6845  0.3108  0.0996
    [ Variable[CPUFloatType]{3,3} ]
    
     1  2  0
     1  0  2
     2  1  0
    [ Variable[CPULongType]{3,3} ]
    
  • 相关阅读:
    系统管理员必须掌握的20个Linux监控工具
    JavaWeb基础—MVC与三层架构
    JavaWeb基础—JavaBean
    JavaWeb基础—JSP
    myeclipse(eclipse)IDE配置
    JavaWeb基础—会话管理之Session
    JavaWeb基础—会话管理之Cookie
    JavaWeb基础—项目名的写法
    JavaWeb基础—HttpServletRequest
    JavaWeb基础—VerifyCode源码
  • 原文地址:https://www.cnblogs.com/cheungxiongwei/p/10721547.html
Copyright © 2011-2022 走看看