本文共 2793 字,大约阅读时间需要 9 分钟。
code:https://github.com/OAID/Tengine
version: 88b4b7a2 图片,代码都来自以上项目。
…
ir_graph_t* ir_graph = create_ir_graph((struct context*)context);
struct serializer* loader = find_serializer_via_name(model_format);
,然后调用load_model
方法graph
的tensor_list
,node_list
,输入输出的node
,sub_graphes
if (load_graph_tensors(tm2_s, graph, priv) < 0) // 载入tensor_list goto error;if (load_graph_nodes(tm2_s, graph, priv) < 0) // 初始化node_list goto error;if (set_graph_io_nodes(tm2_s, graph, priv) < 0) goto error;if (load_graph_sub_info(tm2_s, graph, priv) < 0) goto error;
load_graph_tensors
中会遍历预训练的模型,创建tensor
,包括type,data,shape等属性,并添加到graph.tensor_list
中。load_graph_nodes
中会遍历预训练模型中的nodes,创建node
,设置node的index,input_num, output_num, input_tensors, output_tensors, op
等属性,并添加到graph.node_list
source/operator/prototype
目录下所有op文件是用于node的op初始化,op->param_mem参数空间申请。#include "graph/tensor.h"#include "graph/node.h"#include "graph/graph.h"#include "module/module.h"#include "utility/sys_port.h"static int infer_shape(struct node* node) // static 函数指针的位置不变{ struct graph* ir_graph = node->graph; struct tensor* input = get_ir_graph_tensor(ir_graph, node->input_tensors[0]); struct tensor* output = get_ir_graph_tensor(ir_graph, node->output_tensors[0]); set_ir_tensor_shape(output, input->dims, input->dim_num); // 设置 output tensor 的尺寸 return 0;}static int init_op(struct op* op){ op->param_mem = NULL; op->param_size = 0; op->same_shape = 0; op->infer_shape = infer_shape; return 0;}static void release_op(struct op* op){ sys_free(op->param_mem);}int register_absval_op(){ struct method m; m.version = 1; m.init = init_op; // 函数指针 m.release = release_op; // 函数指针 return register_op(OP_ABSVAL, OP_ABSVAL_NAME, &m); // OP_ABSVAL:enum, OP_ABSVAL_NAME:字符串}int unregister_absval_op(){ return unregister_op(OP_ABSVAL, 1);}
source/serializer/tmfile/op
文件夹下的所有op文件,是对node->op->param_mem进行赋值,上一步是申请内存初始化。如果用c++的话,可以把这两步的操作放到一个类中,框架更加清晰static int batchnorm_op_map(int op){ return OP_BATCHNORM;}static int tm2_load_batchnorm(struct graph* ir_graph, struct node* ir_node, const TM2_Node* tm_node, const TM2_Operator* tm_op){ struct batchnorm_param* batchnorm_param = ( struct batchnorm_param* )ir_node->op.param_mem; const struct tm2_priv* tm2_priv = (struct tm2_priv*)ir_graph->serializer_privacy; const char* mem_base = tm2_priv->base; const TM2_BatchNormParam* tm_param = ( TM2_BatchNormParam* )(mem_base + tm_op->offset_t_param); batchnorm_param->rescale_factor = tm_param->rescale_factor; // op 的参数 batchnorm_param->eps = tm_param->eps; batchnorm_param->caffe_flavor = tm_param->caffe_flavor; return 0;}
转载地址:http://cuhws.baihongyu.com/