微信扫码
添加专属顾问
我要投稿
model
的文件夹中,以便其他脚本调用。import torch
import torch.nn as nn
# 定义简单的全连接神经网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 128)# 28*28 = 784 输入节点, 128 输出节点
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 2)# 输出10类,对应MNIST的10个数字
def forward(self, x):
x = x.view(-1, 784)# flatten the image
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return torch.log_softmax(x, dim=1)
# 从model模块导入SimpleNet类from model.simple_nn import SimpleNet# 导入PyTorch库import torch# 设置保存模型参数的文件路径path_state_dict = '3_simple_nn_model.pth'# 创建SimpleNet类的实例,这会初始化模型model = SimpleNet()# 加载保存的模型参数到模型实例model.load_state_dict(torch.load(path_state_dict))# 将模型切换到评估模式,这对于进行预测是必要的,因为它会禁用一些特定于训练阶段的操作,比如Dropoutmodel.eval()# 生成随机输入数据:1个样本,每个样本10个特征input_data = torch.randn(1, 784)# 使用模型进行预测,此时不计算梯度以提高性能和减少内存使用with torch.no_grad(): output = model(input_data)
2. 载入和使用模型
# 从model模块导入SimpleNet类
from model.simple_nn import SimpleNet
# 导入PyTorch库
import torch
# 设置保存模型参数的文件路径
path_state_dict = '3_simple_nn_model.pth'
# 创建SimpleNet类的实例,这会初始化模型
model = SimpleNet()
# 加载保存的模型参数到模型实例
model.load_state_dict(torch.load(path_state_dict))
# 将模型切换到评估模式,这对于进行预测是必要的,因为它会禁用一些特定于训练阶段的操作,比如Dropout
model.eval()
# 生成随机输入数据:1个样本,每个样本10个特征
input_data = torch.randn(1, 784)
# 使用模型进行预测,此时不计算梯度以提高性能和减少内存使用
with torch.no_grad():
output = model(input_data)
SimpleNet_Predictor
) 来展示如何优雅地在应用中使用模型,并将其存储在python包中(utils_AI
)。import torch
from torchvision import transforms
from model.simple_nn import SimpleNet
class SimpleNet_Predictor:
def __init__(self, model_path):
"""
初始化模型预测器,加载模型并进行热身。
:param model_class: 模型类,用于实例化模型
:param model_path: 预训练模型的路径,用于加载模型参数
"""
# 实例化模型
self.model = SimpleNet()
# 加载模型参数
self.model.load_state_dict(torch.load(model_path))
# 切换到评估模式
self.model.eval()
# 定义数据预处理
self.transform = transforms.Compose([
# 此处为空,可自定义
])
# 进行模型热身,以确保模型在首次调用时响应迅速
self.warmup()
def warmup(self):
"""
执行一次前向传递以热身模型。
"""
with torch.no_grad():
# 使用一些随机数据进行热身
random_input = torch.randn(1, 784)
random_input = self.transform(random_input)# 应用转换
self.model(random_input)
def __call__(self, input_data):
"""
使得该类实例可以像函数那样被调用,进行模型预测。
:param input_data: 输入数据,应为torch.Tensor格式
:return: 模型的输出结果
"""
with torch.no_grad():
input_data = self.transform(input_data)# 应用转换
return self.model(input_data)
from utils_AI.simplenet_predictor import SimpleNet_Predictor
import torch
model_path = '3_simple_nn_model.pth'# 假设模型参数已经保存在这个路径
predictor = SimpleNet_Predictor(model_path)
# 测试模型预测
test_input = torch.randn(1, 784)
output = predictor(test_input)
print("Model output shape:", output.shape)
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-07-06
比Kimi还好用?AI写作神器「橙篇」来势汹汹 欲夺长文创作之未来
2024-07-06
暴走WAIC:跟AI+教育有关的,都在这儿↑
2024-07-02
【研究成果】ArchGPT:利用大语言模型支持传统建筑遗产的更新与保护
2024-06-28
所有男生女生,AI 卖货主播来咯!
2024-06-28
AI+医疗专题报告:院内场景丰富,AI 全面赋能医疗健康领域
2024-06-20
AI 背后 B 端设计师的机会
2024-06-20
30 款让教师工作更轻松的 AI 工具
2024-06-13
知识图谱(KG)和大模型(LLMs)双轮驱动的企业级AI平台构建之道暨行业调研
2024-05-03
2024-04-28
2024-05-25
2024-07-18
2023-07-06
2023-06-30
2024-04-30
2023-06-29
2024-10-17
2023-07-03