zoukankan      html  css  js  c++  java
  • tensorflow C++手写数字识别


    using namespace std;
    using namespace tensorflow;
    using namespace tensorflow::ops;
    using tensorflow::Flag;
    using tensorflow::Tensor;
    using tensorflow::Status;
    using tensorflow::string;
    using tensorflow::int32;
    static Status ReadEntireFile(tensorflow::Env* env,const string& filename,Tensor* output) {
    	tensorflow::uint64 file_size = 0;
    	TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
    	string contents;
    	TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
    	tensorflow::StringPiece data;
    	TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
    	if (data.size() != file_size) {
    		return tensorflow::errors::DataLoss("Truncated read of '", filename,
    			"' expected ", file_size, " got ", data.size());
    	output->scalar<string>()() = data.ToString();
    	return Status::OK();
    Status ReadTensorFromImageFile(const string& file_name,const int input_height,
    	const int input_width,const int input_mean,const int input_std,
    	std::vector<Tensor>* out_tensors) {
    	auto root = tensorflow::Scope::NewRootScope();
    	string input_name = "file_reader";
    	string out_name = "normalized";
    	Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape());
    	TF_RETURN_IF_ERROR(ReadEntireFile(tensorflow::Env::Default(), file_name, &input));
    	auto file_reader = Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING);
    	std::vector<std::pair<string, tensorflow::Tensor>> inputs = { {"input",input} };
    	const int wanted_channels = 1;
    	tensorflow::Output image_reader;
    	if (tensorflow::StringPiece(file_name).ends_with(".png")) {
    		image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
    	else if (tensorflow::StringPiece(file_name).ends_with(".gif")) {
    		// gif decoder returns 4-D tensor, remove the first dim
    		image_reader =
    				DecodeGif(root.WithOpName("gif_reader"), file_reader));
    	else if (tensorflow::StringPiece(file_name).ends_with(".bmp")) {
    		image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader);
    	else {
    		// Assume if it's neither a PNG nor a GIF then it must be a JPEG.
    		image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
    	// Now cast the image data to float so we can do normal math on it.
    	auto float_caster =
    		Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
    	// The convention for image ops in TensorFlow is that all images are expected
    	// to be in batches, so that they're four-dimensional arrays with indices of
    	// [batch, height, width, channel]. Because we only have a single image, we
    	// have to add a batch dimension of 1 to the start with ExpandDims().
    	auto dims_expander = ExpandDims(root.WithOpName("expand"), float_caster, 0);
    	// Bilinearly resize the image to fit the required dimensions.
       // auto resized = ResizeBilinear(
    		//root, dims_expander,
    		//Const(root.WithOpName("size"), {input_height, input_width}));
    	// Subtract the mean and divide by the scale.
    	//Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
    	float input_max = 255;
    	Div(root.WithOpName("div"), dims_expander, input_max);
    	// This runs the GraphDef network definition that we've just constructed, and
    	// returns the results in the output tensor.
    	tensorflow::GraphDef graph;
    	std::unique_ptr<tensorflow::Session> session(
    	TF_RETURN_IF_ERROR(session->Run({ inputs }, { "div" }, {}, out_tensors));
    	return Status::OK();
    int main() {
    	Session* session;
    	Status status = NewSession(SessionOptions(), &session);
    	string model_path = "model.pb";
    	GraphDef graphdef; //定义一个图
    	Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef);
    	if (!status_load.ok()) {
    		std::cout << "ERROR:Loading model failed..." << model_path << endl;
    		std::cout << status_load.ToString() << "
    		return -1;
    	Status status_create = session->Create(graphdef);
    	if (!status_create.ok()) {
    		std::cout << "ERROR:create graph in session failed..." << status_create.ToString() << '
    		return -1;
    	std::cout << "Session successfully created." << '
    	string image_path = "digit.jpg";
    	int input_height = 28, input_width = 28;
    	int input_mean = 0, input_std = 1;
    	std::vector<Tensor> resized_tensors;
    	Status read_tensor_status = ReadTensorFromImageFile(image_path, input_height, input_width,
    		input_mean, input_std,&resized_tensors);
    	if (!read_tensor_status.ok()) {
    		LOG(ERROR) << read_tensor_status;
    		cout << "resing error" << '
    		return -1;
    	const Tensor& resized_tensor = resized_tensors[0];
    	std::cout << resized_tensor.DebugString() << endl;
    	vector<tensorflow::Tensor> outputs;
    	string output_node = "softmax";
    	virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
                         const std::vector<string>& output_tensor_names,
                         const std::vector<string>& target_node_names,
                         std::vector<Tensor>* outputs)
    	Status status_run = session->Run({ {"inputs",resized_tensor} }, 
    		{ output_node }, {}, &outputs);
    	if (!status_run.ok()) {
    		std::cout << "ERROR: RUN failed..." << std::endl;
    		std::cout << status_run.ToString() << "
    		return -1;
    	std::cout << "Output tensor size:" << outputs.size() << std::endl;
    	for (std::size_t i = 0; i < outputs.size(); i++) {
    		std::cout << outputs[i].DebugString() << endl;
    	Tensor t = outputs[0];
    	int ndim = t.shape().dims();
    	auto tmap = t.tensor<float, 2>();
    	int output_dim = t.shape().dim_size(1);
    	std::vector<double> tout;
    	int output_class_id = -1;
    	double output_prob = 0.0;
    	for (int j = 0; j < output_dim; j++) {
    		std::cout << "Class " << j << "prob:" << tmap(0, j) << "," << endl;
    		if (tmap(0, j) >= output_prob) {
    			output_class_id = j;
    			output_prob = tmap(0, j);
    	std::cout << "Final class id:" << output_class_id << endl;
    	std::cout << "Final prob:" << output_prob << endl;
    	return 0;
  • 相关阅读:
    复习 层叠样式表
    WindowsForm 增 删 查 改
    WindowsForm 计算器
    rabbitmq 使用心得
  • 原文地址:https://www.cnblogs.com/lutaishi/p/13436230.html
Copyright © 2011-2022 走看看