Skip to content

Commit 399ddf5

Browse files
committed
Adding basic test for Functions.
1 parent 3b94f5a commit 399ddf5

File tree

3 files changed

+112
-40
lines changed

3 files changed

+112
-40
lines changed

Sources/PerfectTensorFlow/APILoader.swift

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -785,33 +785,26 @@ public class TFLib {
785785
/// in this address space.
786786
public static var GetAllOpList: @convention(c) () -> UnsafeMutablePointer<TF_Buffer>? = { return nil }
787787

788-
// Adds a copy of function `func` and optionally its gradient function `grad`
789-
// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating
788+
/// Adds a copy of function `func` and optionally its gradient function `grad`
789+
/// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating
790790
// an operation using the function's name.
791-
// Any changes to `func`/`grad` (including deleting it) done after this method
792-
// returns, won't affect the copy of `func`/`grad` in `g`.
793-
// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no
794-
// effect on them, but can establish the function->gradient relationship
795-
// between them if `func` does not already have a gradient. If `func` already
796-
// has a gradient different from `grad`, an error is returned.
797-
//
798-
// `func` must not be null.
799-
// If `grad` is null and `func` is not in `g`, `func` is added without a
800-
// gradient.
801-
// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop.
802-
// `grad` must have appropriate signature as described in the doc of
803-
// GradientDef in tensorflow/core/framework/function.proto.
804-
//
805-
// If successful, status is set to OK and `func` and `grad` are added to `g`.
806-
// Otherwise, status is set to the encountered error and `g` is unmodified.
807-
/*
808-
TF_CAPI_EXPORT extern void TF_GraphCopyFunction(TF_Graph* g,
809-
const TF_Function* func,
810-
const TF_Function* grad,
811-
TF_Status* status);
812-
*/
813-
814-
public static var GraphCopyFunction: @convention(c) (OpaquePointer?, OpaquePointer?, OpaquePointer?, OpaquePointer?) -> Void = { _, _, _, _ in }
791+
/// Any changes to `func`/`grad` (including deleting it) done after this method
792+
/// returns, won't affect the copy of `func`/`grad` in `g`.
793+
/// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no
794+
/// effect on them, but can establish the function->gradient relationship
795+
/// between them if `func` does not already have a gradient. If `func` already
796+
/// has a gradient different from `grad`, an error is returned.
797+
///
798+
/// `func` must not be null.
799+
/// If `grad` is null and `func` is not in `g`, `func` is added without a
800+
/// gradient.
801+
/// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop.
802+
/// `grad` must have appropriate signature as described in the doc of
803+
/// GradientDef in tensorflow/core/framework/function.proto.
804+
///
805+
/// If successful, status is set to OK and `func` and `grad` are added to `g`.
806+
/// Otherwise, status is set to the encountered error and `g` is unmodified.
807+
public static var GraphCopyFunction: @convention(c) (OpaquePointer, OpaquePointer, OpaquePointer?, OpaquePointer) -> Void = { _, _, _, _ in }
815808

816809
/// Create a TF_Function from a TF_Graph
817810
///

Sources/PerfectTensorFlow/PerfectTensorFlow.swift

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,10 +1904,14 @@ public class TensorFlow {
19041904
/// - outputNames: [String], The names of the function's outputs. Must either have the same length as `outputs` or be null. In the former case, the names should match the regular expression for ArgDef names - "[a-z][a-z0-9_]*". In the latter case, names for outputs will be generated automatically.
19051905
/// - options: various options for the function, e.g. XLA's inlining control.
19061906
/// - description: optional human-readable description of this function
1907-
public func toFunction(_ name: String, appendHashToFunctionName: Bool = false, operations: [Operation], inputs: [Output], outputs: [Output], outputNames: [String], options: OpaquePointer? = nil, description: String = "") throws -> Function {
1908-
guard outputs.count == outputNames.count else {
1909-
throw Panic.FAULT(reason: "Output array elements are mismatched with names")
1910-
}
1907+
public func toFunction(
1908+
_ name: String, appendHashToFunctionName: Bool = false,
1909+
operations: [Operation] = [],
1910+
inputs: [Output] = [],
1911+
outputs: [Output] = [],
1912+
outputNames: [String] = [],
1913+
options: OpaquePointer? = nil,
1914+
description: String = "") throws -> Function {
19111915
let status = try Status()
19121916
let opera:UnsafePointer<OpaquePointer?>? = operations.map { $0.operation }
19131917
.withUnsafeBufferPointer { $0.baseAddress }
@@ -1929,7 +1933,8 @@ public class TensorFlow {
19291933
Int32(outputs.count > 0 ? outputs.count: 0),
19301934
outputs.count > 0 ? pOutpus : nil,
19311935

1932-
outputs.count > 0 && outputNames.count == outputs.count ? pOutputNames : nil,
1936+
outputNames.count > 0
1937+
&& outputNames.count == outputs.count ? pOutputNames : nil,
19331938

19341939
options, description.isEmpty ? nil: description,
19351940

@@ -1941,6 +1946,32 @@ public class TensorFlow {
19411946
return Function(fun)
19421947
}
19431948

1949+
/// Adds a copy of function `func` and optionally its gradient function `grad`
1950+
/// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating
1951+
/// an operation using the function's name.
1952+
/// Any changes to `func`/`grad` (including deleting it) done after this method
1953+
/// returns, won't affect the copy of `func`/`grad` in `g`.
1954+
/// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no
1955+
/// effect on them, but can establish the function->gradient relationship
1956+
/// between them if `func` does not already have a gradient. If `func` already
1957+
/// has a gradient different from `grad`, an error is returned.
1958+
/// If `grad` is null and `func` is not in `g`, `func` is added without a
1959+
/// gradient.
1960+
/// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop.
1961+
/// `grad` must have appropriate signature as described in the doc of
1962+
/// GradientDef in tensorflow/core/framework/function.proto.
1963+
/// - parameters:
1964+
/// - function: function to add
1965+
/// - grad: the gradient function to add with.
1966+
/// - throws: Panic.FAULT
1967+
public func copy(function: Function, grad: Function? = nil) throws {
1968+
let status = try Status()
1969+
TFLib.GraphCopyFunction(self.graph, function.ref, grad?.ref, status.status)
1970+
guard status.code == .OK else {
1971+
throw Panic.FAULT(reason: status.message)
1972+
}
1973+
}
1974+
19441975
/// Function is a grouping of operations with defined inputs and outputs.
19451976
/// Once created and added to graphs, functions can be invoked by creating an
19461977
/// operation whose operation type matches the function name.
@@ -2023,17 +2054,17 @@ public class TensorFlow {
20232054
return nil
20242055
}
20252056
}
2026-
}
20272057

2028-
/// get definition
2029-
public var def: FunctionDef? {
2030-
if let buf = self.buffer, let proto = buf.data {
2031-
return try? FunctionDef(serializedData: proto)
2032-
} else {
2033-
return nil
2058+
/// get definition
2059+
public var definition: FunctionDef? {
2060+
if let buf = self.buffer, let proto = buf.data {
2061+
return try? FunctionDef(serializedData: proto)
2062+
} else {
2063+
return nil
2064+
}
20342065
}
2035-
}
20362066

2067+
}
20372068
}//end graph
20382069

20392070
/// class wrapper of Graph Definition Options

Tests/PerfectTensorFlowTests/PerfectTensorFlowTests.swift

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ public extension Data {
114114
}
115115
}
116116

117+
extension Array where Element == TF.Output {
118+
func useHelper(graph: TF.Graph, _ operationType: String, _ name: String)
119+
throws -> TF.Operation {
120+
var desc = try TF.OperationBuilder(graph: graph, name: name, type: operationType)
121+
self.forEach { inp in
122+
desc = desc.add(input: inp)
123+
}
124+
return try desc.set(device: "/cpu:0").build()
125+
}
126+
}
127+
117128
class PerfectTensorFlowTests: XCTestCase {
118129

119130
static var allTests = [
@@ -148,9 +159,46 @@ class PerfectTensorFlowTests: XCTestCase {
148159
("testMatrix", testMatrix),
149160
("testBasicImproved",testBasicImproved),
150161
("testDevices", testDevices),
151-
("testEventAndSummary", testEventAndSummary)
162+
("testEventAndSummary", testEventAndSummary),
163+
("testFunctionBasic", testFunctionBasic)
152164
]
153165

166+
func testFunctionBasic() {
167+
do {
168+
let funcName = "MyFunc"
169+
let nodeName = "MyFunc_0"
170+
let funcGraph = try TF.Graph()
171+
let hostGraph = try TF.Graph()
172+
let c = try funcGraph.scalar(10, name: "scalar10")
173+
let function = try funcGraph.toFunction(funcName, outputs: [c.asOutput(0)])
174+
try hostGraph.copy(function: function)
175+
let nullInput: [TF.Output] = []
176+
let funOp = try nullInput.useHelper(graph: hostGraph, funcName, nodeName)
177+
let s = try hostGraph.runner().fetch(funOp).run()
178+
XCTAssertEqual(s.count, 1)
179+
let res:[Int32] = try s[0].asArray()
180+
XCTAssertEqual(res[0], 10)
181+
guard let def = function.definition else {
182+
XCTFail("function definition failure")
183+
return
184+
}
185+
let node = def.nodeDef[0]
186+
XCTAssertEqual(node.name, "scalar10_0")
187+
guard let value = node.attr["value"] else {
188+
XCTFail("function invalid value")
189+
return
190+
}
191+
XCTAssertEqual( value.tensor.intVal, [10])
192+
guard let ret = def.ret["scalar10"] else {
193+
XCTFail("function return fault")
194+
return
195+
}
196+
XCTAssertEqual(ret, "scalar10_0:output:0")
197+
}catch {
198+
XCTFail("functions: \(error)")
199+
}
200+
}
201+
154202
func testDevices() {
155203
do {
156204
let g = try TF.Graph()

0 commit comments

Comments
 (0)