Java使用pytorch模型进行数据推算

本文介绍如何在Java后台利用DJL库调用PyTorch模型进行数据分析。首先,需添加对应的依赖,并确保PyTorch版本匹配。然后,将预训练的PyTorch模型转换为TorchScript,以便在Java中使用。通过示例展示了如何用tracing方法将模型转换为TorchScript,并保存。最后,提供了Java代码示例,展示如何加载模型并进行预测。在遇到UnsatisfiedLinkError错误时,调整了依赖项并解决了问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

我的Java后台需要对数据进行分析,但找不到合适的方法,就准备用pytorch写个模型凑活着用。

使用的DJL调用pytorch引擎

Github:djl/README.md at master · deepjavalibrary/djl · GitHub

pom.xml中添加依赖:

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.16.0</version>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-auto</artifactId>
    <version>1.9.1</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-jni</artifactId>
    <version>1.9.1-0.16.0</version>
    <scope>runtime</scope>
</dependency>

注意version与pytorch版本有一个对应关系

PyTorch engine versionPyTorch native library version
pytorch-engine:0.15.0pytorch-native-auto: 1.8.1, 1.9.1, 1.10.0
pytorch-engine:0.14.0pytorch-native-auto: 1.8.1, 1.9.0, 1.9.1
pytorch-engine:0.13.0pytorch-native-auto:1.9.0
pytorch-engine:0.12.0pytorch-native-auto:1.8.1
pytorch-engine:0.11.0pytorch-native-auto:1.8.1
pytorch-engine:0.10.0pytorch-native-auto:1.7.1
pytorch-engine:0.9.0pytorch-native-auto:1.7.0
pytorch-engine:0.8.0pytorch-native-auto:1.6.0
pytorch-engine:0.7.0pytorch-native-auto:1.6.0
pytorch-engine:0.6.0pytorch-native-auto:1.5.0
pytorch-engine:0.5.0pytorch-native-auto:1.4.0
pytorch-engine:0.4.0pytorch-native-auto:1.4.0

其他问题访问连接:PyTorch Engine - Deep Java Library


官方给出了一个图片分类的例子,我只需要纯数据不需要图片输入。

随便写了个例子 输入是[a, b] 输出一个0~1的数

还是建议用python先训练好模型,不要用Java训练。模型训练好后,首先要做的是把pytorch模型转为TorchScript,TorchScript会把模型结构和参数都加载进去的

官网原文:

There are two ways to convert your model to TorchScript: tracing and scripting. We will only demonstrate the first one, tracing, but you can find information about scripting from the PyTorch documentation. When tracing, we use an example input to record the actions taken and capture the the model architecture. This works best when your model doesn't have control flow. If you do have control flow, you will need to use the scripting approach. In DJL, we use tracing to create TorchScript for our ModelZoo models.

Here is an example of tracing in actions:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)

# Switch the model to eval model
model.eval()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

# Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")

如果你使用了dropout等 一定要记得加上model.eval()再保存

对于我的来说 就下面这样

model = LinearModel()

model.load_state_dict(torch.load("model.pth"))

input = torch.tensor([0.72, 0.94]).float() //根据你的模型随便创建一个输入
    
script = torch.jit.trace(model, input)
    
script.save("model.pt")

然后该写Java代码了

官网例子:Load a PyTorch Model - Deep Java Library

还有这个:03 image classification with your model - Deep Java Library

我的数据就不需要transform了 代码:

//首先创建一个模型
Model model = Model.newInstance("test");
        try {
            model.load(Paths.get("C:\\Users\\Administrator\\IdeaProjects\\PytorchInJava\\src\\main\\resources\\model.pt"));
            System.out.println(model);

            //Predictor<参数类型,返回值类型> 输入图片的话参数是Image
            //我的参数是float32 不要写成Double
            Predictor<float[], Object> objectObjectPredictor = model.newPredictor(new NoBatchifyTranslator<float[], Object>() {
                @Override
                public NDList processInput(TranslatorContext translatorContext, float[] input) throws Exception {
                    NDManager ndManager = translatorContext.getNDManager();
                    NDArray ndArray = ndManager.create(input);
                    //ndArray作为输入
                    System.out.println(ndArray);
                    return new NDList(ndArray);
                }
                @Override
                public Object processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
                    System.out.println("process: " + ndList.get(0).getFloat());
                    return ndList.get(0).getFloat();
                }
            });

            float result = objectObjectPredictor.predict(new float[]{0.6144011f, 0.952401f});

            System.out.println("result: " + result);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (MalformedModelException e) {
            e.printStackTrace();
        } catch (Exception e) {
            System.out.println("qunimade ");
            e.printStackTrace();
        }

输出:

更新

当我打包成jar到centos7的linux中运行时,报错UnsatisfiedLinkError,经过大神的指导,问题出在我引的依赖。

修改后的依赖:

    <properties>
        <java.version>8</java.version>
        <jna.version>5.3.0</jna.version>
    </properties>


    <dependencies>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.16.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu-precxx11</artifactId>
            <classifier>linux-x86_64</classifier>
            <version>1.9.1</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-jni</artifactId>
            <version>1.9.1-0.16.0</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
    </dependencies>

评论 24
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

欧内的手好汗

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值