使用 TensorFlow 設計矩陣乘法計算並轉移執行在 Android 上 建國科技大學資管系 饒瑞佶 2017/8
Python 設計 Model import tensorflow as tf from tensorflow.python.tools import freeze_graph from tensorflow.python.tools import optimize_for_inference_lib I=tf.placeholder(tf.float32,shape=[None,3],name="I") # input W=tf.Variable(tf.zeros(shape=[3,2]),dtype=tf.float32,name="W") # weights b=tf.variable(tf.zeros(shape=[2]),dtype=tf.float32,name="b") #bias o=tf.nn.relu(tf.matmul(i,w)+b,name="final_result") # activation / output # o= I * W + b saver=tf.train.saver() init_op=tf.global_variables_initializer()
with tf.session() as sess: sess.run(init_op) #save the graph tf.train.write_graph(sess.graph_def,'.','tfdroid.pbtxt') #normally you would do some training here # but for now we will just assign something to W sess.run(tf.assign(w,[[1,2],[3,4],[5,6]])) sess.run(tf.assign(b,[1,1])) saver.save(sess,'./tfdroid.ckpt')
MODEL_NAME='tfdroid' #freeze the graph input_graph_path= MODEL_NAME + '.pbtxt' checkpoint_path='./' + MODEL_NAME + '.ckpt' input_saver_def_path="" input_binary=false output_node_names="final_result" restore_op_name="save/restore_all" filename_tensor_name="save/const:0" output_frozen_graph_name='frozen_' + MODEL_NAME + '.pb' output_optimized_graph_name='optimized_' + MODEL_NAME + '.pb' clear_devices=true freeze_graph.freeze_graph(input_graph_path,input_saver_def_path,input_binary,checkpoin t_path,output_node_names,restore_op_name,filename_tensor_name,output_frozen_graph _name,clear_devices,"")
#optimize for inference input_graph_def=tf.graphdef() with tf.gfile.open(output_frozen_graph_name,"rb") as f: data=f.read() input_graph_def.parsefromstring(data) output_graph_def=optimize_for_inference_lib.optimize_for_inference(input_graph_def, ["I"],["final_result"],tf.float32.as_datatype_enum) #Save the optimized graph f=tf.gfile.fastgfile(output_optimized_graph_name,"w") f.write(output_graph_def.serializetostring())
轉移到 Android 透過.pb 模型檔案
New Android Project
Empty Activity
加入 assets 目錄與.pb 檔案
加入.pb 檔案
使用 nightly build package https://ci.tensorflow.org/view/nightly/job/nightly-android/ 找 #44 版本
使用方式 1. 複製 libandroid_tensorflow_inference_java.jar 與 libtensorflow_inference.so 目錄內的子目錄到專案的 libs 目錄中
result
使用方式 2. 修改 build.gradle(app), 在 android 段加入 sourcesets { main { jnilibs.srcdirs = ['libs'] } }
Android 專案結構
layout <RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:tools="http://schemas.android.com/tools" android:id="@+id/activity_main" android:layout_width="match_parent" android:layout_height="match_parent" android:paddingbottom="@dimen/activity_vertical_margin" android:paddingleft="@dimen/activity_horizontal_margin" android:paddingright="@dimen/activity_horizontal_margin" android:paddingtop="@dimen/activity_vertical_margin" tools:context="com.example.rueychi.tf_to_android.mainactivity"> <EditText android:layout_width="100dp" android:layout_height="wrap_content" android:inputtype="textpersonname" android:text="2.3" android:ems="10" android:id="@+id/editnum1" android:layout_margintop="24dp" android:layout_marginstart="16dp" android:layout_alignparenttop="true" android:layout_alignparentstart="true" />
<EditText android:layout_height="wrap_content" android:inputtype="textpersonname" android:text="33" android:ems="10" android:id="@+id/editnum2" android:layout_width="100dp" android:layout_alignbaseline="@+id/editnum1" android:layout_alignbottom="@+id/editnum1" android:layout_centerhorizontal="true" /> <Button android:text="run" android:layout_width="wrap_content" android:layout_height="wrap_content" android:id="@+id/button" android:layout_below="@+id/editnum2" android:layout_centerhorizontal="true" android:layout_margintop="50dp" />
<TextView android:layout_width="wrap_content" android:layout_height="wrap_content" android:text="output" android:id="@+id/txtviewresult" android:layout_margintop="85dp" android:textalignment="center" android:layout_aligntop="@+id/button" android:layout_centerhorizontal="true" /> <EditText android:layout_height="wrap_content" android:inputtype="textpersonname" android:text="12" android:ems="10" android:id="@+id/editnum3" android:layout_width="100dp" android:layout_marginstart="11dp" android:layout_alignbaseline="@+id/editnum2" android:layout_alignbottom="@+id/editnum2" android:layout_toendof="@+id/editnum2" /> </RelativeLayout>
MainActivity.java
載入.so 檔案
code // 使用 Python 建立的.pb model private static final String MODEL_FILE = "file:///android_asset/optimized_tfdroid.pb"; private static final String INPUT_NODE = "I"; // 對應 Python 中的名稱 private static final String OUTPUT_NODE = final_result";// 對應 Python 中的名稱 private static final int[] INPUT_SIZE = {1,3};// 設定輸入的大小 // 透過 TensorFlowInferenceInterface 建立 TensorFlow 物件 ( 來自.jar 檔案 ) private TensorFlowInferenceInterface inferenceinterface; // 載入.so 檔案 static { System.loadLibrary("tensorflow_inference"); }
@Override protected void oncreate(bundle savedinstancestate) { super.oncreate(savedinstancestate); setcontentview(r.layout.activity_main); // 建立 tensorflow 物件 inferenceinterface = new TensorFlowInferenceInterface(); // 透過 tensorflow 物件載入.pb 模型 inferenceinterface.initializetensorflow(getassets(), MODEL_FILE);
final Button button = (Button) findviewbyid(r.id.button); button.setonclicklistener(new View.OnClickListener() { public void onclick(view v) { // 輸入的 3 個資料 final EditText editnum1 = (EditText) findviewbyid(r.id.editnum1); final EditText editnum2 = (EditText) findviewbyid(r.id.editnum2); final EditText editnum3 = (EditText) findviewbyid(r.id.editnum3); float num1 = Float.parseFloat(editNum1.getText().toString()); float num2 = Float.parseFloat(editNum2.getText().toString()); float num3 = Float.parseFloat(editNum3.getText().toString()); // Input I tensor float[] inputfloats = {num1, num2, num3}; // 將 I 輸入到 tensorflow inferenceinterface.fillnodefloat(input_node, INPUT_SIZE, inputfloats); // 開始執行 model inferenceinterface.runinference(new String[]{OUTPUT_NODE}); // 輸出 tensor float[] resu = {0, 0}; // 讀出結果 O inferenceinterface.readnodefloat(output_node, resu); // 顯示結果 final TextView textviewr = (TextView) findviewbyid(r.id.txtviewresult); textviewr.settext(float.tostring(resu[0]) + ", " + Float.toString(resu[1])); } });
可以試試回去 Python 改 原來 o=tf.nn.relu(tf.matmul(i,w)+b,name="final_result") 改成 o=tf.nn.relu(tf.matmul(i,w)*b,name="final_result") 重新執行看看結果