@@ -1904,10 +1904,14 @@ public class TensorFlow {
1904
1904
/// - outputNames: [String], The names of the function's outputs. Must either have the same length as `outputs` or be null. In the former case, the names should match the regular expression for ArgDef names - "[a-z][a-z0-9_]*". In the latter case, names for outputs will be generated automatically.
1905
1905
/// - options: various options for the function, e.g. XLA's inlining control.
1906
1906
/// - description: optional human-readable description of this function
1907
- public func toFunction( _ name: String , appendHashToFunctionName: Bool = false , operations: [ Operation ] , inputs: [ Output ] , outputs: [ Output ] , outputNames: [ String ] , options: OpaquePointer ? = nil , description: String = " " ) throws -> Function {
1908
- guard outputs. count == outputNames. count else {
1909
- throw Panic . FAULT ( reason: " Output array elements are mismatched with names " )
1910
- }
1907
+ public func toFunction(
1908
+ _ name: String , appendHashToFunctionName: Bool = false ,
1909
+ operations: [ Operation ] = [ ] ,
1910
+ inputs: [ Output ] = [ ] ,
1911
+ outputs: [ Output ] = [ ] ,
1912
+ outputNames: [ String ] = [ ] ,
1913
+ options: OpaquePointer ? = nil ,
1914
+ description: String = " " ) throws -> Function {
1911
1915
let status = try Status ( )
1912
1916
let opera : UnsafePointer < OpaquePointer ? > ? = operations. map { $0. operation }
1913
1917
. withUnsafeBufferPointer { $0. baseAddress }
@@ -1929,7 +1933,8 @@ public class TensorFlow {
1929
1933
Int32 ( outputs. count > 0 ? outputs. count: 0 ) ,
1930
1934
outputs. count > 0 ? pOutpus : nil ,
1931
1935
1932
- outputs. count > 0 && outputNames. count == outputs. count ? pOutputNames : nil ,
1936
+ outputNames. count > 0
1937
+ && outputNames. count == outputs. count ? pOutputNames : nil ,
1933
1938
1934
1939
options, description. isEmpty ? nil : description,
1935
1940
@@ -1941,6 +1946,32 @@ public class TensorFlow {
1941
1946
return Function ( fun)
1942
1947
}
1943
1948
1949
+ /// Adds a copy of function `func` and optionally its gradient function `grad`
1950
+ /// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating
1951
+ /// an operation using the function's name.
1952
+ /// Any changes to `func`/`grad` (including deleting it) done after this method
1953
+ /// returns, won't affect the copy of `func`/`grad` in `g`.
1954
+ /// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no
1955
+ /// effect on them, but can establish the function->gradient relationship
1956
+ /// between them if `func` does not already have a gradient. If `func` already
1957
+ /// has a gradient different from `grad`, an error is returned.
1958
+ /// If `grad` is null and `func` is not in `g`, `func` is added without a
1959
+ /// gradient.
1960
+ /// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop.
1961
+ /// `grad` must have appropriate signature as described in the doc of
1962
+ /// GradientDef in tensorflow/core/framework/function.proto.
1963
+ /// - parameters:
1964
+ /// - function: function to add
1965
+ /// - grad: the gradient function to add with.
1966
+ /// - throws: Panic.FAULT
1967
+ public func copy( function: Function , grad: Function ? = nil ) throws {
1968
+ let status = try Status ( )
1969
+ TFLib . GraphCopyFunction ( self . graph, function. ref, grad? . ref, status. status)
1970
+ guard status. code == . OK else {
1971
+ throw Panic . FAULT ( reason: status. message)
1972
+ }
1973
+ }
1974
+
1944
1975
/// Function is a grouping of operations with defined inputs and outputs.
1945
1976
/// Once created and added to graphs, functions can be invoked by creating an
1946
1977
/// operation whose operation type matches the function name.
@@ -2023,17 +2054,17 @@ public class TensorFlow {
2023
2054
return nil
2024
2055
}
2025
2056
}
2026
- }
2027
2057
2028
- /// get definition
2029
- public var def : FunctionDef ? {
2030
- if let buf = self . buffer, let proto = buf. data {
2031
- return try ? FunctionDef ( serializedData: proto)
2032
- } else {
2033
- return nil
2058
+ /// get definition
2059
+ public var definition : FunctionDef ? {
2060
+ if let buf = self . buffer, let proto = buf. data {
2061
+ return try ? FunctionDef ( serializedData: proto)
2062
+ } else {
2063
+ return nil
2064
+ }
2034
2065
}
2035
- }
2036
2066
2067
+ }
2037
2068
} //end graph
2038
2069
2039
2070
/// class wrapper of Graph Definition Options
0 commit comments