Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

Commit

Permalink
Lazy tensor: automatically promote constants to inputs based on histo…
Browse files Browse the repository at this point in the history
…ry (#476)
  • Loading branch information
bgogul authored Aug 30, 2019
1 parent 5a8ac7b commit 33fe7f3
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Sources/TensorFlow/Bindings/TFTensorOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

/// Opaque reference to a function that has been made callable by loading it
/// into the runtime.
public struct _TensorFunctionPointer {
public struct _TensorFunctionPointer: Equatable {
public var name: String
public init(name: String) {
self.name = name
Expand Down
2 changes: 1 addition & 1 deletion Sources/TensorFlow/Core/DataTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import CTensorFlow
// This simply wraps a `TF_DataType` and allows user code to handle
// `TF_DataType` without importing CTensorFlow, which pollutes the namespace
// with TensorFlow C API declarations.
public struct TensorDataType {
public struct TensorDataType: Equatable {
public var _cDataType: TF_DataType

@usableFromInline
Expand Down
3 changes: 3 additions & 0 deletions Sources/TensorFlow/Core/LazyTensorContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class LazyTensorOperationsTracker {
struct LazyTensorContext {
var operationsTracker = LazyTensorOperationsTracker()
var isShapeTrackingEnabled = true
/// Should constants in trace be heuristically promoted to inputs automatically?
/// (See `LazyTensorTraceCache`)
var shouldPromoteConstants = true

static var local: LazyTensorContext {
_read { yield _ThreadLocalState.local.lazyTensorContext }
Expand Down
4 changes: 2 additions & 2 deletions Sources/TensorFlow/Core/LazyTensorOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class LazyTensorOperation: TensorOperation {
case list([LazyTensorHandle])
}

enum Attribute {
enum Attribute: Equatable {
case boolValue(Bool)
case intValue(Int)
case floatValue(Float)
Expand All @@ -199,7 +199,7 @@ class LazyTensorOperation: TensorOperation {
case optionalTensorShapeArray([TensorShape?])
}

let name: String
var name: String
let outputCount: Int
var inputs: [Input]
var attributes: [String: Attribute]
Expand Down
5 changes: 4 additions & 1 deletion Sources/TensorFlow/Core/LazyTensorTrace.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,13 @@ class LazyTensorTraceBuilder {
inputs: builder.inputs,
operations: builder.operations,
outputs: builder.outputs)
return MaterializationTraceInfo(
let materializationTraceInfo = MaterializationTraceInfo(
lazyOperations: builder.originalOutputs,
trace: trace,
concreteInputs: builder.inputValues)
return LazyTensorContext.local.shouldPromoteConstants
? LazyTensorTraceCache.traceWithPromotedConstants(materializationTraceInfo)
: materializationTraceInfo
}

static func materializationTraceInfo(
Expand Down
158 changes: 158 additions & 0 deletions Sources/TensorFlow/Core/LazyTensorTraceCache.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import CTensorFlow

extension TFETensorHandle: Equatable {}

public func ==(_ lhs: TFETensorHandle, _ rhs: TFETensorHandle) -> Bool {
return lhs._cTensorHandle == rhs._cTensorHandle
}

extension TFETensorHandle {
/// Returns true if the underlying tensors are equal.
func elementsEqual(_ other: TFETensorHandle) -> Bool {
let selfDtype = TFE_TensorHandleDataType(self._cTensorHandle)
let otherDtype = TFE_TensorHandleDataType(other._cTensorHandle)
precondition(
selfDtype == otherDtype && selfDtype != TF_VARIANT && selfDtype != TF_RESOURCE,
"Datatypes of tensor handles don't match.")
let op = TFE_Op("Equal", 1)
op.updateAttribute("T", TensorDataType(selfDtype))
op.addInput(self)
op.addInput(other)
let result: Tensor<Bool> = op.execute(Int(1))
return result.scalars.allSatisfy { $0 }
}
}

extension LazyTensorHandle {
func isEquivalent(to other: LazyTensorHandle) -> Bool {
switch (self.handle, other.handle) {
case let (.concrete(x, _), .concrete(y, _)):
return x == y
case let (.symbolic(x, xi, _), .symbolic(y, yi, _)):
return xi == yi && x.id == y.id
default: return false
}
}
}

extension LazyTensorOperation.Input {
/// Returns true if these inputs are equivalent when comparing lazy tensor traces.
func isEquivalent(to other: LazyTensorOperation.Input) -> Bool {
switch (self, other) {
case let (.single(l), .single(r)):
return l.isEquivalent(to: r)
case let (.list(l), .list(r)):
return l.elementsEqual(r, by: { $0.isEquivalent(to: $1) })
default:
return false
}
}
}

extension LazyTensorOperation {
/// Returns true if these operations are equivalent when comparing lazy tensor traces.
func isEquivalent(to other: LazyTensorOperation) -> Bool {
return self.name == other.name &&
self.outputCount == other.outputCount &&
self.deviceName == other.deviceName &&
self.inputs.elementsEqual(other.inputs, by: { $0.isEquivalent(to: $1) }) &&
self.attributes == other.attributes
}
}

// TODO(TF-693): This is not thread safe!
struct LazyTensorTraceCache {
/// Cache from signature to traces that match signature.
static private var cache: [String: [LazyTensorTrace]] = [:]
static func clearCache() { cache.removeAll() }

/// Returns a `MaterializationTraceInfo` with possibly some constants promoted to inputs.
static func traceWithPromotedConstants(
_ traceInfo: MaterializationTraceInfo
) -> MaterializationTraceInfo {
let trace = traceInfo.trace
guard var traces = cache[trace.signature] else {
cache[trace.signature] = [trace]
return traceInfo
}
for cachedTrace in traces {
if let promotedTrace = traceInfo.withPromotedConstants(cachedTrace: cachedTrace) {
debugLog("Promoted: \(promotedTrace)\n")
return promotedTrace
}
}
// No match found; cache and return the input `traceInfo` itself.
traces.append(trace)
return traceInfo
}
}

private extension MaterializationTraceInfo {
func withPromotedConstants(cachedTrace: LazyTensorTrace) -> MaterializationTraceInfo? {
let currentTrace = self.trace
if currentTrace.operations.count != cachedTrace.operations.count { return nil }
var promotableConstants: [(Int, TFETensorHandle)] = []
for (i, current) in currentTrace.operations.enumerated() {
let cached = cachedTrace.operations[i]
if let (currentTensor, cachedTensor) = Self.promotableConstants(current, cached) {
if currentTensor.elementsEqual(cachedTensor) { continue }
promotableConstants.append((i, currentTensor))
continue
}
// TODO: we might avoid running the following check based on results of promotableConstant
if current.isEquivalent(to: cached) { continue }
return nil
}

let newConcreteInputs: [TFETensorHandle] = promotableConstants.map { return $0.1 }
let newOperations = currentTrace.operations
let newInputs = promotableConstants.map {
(promotableConstant: (Int, TFETensorHandle)) -> LazyTensorOperation in
let constantOp = newOperations[promotableConstant.0]
constantOp.name = "Placeholder"
constantOp.attributes.removeValue(forKey: "value")
return constantOp
}
let newTrace = LazyTensorTrace(
inputs: currentTrace.inputs + newInputs,
operations: newOperations,
outputs: currentTrace.outputs)
return MaterializationTraceInfo(
lazyOperations: self.lazyOperations,
trace: newTrace,
concreteInputs: self.concreteInputs + newConcreteInputs)
}

/// If `current` and `cached` are compatible constants, returns the constant tensors.
static private func promotableConstants(
_ current: LazyTensorOperation,
_ cached: LazyTensorOperation
) -> (TFETensorHandle, TFETensorHandle)? {
if current.name != "Const" || cached.name != "Const" { return nil }
let currentValue = current.attributes["value"]!
let cachedValue = cached.attributes["value"]!
guard case let .constTensor(currentTensor) = currentValue,
case let .constTensor(cachedTensor) = cachedValue
else { return nil }
let currentDtype = TFE_TensorHandleDataType(currentTensor._cTensorHandle)
let cachedDtype = TFE_TensorHandleDataType(cachedTensor._cTensorHandle)
if currentDtype == TF_VARIANT || currentDtype == TF_RESOURCE { return nil }
if cachedDtype == TF_VARIANT || cachedDtype == TF_RESOURCE { return nil }
return currentTensor.shape == cachedTensor.shape && currentDtype == cachedDtype
? (currentTensor, cachedTensor)
: nil
}
}
4 changes: 4 additions & 0 deletions Tests/TensorFlowTests/LazyTensorTestHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ import XCTest
@testable import TensorFlow

class LazyTensorTestCase: XCTestCase {
static var shouldPromoteConstants = true
override class func setUp() {
super.setUp()
_ThreadLocalState.useLazyTensor = true
shouldPromoteConstants = LazyTensorContext.local.shouldPromoteConstants
LazyTensorContext.local.shouldPromoteConstants = false
}

override class func tearDown() {
super.tearDown()
_ThreadLocalState.useLazyTensor = false
LazyTensorContext.local.shouldPromoteConstants = shouldPromoteConstants
}
}

Expand Down
131 changes: 131 additions & 0 deletions Tests/TensorFlowTests/LazyTensorTraceCacheTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest

@testable import TensorFlow
import CTensorFlow

final class LazyTensorTraceCacheTests: LazyTensorTestCase {
override class func setUp() {
super.setUp()
LazyTensorContext.local.shouldPromoteConstants = true
}

override class func tearDown() {
super.tearDown()
LazyTensorTraceCache.clearCache()
}

func testConstPromotion() {
LazyTensorTraceCache.clearCache()
let a = Tensor<Float>(1.0)
let b = Tensor<Float>(2.0)
let c = Tensor<Float>(3.0)
let d = Tensor<Float>(4.0)
let w = a * b
let x = c * d
// Trigger materialization for `w` so that a trace with constants and mul is added to cache.
XCTAssertEqual(
lazyTrace(w).description,
"""
lazyTrace_3() -> (%2) {
%0 = Const[dtype: float, value: 1.0]()
%1 = Const[dtype: float, value: 2.0]()
%2 = Mul[T: float](%0, %1)
}
""")
XCTAssertEqual(w.scalars, [2.0])

// The trace for `x` should have the inputs to Mul as arguments instead of constants.
XCTAssertEqual(
lazyTrace(x).description,
"""
lazyTrace_3(%0: float, %1: float) -> (%2) {
%2 = Mul[T: float](%0, %1)
}
""")
XCTAssertEqual(x.scalarized(), 12.0)

let e = Tensor<Float>(shape: [1,3], scalars: [1, 2, 3])
let f = Tensor<Float>(5.0)
let y = e * f
// We won't promote constants in 'y' as shape of constants is different.
XCTAssertEqual(
lazyTrace(y).description,
"""
lazyTrace_3() -> (%2) {
%0 = Const[dtype: float, value: [[1.0, 2.0, 3.0]]]()
%1 = Const[dtype: float, value: 5.0]()
%2 = Mul[T: float](%0, %1)
}
""")
XCTAssertEqual(y.scalars, [5.0, 10.0, 15.0])
}

func testDoNotPromoteEqualConstants() {
LazyTensorTraceCache.clearCache()
let a = Tensor<Float>(1.0)
let b = Tensor<Float>(2.0)
let c = Tensor<Float>(3.0)
let w = a * b
let x = a * c
XCTAssertEqual(
lazyTrace(w).description,
"""
lazyTrace_3() -> (%2) {
%0 = Const[dtype: float, value: 1.0]()
%1 = Const[dtype: float, value: 2.0]()
%2 = Mul[T: float](%0, %1)
}
""")
XCTAssertEqual(w.scalars, [2.0])
// Const 1.0 is not promoted.
XCTAssertEqual(
lazyTrace(x).description,
"""
lazyTrace_3(%1: float) -> (%2) {
%0 = Const[dtype: float, value: 1.0]()
%2 = Mul[T: float](%0, %1)
}
""")
}

private func lazyTensorOperation<T: TensorFlowScalar>(
_ input: Tensor<T>
) -> LazyTensorOperation? {
let tensor = input.handle.handle
guard let lazyTensor = tensor as? LazyTensorHandle else {
XCTFail("Trying to get lazy trace for a non-lazy tensor.")
return nil
}
guard case let .symbolic(lazyOp, _, _) = lazyTensor.handle else {
XCTFail("Cannot get lazy trace for a concrete tensor.")
return nil
}
return lazyOp
}

private func lazyTrace<T: TensorFlowScalar>(
_ input: Tensor<T>
) -> LazyTensorTrace {
let lazyOperation = lazyTensorOperation(input)!
return LazyTensorTraceBuilder.materializationTraceInfo(lazyOperation).trace
}

static var allTests = [
("testConstPromotion", testConstPromotion),
("testDoNotPromoteEqualConstants", testDoNotPromoteEqualConstants)
]
}
6 changes: 0 additions & 6 deletions Tests/TensorFlowTests/TensorGroupTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ import XCTest
@testable import TensorFlow
import CTensorFlow

extension TensorDataType : Equatable {
public static func == (lhs: TensorDataType, rhs: TensorDataType) -> Bool {
return Int(lhs._cDataType.rawValue) == Int(rhs._cDataType.rawValue)
}
}

struct Empty : TensorGroup {}

struct Simple : TensorGroup, Equatable {
Expand Down
2 changes: 2 additions & 0 deletions Tests/TensorFlowTests/XCTestManifests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public func allTests() -> [XCTestCaseEntry] {
testCase(InitializerTests.allTests),
testCase(LayerTests.allTests),
testCase(LazyTensorEvaluationTests.allTests),
testCase(LazyTensorTraceTests.allTests),
testCase(LazyTensorTraceCacheTests.allTests),
testCase(LazyTensorExplicitTraceTests.allTests),
testCase(LazyTensorHandleTests.allTests),
testCase(LazyTensorOperationTests.allTests),
Expand Down

0 comments on commit 33fe7f3

Please sign in to comment.