最新Java教程:在Java中使用便携式ONNX AI模型
在我们关于2020年使用便携式神经网络的系列文章中,您将了解如何在x64架构上安装ONNX并在Java中使用它。
微软与Facebook和AWS共同开发了ONNX。ONNX格式和ONNXRuntime都得到了业界的支持,以确保所有重要的框架都能够将它们的图导出到ONNX,并且这些模型可以在任何硬件配置上运行。
ONNXRuntime是一个用于运行已转换为ONNX格式的机器学习模型的引擎。传统的机器学习模型和深度学习模型(神经网络)都可以输出到ONNX格式。该运行时可以在Linux、Windows和Mac上运行,并且可以在各种芯片架构上运行。它还可以利用GPU和TPU等硬件加速器。不过,并不是每一种操作系统、芯片架构和加速器的组合都有安装包,所以如果您没有使用常见的组合,可能需要从源码中构建运行时。请查看ONNX运行时网站,获取所需组合的安装说明。本文将介绍如何在使用默认CPU的x64架构和使用GPU的x64架构上安装ONNXRuntime。
除了能够在多种硬件配置上运行,运行时还可以从大多数流行的编程语言中调用。本文的目的是展示如何在Java中使用ONNX运行时。我将展示如何安装onnxruntime包。安装ONNXRuntime后,我将把之前导出的MNIST模型加载到ONNXRuntime中,并使用它进行预测。
安装和导入ONNX运行时系统
在使用ONNX运行时之前,您需要为您的构建工具添加适当的依赖性。Maven资源库是为Maven和Gradle等各种工具设置ONNX运行时的良好来源。要在x64架构和默认CPU上使用运行时,请参考以下链接。https://mvnrepository.com/artifact/org.bytedeco/onnxruntime-platform
要在x64架构的GPU上使用运行时,请使用以下链接。https://mvnrepository.com/artifact/org.bytedeco/onnxruntime-platform-gpu
一旦安装了运行时,就可以通过下图所示的导入语句将其导入到你的Java代码文件中。导入TensorProto工具的导入语句将帮助我们为ONNX模型创建输入,它还将帮助解释ONNX模型的输出(预测)。
import ai.onnxruntime.OnnxMl.TensorProto;import ai.onnxruntime.OnnxMl.TensorProto.DataType;import ai.onnxruntime.OrtSession.Result;import ai.onnxruntime.OrtSession.SessionOptions;import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode;import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;
加载ONNX模型
下面的片段显示了如何将ONNX模型加载到以Java运行的ONNXRuntime中。这段代码创建了一个会话对象,可用于进行预测。这里使用的模型是从PyTorch导出的ONNX模型。
这里有几件事值得注意。首先,您需要查询会话以获取其输入。这是通过会话的getInputInfo方法完成的。我们的MNIST模型只有一个输入参数:一个由784个浮点组成的数组,代表MNIST数据集中的一张图像。如果您的模型有多个输入参数,那么InputMetadata将为每个参数设置一个条目。
Utilities.LoadTensorData(); String modelPath = "pytorch_mnist.onnx";try (OrtSession session = env.createSession(modelPath, options)) { Map<String, NodeInfo> inputMetaMap = session.getInputInfo(); Map<String, OnnxTensor> container = new HashMap<>(); NodeInfo inputMeta = inputMetaMap.values().iterator().next(); float[] inputData = Utilities.ImageData[imageIndex]; string label = Utilities.ImageLabels[imageIndex]; System.out.println("Selected image is the number: " + label); // this is the data for only one input tensor for this model Object tensorData = OrtUtil.reshape(inputData, ((TensorInfo) inputMeta.getInfo()).getShape()); OnnxTensor inputTensor = OnnxTensor.createTensor(env, tensorData); container.put(inputMeta.getName(), inputTensor); // Run code omitted for brevity.}
上面的代码中没有显示的是读取原始MNIST图像并将每幅图像转换为784个浮动数组的实用程序。每个图像的标签也从MNIST数据集中读取,这样就可以确定预测的准确性。这段代码是标准的Java代码,但我们仍然鼓励你检查并使用它。如果您需要读取与MNIST数据集相似的图像,它将为您节省时间。
使用ONNX运行时间进行预测。
下面的功能显示了如何使用我们加载ONNX模型时创建的ONNX会话。
try (OrtSession session = env.createSession(modelPath, options)) { // Load code not shown for brevity. // Run the inference try (OrtSession.Result results = session.run(container)) { // Only iterates once for (Map.Entry<String, OnnxValue> r : results) { OnnxValue resultValue = r.getValue(); OnnxTensor resultTensor = (OnnxTensor) resultValue; resultTensor.getValue() System.out.println("Output Name: {0}", r.Name); int prediction = MaxProbability(resultTensor); System.out.println("Prediction: " + prediction.ToString()); } } }
大多数神经网络不直接返回预测。它们会返回每个输出类的概率列表。在我们MNIST模型的情况下,每个图像的返回值将是一个10个概率的列表。具有最高概率的条目就是预测。您可以做的一个有趣的测试是,当ONNX模型在创建模型的框架内运行时,比较ONNX模型返回的概率和原始模型返回的概率。理想情况下,模型格式和运行时的变化不应改变任何产生的概率。这将是一个很好的单元测试,每次模型发生变化时都会运行。
总结和下一步
在本文中,我简要介绍了ONNX运行时和ONNX格式。然后,我展示了如何在ONNX运行时使用Java加载和运行ONNX模型。
本文的代码示例包含一个工作控制台应用程序,演示了这里所展示的所有技术。该代码示例是Github资源库的一部分,该资源库探讨了如何使用神经网络来预测MNIST数据集中发现的数字。具体来说,有一些样本展示了如何在Keras、PyTorch、TensorFlow1.0和TensorFlow2.0中创建神经网络。