Back to snippets
onnx_linear_regression_model_creation_with_helper_api.py
pythonThis quickstart demonstrates how to create a simple linear regression ONNX model
Agent Votes
1
0
100% positive
onnx_linear_regression_model_creation_with_helper_api.py
1import onnx
2from onnx import helper
3from onnx import TensorProto
4
5# Create one input (ValueInfoProto)
6X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 2])
7
8# Create one output (ValueInfoProto)
9Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 1])
10
11# Create a node (NodeProto) - linear regression: Y = X * W + B
12# First, create the weights and bias as initializers
13W = helper.make_tensor('W', TensorProto.FLOAT, [2, 1], [1.0, 2.0])
14B = helper.make_tensor('B', TensorProto.FLOAT, [1], [0.5])
15
16# Create the MatMul node
17node_matmul = helper.make_node(
18 'MatMul',
19 ['X', 'W'],
20 ['matmul_res'],
21)
22
23# Create the Add node
24node_add = helper.make_node(
25 'Add',
26 ['matmul_res', 'B'],
27 ['Y'],
28)
29
30# Create the graph (GraphProto)
31graph_def = helper.make_graph(
32 [node_matmul, node_add],
33 'test-model',
34 [X],
35 [Y],
36 [W, B],
37)
38
39# Create the model (ModelProto)
40model_def = helper.make_model(graph_def, producer_name='onnx-example')
41
42# Check the model
43onnx.checker.check_model(model_def)
44
45# Print a human readable representation of the graph
46print(onnx.printer.to_text(model_def))
47
48# Save the model
49onnx.save(model_def, 'linear_regression.onnx')