0
点赞
收藏
分享

微信扫一扫

技术分析测试

Separes 2024-03-02 阅读 9

在这里插入图片描述

import numpy as np
import onnx
import onnxruntime
import onnxruntime.backend as backend

model = onnx.load('test.onnx')
node = model.graph.node
graph = model.graph
 
# 1.2搜索目标节点
# for i in range(len(node)):
#     if node[i].op_type == 'Conv':
#         node_rise = node[i]
#         if node_rise.output[0] == '203':
#             print(i)
# print(node[159])

new_node_0 = onnx.helper.make_node(
    "Mul",
    inputs=["input_image","1"],
    outputs=["mutiply"],
)

mutiply_node = onnx.helper.make_node(
    "Constant",
    inputs=[],
    outputs=["1"],
    value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [], [2.0])
)

new_node_1 = onnx.helper.make_node(
    "Add",
    inputs=["mutiply","2"],
    outputs=["add"],
)

add_node = onnx.helper.make_node(
    "Constant",
    inputs=[],
    outputs=["2"],
    value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [], [-1.0])
)

#删除老节点 
old_squeeze_node = model.graph.node[0]
old_squeeze_node.input[0] = "add"
model.graph.node.remove(old_squeeze_node)

graph.node.insert(0, mutiply_node)
graph.node.insert(1, new_node_0)
graph.node.insert(2, add_node)
graph.node.insert(3, new_node_1)
graph.node.insert(4, old_squeeze_node)
onnx.checker.check_model(model)
onnx.save(model, 'out.onnx')

# session = onnxruntime.InferenceSession("out.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# out = session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: np.ones([1, 1, 128, 128], dtype=np.float32)})[0]
# print(out)

print(onnxruntime.get_device())
rt = backend.prepare(model, "CPU")
out = rt.run(np.ones([1, 1, 128, 128], dtype=np.float32))
print(out)

在这里插入图片描述

第二种使用可供训练的初始化参数

import numpy as np
import onnx
import onnxruntime
import onnxruntime.backend as backend

model = onnx.load('test.onnx')
node = model.graph.node
graph = model.graph
 
# 1.2搜索目标节点
# for i in range(len(node)):
#     if node[i].op_type == 'Conv':
#         node_rise = node[i]
#         if node_rise.output[0] == '203':
#             print(i)
# print(node[159])

mutiply_node = onnx.helper.make_tensor(name='1',
                                      data_type=onnx.TensorProto.FLOAT,
                                      dims= [1],
                                      vals = np.array([2.0], dtype=np.float32)
                                        )

graph.initializer.append(mutiply_node)

new_node_0 = onnx.helper.make_node(
    "Mul",
    inputs=["input_image","1"],
    outputs=["mutiply"],
)

add_node = onnx.helper.make_tensor(name='2',
                                      data_type=onnx.TensorProto.FLOAT,
                                      dims= [1],
                                      vals = np.array([-1.], dtype=np.float32)
                                        )

graph.initializer.append(add_node)

new_node_1 = onnx.helper.make_node(
    "Add",
    inputs=["mutiply","2"],
    outputs=["add"],
)

#删除老节点 
old_squeeze_node = model.graph.node[0]
old_squeeze_node.input[0] = "add"
model.graph.node.remove(old_squeeze_node)

graph.node.insert(0, new_node_0)
graph.node.insert(1, new_node_1)
graph.node.insert(2, old_squeeze_node)
onnx.checker.check_model(model)
onnx.save(model, 'out.onnx')

# session = onnxruntime.InferenceSession("out.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# out = session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: np.ones([1, 1, 128, 128], dtype=np.float32)})[0]
# print(out)

print(onnxruntime.get_device())
rt = backend.prepare(model, "CPU")
out = rt.run(np.ones([1, 1, 128, 128], dtype=np.float32))
print(out)

在这里插入图片描述

举报

相关推荐

0 条评论