当前位置: 首页>编程语言>正文

模型提前载入python python怎么加载模型

简介

工作流程:

  • python: 使用keras训练模型并保存为h5(keras以tensorflow为引擎)
  • python: 转换h5为pb文件
  • python: 加载模型,并验证模型无误
  • c/c++: 加载并使用模型

依赖安装:

# 指定版本安装
# 如果已经存在该包的更高版本,会自动卸载并重新安装
# 显示卸载示例:pip uninstall opencv-python
pip install tensorflow==1.13.1
pip install h5py==2.10
pip install opencv-python-headless

遇到问题可以查看:

  • keras加载模型load_model时报错:AttributeError: ‘str‘ object has no attribute ‘decode‘ “
  • opencv ImportError: libGL.so.1: cannot open shared object file: No such file or directory

版本:

模型提前载入python python怎么加载模型,模型提前载入python python怎么加载模型_tensorflow,第1张

至于使用c/c++,其实都是差不多的。但是它们的环境构建不同:

  • c++:使用bazel构建,每个平台需要自行构建,繁索
  • c:TensorFlow提供了编译好的包(头文件和库文件),可以根据平台直接下载使用

c相关包的下载:https://tensorflow.google.cn/install/lang_c

模型提前载入python python怎么加载模型,模型提前载入python python怎么加载模型_模型提前载入python_02,第2张

注意,如果要下载之前的版本,只需要把网址中相应的版本号修改一下即可。

所以这里使用c为例,来介绍整个流程。

keras训练模型

以fashion_mnist为例。

首先使用python训练模型,假设使用下面的模型:

from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D

def create_model():
    model = Sequential()
    # Must define the input shape in the first layer of the neural network
    model.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28, 28, 1), name='input_image'))
    model.add(MaxPooling2D(pool_size=2))
    model.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=2))
    model.add(Flatten())
    model.add(Dense(256, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10, activation='softmax', name='output_class'))

    return model

它的模型结构是这样的:

模型提前载入python python怎么加载模型,模型提前载入python python怎么加载模型_c++_03,第3张

训练并保存模型,保存文件为:fashion_mnist.h5

转换h5为pb

使用 keras_to_tensorflow 工具进行转换。

这个工具原理:

  • 用 Keras 读取 .h5 模型文件
  • 用 tensorflow 的 convert_variables_to_constants 函数将所有变量转换成常量
  • 再 write_graph 就是一个包含了网络以及参数值的 .pb 文件了

这个参考网上开源工具即可。

转换完成后,生成:fashion_mnist.h5.pb。

python验证模型

这个代码就比较简单了:

#!/usr/bin/env python

import tensorflow as tf
import numpy as np

from tensorflow.python.platform import gfile

# Fix for ros kinetic users
import sys
print(sys.path)
#sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')

# OpenCV
import cv2

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# Read image
img = cv2.imread('fashion_0.png', cv2.IMREAD_GRAYSCALE)
print('img.shape = ', img.shape)
img = img.astype('float32')
img /= 255.0
img = img.reshape(1, 28, 28, 1)

# Initialize a tensorflow session
with tf.Session() as sess:
    # Load the protobuf graph
    with gfile.FastGFile("models/fashion_mnist.h5.pb",'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Add the graph to the session
        tf.import_graph_def(graph_def, name='')

    # Get graph
    graph = tf.get_default_graph()

    # Get tensor from graph
    pred = graph.get_tensor_by_name("output_class/Softmax:0")

    # Run the session, evaluating our "c" operation from the graph
    res = sess.run(pred, feed_dict={'input_image_input:0': img})

    # Print test accuracy
    pred_index = np.argmax(res[0])

    # Print test accuracy
    print('Predict:', pred_index, ' Label:', class_names[pred_index])

这个过程先是准备了一个样例,使用模型来预测该样例。

加载模型,以样例作为输入,如果执行成功,会得到:

模型提前载入python python怎么加载模型,模型提前载入python python怎么加载模型_c++_04,第4张

c/c++加载模型

c++加载模型并使用的总体流程与python加载验证模型的流程类似。

#include "../utils/TFUtils.hpp"
#include "utils/mat2tensor_c_cpi.h"

#include <iostream>
#include <vector>

// OpenCV
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>

std::string class_names[] = {"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"};

int main(int argc, char* argv[])
{
    if (argc != 3)
    {
        std::cerr << std::endl << "Usage: ./project path_to_graph.pb path_to_image.png" << std::endl;
        return 1;
    }

    // Load graph
    std::string graph_path = argv[1];

    // TFUtils init
    TFUtils TFU;
    TFUtils::STATUS status = TFU.LoadModel(graph_path);

    if (status != TFUtils::SUCCESS) {
        std::cerr << "Can't load graph" << std::endl;
        return 1;
    }

    // Load image and convert to tensor
    std::string image_path = argv[2];
    cv::Mat image = cv::imread(image_path, CV_LOAD_IMAGE_GRAYSCALE);

    const std::vector<std::int64_t> input_dims = {1, image.size().height, image.size().width, image.channels()};

    TF_Tensor* input_image = Mat2Tensor(image, 1/255.0);

    // Input Tensor/Ops Create
    const std::vector<TF_Tensor*> input_tensors = {input_image};

    const std::vector<TF_Output> input_ops = {TFU.GetOperationByName("input_image_input", 0)};

    // Output Tensor/Ops Create
    const std::vector<TF_Output> output_ops = {TFU.GetOperationByName("output_class/Softmax", 0)};

    std::vector<TF_Tensor*> output_tensors = {nullptr};

    status = TFU.RunSession(input_ops, input_tensors,
                            output_ops, output_tensors);

    if (status == TFUtils::SUCCESS) {
        const std::vector<std::vector<float>> data = TFUtils::GetTensorsData<float>(output_tensors);
        const std::vector<float> result = data[0];

        int pred_index = ArgMax(result);

        // Print test accuracy
        printf("Predict: %d Label: %s", pred_index, class_names[pred_index].c_str());

    } else {
        std::cout << "Error run session";
        return 2;
    }

    TFUtils::DeleteTensors(input_tensors);
    TFUtils::DeleteTensors(output_tensors);

    return 0;
}

大概流程如下:

  • 加载模型文件:LoadModel
  • 使用opencv的imread函数读取图像,以此图像作为输入,使用模型对它进行识别
  • Mat2Tensor把图像数据转化为TensorFlow能够识别的数据
  • 构建输入输出
  • 执行session
  • 获取结果

整个流程是这样的,要想真正理解它们的含义,还是需要学习一下TensorFlow这些函数背后的意义。

小结

至此,整个流程可以跑起来了,可以有一个直观上认识和感觉。建立一下学习的信心和希望。

下一步,就是要深入到这些代码背后,并结合TensorFlow的使用,真正理解为什么要这样做。

参考资料

Tensorflow C++ 从训练到部署(3):使用 Keras 训练和部署 CNN https://github.com/skylook/tensorflow_cpp https://github.com/Neargye/hello_tf_c_api


https://www.xamrdz.com/lan/5d81923828.html

相关文章: