Skip to content

Commit 0faac1f

Browse files
committed
Finishing Function for TensorFlow 1.4.0
1 parent 399ddf5 commit 0faac1f

File tree

5 files changed

+92
-4
lines changed

5 files changed

+92
-4
lines changed

GUIDE.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,4 +504,42 @@ let dev = try g.runner().session.devices
504504
print(dev)
505505
// sample output:
506506
// ["/job:localhost/replica:0/task:0/cpu:0": (type: "CPU", memory: 268435456)]
507+
```
508+
509+
### Functions of a Graph
510+
511+
Since TensorFlow 1.4.0+, a new feature named `Function` has been introduced as a nested class of Graph.
512+
513+
A `Function` object can be either built by `graph.toFunction()` or imported by its definition protocol buffer:
514+
515+
``` swift
516+
// build a function from the current graph
517+
let function = try graph.toFunction("funcName",
518+
operations:[], inputs:[] outputs: [operation.asOutput(0)],
519+
outputNames: [], description: "myFunc")
520+
521+
// get the function definition buffer:
522+
guard let def = function.definition else {
523+
// something wrong
524+
}
525+
526+
// import a function from its definition buffer
527+
let function2 = try TF.Graph.Function(importDefinition: def)
528+
// now function == function2
529+
```
530+
531+
The purpose of `Function` is to copy its gradient function from one graph to another:
532+
533+
``` swift
534+
let function = try graph1.toFunction(...)
535+
try graph.copy(function: function)
536+
```
537+
538+
You can also get and set a function object's attribute:
539+
540+
``` swift
541+
// assuming value is a TF.AttrValue protocol buffer
542+
try function.setAttributeFor("foo_attr", value: value)
543+
let value2 = try function.getAttributeFor("foo_attr")
544+
// now value2 should be the same as value
507545
```

GUIDE.zh_CN.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,4 +498,42 @@ let dev = try g.runner().session.devices
498498
print(dev)
499499
// 样本输出:
500500
// ["/job:localhost/replica:0/task:0/cpu:0": (type: "CPU", memory: 268435456)]
501+
```
502+
503+
### 运算流程图中的函数
504+
505+
自 TensorFlow 1.4.0 版本开始,本函数库新增了一个内置对象类 `Function` 用于表达流程图中的运算函数
506+
507+
一个函数对象可以通过运算流程图`graph.toFunction()`方法创建,或者从其定义缓冲字节中导入:
508+
509+
``` swift
510+
// 从当前流程图中创建一个函数
511+
let function = try graph.toFunction("函数名",
512+
operations:[], inputs:[] outputs: [operation.asOutput(0)],
513+
outputNames: [], description: "我的函数")
514+
515+
// 获取函数的协议缓冲字节
516+
guard let def = function.definition else {
517+
// something wrong
518+
}
519+
520+
// 或者从已经保存的协议缓冲字节中导入一个函数
521+
let function2 = try TF.Graph.Function(importDefinition: def)
522+
// now function == function2
523+
```
524+
525+
运算流程图中函数对象的设置是为了方便地把一组运算(的梯度函数)从一个流程图复制到另外一个流程图上去:
526+
527+
``` swift
528+
let function = try graph1.toFunction(...)
529+
try graph.copy(function: function)
530+
```
531+
532+
同时还可以随时读取或者设置函数对象的属性:
533+
534+
``` swift
535+
// 假设value是一个 TF.AttrValue 协议缓冲字节
536+
try function.setAttributeFor("某属性", value: value)
537+
let value2 = try function.getAttributeFor("某属性")
538+
// 现在 value 和 value2 应该是完全一致的
501539
```

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ These files are the key part of Perfect-TensorFlow:
5656
```
5757
Sources
5858
├── PerfectTensorFlow
59-
│   ├── APILoader.swift (928 lines, translated from tensorflow/c/c_api.h)
60-
│   ├── PerfectTensorFlow.swift (2436 lines)
59+
│   ├── APILoader.swift (1099 lines, translated from tensorflow/c/c_api.h)
60+
│   ├── PerfectTensorFlow.swift (2701 lines)
6161
└── TensorFlowAPI
6262
├── TensorFlowAPI.c (72 lines)
6363
└── include

README.zh_CN.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@
5656
```
5757
Sources
5858
├── PerfectTensorFlow
59-
│   ├── APILoader.swift (928 行代码,直接从tensorflow/c/c_api.h翻译而来)
60-
│   ├── PerfectTensorFlow.swift (2436 行代码)
59+
│   ├── APILoader.swift (1099 行代码,直接从tensorflow/c/c_api.h翻译而来)
60+
│   ├── PerfectTensorFlow.swift (2701 行代码)
6161
└── TensorFlowAPI
6262
├── TensorFlowAPI.c (72 行代码)
6363
└── include

Tests/PerfectTensorFlowTests/PerfectTensorFlowTests.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class PerfectTensorFlowTests: XCTestCase {
167167
do {
168168
let funcName = "MyFunc"
169169
let nodeName = "MyFunc_0"
170+
let attrName = "foo_attr"
170171
let funcGraph = try TF.Graph()
171172
let hostGraph = try TF.Graph()
172173
let c = try funcGraph.scalar(10, name: "scalar10")
@@ -194,6 +195,17 @@ class PerfectTensorFlowTests: XCTestCase {
194195
return
195196
}
196197
XCTAssertEqual(ret, "scalar10_0:output:0")
198+
199+
let function2 = try TF.Graph.Function(importDefinition: def)
200+
guard let def2 = function2.definition else {
201+
XCTFail("function import / export failure")
202+
return
203+
}
204+
let node2 = def2.nodeDef[0]
205+
XCTAssertEqual(node, node2)
206+
try function2.setAttributeFor(attrName, value: value)
207+
let value2 = try function2.getAttributeFor(attrName)
208+
XCTAssertEqual(value, value2)
197209
}catch {
198210
XCTFail("functions: \(error)")
199211
}

0 commit comments

Comments
 (0)