博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【Tengine端侧推理框架】——创建graph
阅读量:4304 次
发布时间:2019-05-27

本文共 2793 字,大约阅读时间需要 9 分钟。

code:https://github.com/OAID/Tengine

version: 88b4b7a2
图片,代码都来自以上项目。
在这里插入图片描述

1. 简介

2. 正题

  1. 创建一个空的graph,ir_graph_t* ir_graph = create_ir_graph((struct context*)context);
  2. 找到模型序列化器struct serializer* loader = find_serializer_via_name(model_format);,然后调用load_model方法
  3. load_graph,会初始化graphtensor_list,node_list,输入输出的nodesub_graphes
if (load_graph_tensors(tm2_s, graph, priv) < 0)         // 载入tensor_list    goto error;if (load_graph_nodes(tm2_s, graph, priv) < 0)           // 初始化node_list    goto error;if (set_graph_io_nodes(tm2_s, graph, priv) < 0)    goto error;if (load_graph_sub_info(tm2_s, graph, priv) < 0)    goto error;
  1. load_graph_tensors中会遍历预训练的模型,创建tensor,包括type,data,shape等属性,并添加到graph.tensor_list中。
  2. load_graph_nodes中会遍历预训练模型中的nodes,创建node,设置node的index,input_num, output_num, input_tensors, output_tensors, op等属性,并添加到graph.node_list
  3. source/operator/prototype目录下所有op文件是用于node的op初始化,op->param_mem参数空间申请。
#include "graph/tensor.h"#include "graph/node.h"#include "graph/graph.h"#include "module/module.h"#include "utility/sys_port.h"static int infer_shape(struct node* node)       // static 函数指针的位置不变{
struct graph* ir_graph = node->graph; struct tensor* input = get_ir_graph_tensor(ir_graph, node->input_tensors[0]); struct tensor* output = get_ir_graph_tensor(ir_graph, node->output_tensors[0]); set_ir_tensor_shape(output, input->dims, input->dim_num); // 设置 output tensor 的尺寸 return 0;}static int init_op(struct op* op){
op->param_mem = NULL; op->param_size = 0; op->same_shape = 0; op->infer_shape = infer_shape; return 0;}static void release_op(struct op* op){
sys_free(op->param_mem);}int register_absval_op(){
struct method m; m.version = 1; m.init = init_op; // 函数指针 m.release = release_op; // 函数指针 return register_op(OP_ABSVAL, OP_ABSVAL_NAME, &m); // OP_ABSVAL:enum, OP_ABSVAL_NAME:字符串}int unregister_absval_op(){
return unregister_op(OP_ABSVAL, 1);}
  1. source/serializer/tmfile/op文件夹下的所有op文件,是对node->op->param_mem进行赋值,上一步是申请内存初始化。如果用c++的话,可以把这两步的操作放到一个类中,框架更加清晰
static int batchnorm_op_map(int op){
return OP_BATCHNORM;}static int tm2_load_batchnorm(struct graph* ir_graph, struct node* ir_node, const TM2_Node* tm_node, const TM2_Operator* tm_op){
struct batchnorm_param* batchnorm_param = ( struct batchnorm_param* )ir_node->op.param_mem; const struct tm2_priv* tm2_priv = (struct tm2_priv*)ir_graph->serializer_privacy; const char* mem_base = tm2_priv->base; const TM2_BatchNormParam* tm_param = ( TM2_BatchNormParam* )(mem_base + tm_op->offset_t_param); batchnorm_param->rescale_factor = tm_param->rescale_factor; // op 的参数 batchnorm_param->eps = tm_param->eps; batchnorm_param->caffe_flavor = tm_param->caffe_flavor; return 0;}

转载地址:http://cuhws.baihongyu.com/

你可能感兴趣的文章
规范性附录 属性值代码
查看>>
提取面狭长角
查看>>
Arcsde表空间自动增长
查看>>
Arcsde报ora-29861: 域索引标记为loading/failed/unusable错误
查看>>
记一次断电恢复ORA-01033错误
查看>>
C#修改JPG图片EXIF信息中的GPS信息
查看>>
从零开始的Docker ELK+Filebeat 6.4.0日志管理
查看>>
Sequelize的原始查询的时区问题
查看>>
How it works(1) winston3源码阅读(A)
查看>>
How it works(2) autocannon源码阅读(A)
查看>>
How it works(3) Tilestrata源码阅读(A)
查看>>
How it works(12) Tileserver-GL源码阅读(A) 服务的初始化
查看>>
uni-app 全局变量的几种实现方式
查看>>
echarts 为例讲解 uni-app 如何引用 npm 第三方库
查看>>
uni-app跨页面、跨组件通讯
查看>>
springmvc-helloworld(idea)
查看>>
JDK下载(百度网盘)
查看>>
idea用得溜,代码才能码得快
查看>>
一篇掌握python魔法方法详解
查看>>
数据结构和算法5-非线性-树
查看>>