Skip to content

Commit e5dfe90

Browse files
committed
More SafeHandles.
1 parent 016294d commit e5dfe90

File tree

69 files changed

+1524
-1826
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+1524
-1826
lines changed

src/TensorFlowNET.Console/MemoryMonitor.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@ public void WarmUp()
2323
var x = tf.placeholder(tf.float64, shape: (1024, 1024));
2424
var log = tf.log(x);
2525

26-
using (var sess = tf.Session())
27-
{
28-
var ones = np.ones((1024, 1024), dtype: np.float64);
29-
var o = sess.run(log, new FeedItem(x, ones));
30-
}
26+
var sess = tf.Session();
27+
var ones = np.ones((1024, 1024), dtype: np.float64);
28+
var o = sess.run(log, new FeedItem(x, ones));
3129
// Thread.Sleep(1);
3230
}
3331

src/TensorFlowNET.Core/Buffers/Buffer.cs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ namespace Tensorflow
2525
/// <summary>
2626
/// Represents a TF_Buffer that can be passed to Tensorflow.
2727
/// </summary>
28-
public sealed class Buffer : IDisposable
28+
public sealed class Buffer
2929
{
30-
public SafeBufferHandle Handle { get; }
30+
SafeBufferHandle _handle;
3131

3232
/// <remarks>
3333
/// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/>
3434
/// </remarks>
3535
private unsafe ref readonly TF_Buffer DangerousBuffer
36-
=> ref Unsafe.AsRef<TF_Buffer>(Handle.DangerousGetHandle().ToPointer());
36+
=> ref Unsafe.AsRef<TF_Buffer>(_handle.DangerousGetHandle().ToPointer());
3737

3838
/// <summary>
3939
/// The memory block representing this buffer.
@@ -59,21 +59,21 @@ public ulong Length
5959
{
6060
get
6161
{
62-
using (Handle.Lease())
62+
using (_handle.Lease())
6363
{
6464
return DangerousBuffer.length;
6565
}
6666
}
6767
}
6868

6969
public Buffer()
70-
=> Handle = TF_NewBuffer();
70+
=> _handle = TF_NewBuffer();
7171

7272
public Buffer(SafeBufferHandle handle)
73-
=> Handle = handle;
73+
=> _handle = handle;
7474

7575
public Buffer(byte[] data)
76-
=> Handle = _toBuffer(data);
76+
=> _handle = _toBuffer(data);
7777

7878
private static SafeBufferHandle _toBuffer(byte[] data)
7979
{
@@ -92,7 +92,7 @@ private static SafeBufferHandle _toBuffer(byte[] data)
9292
/// </summary>
9393
public unsafe byte[] ToArray()
9494
{
95-
using (Handle.Lease())
95+
using (_handle.Lease())
9696
{
9797
ref readonly TF_Buffer buffer = ref DangerousBuffer;
9898

@@ -107,7 +107,12 @@ public unsafe byte[] ToArray()
107107
}
108108
}
109109

110-
public void Dispose()
111-
=> Handle.Dispose();
110+
public override string ToString()
111+
=> $"0x{_handle.DangerousGetHandle():x16}";
112+
113+
public static implicit operator SafeBufferHandle(Buffer buffer)
114+
{
115+
return buffer._handle;
116+
}
112117
}
113118
}

src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public CheckpointReader(string filename)
1111
Status status = new Status();
1212
VariableToDataTypeMap = new Dictionary<string, TF_DataType>();
1313
VariableToShapeMap = new Dictionary<string, Shape>();
14-
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
14+
_handle = c_api.TF_NewCheckpointReader(filename, status);
1515
status.Check(true);
1616
ReadAllShapeAndType();
1717
}
@@ -38,7 +38,7 @@ public Shape GetVariableShape(string name)
3838
int num_dims = GetVariableNumDims(name);
3939
long[] dims = new long[num_dims];
4040
Status status = new Status();
41-
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
41+
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status);
4242
status.Check(true);
4343
return new Shape(dims);
4444
}
@@ -49,7 +49,7 @@ public int GetVariableNumDims(string name)
4949
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
5050
{
5151
Status status = new Status();
52-
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
52+
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status);
5353
status.Check(true);
5454
return new Tensor(tensor);
5555
}

src/TensorFlowNET.Core/Contexts/Context.Device.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public sealed partial class Context
3737
public void log_device_placement(bool enable)
3838
{
3939
if (_handle != null)
40-
c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status.Handle);
40+
c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status);
4141
_log_device_placement = enable;
4242
// _thread_local_data.function_call_options = null;
4343
}
@@ -60,23 +60,23 @@ public void set_memory_growth(PhysicalDevice device, bool enable)
6060
public PhysicalDevice[] list_physical_devices(string device_type = null)
6161
{
6262
using var opts = c_api.TFE_NewContextOptions();
63-
using var ctx = c_api.TFE_NewContext(opts, tf.Status.Handle);
64-
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status.Handle);
63+
using var ctx = c_api.TFE_NewContext(opts, tf.Status);
64+
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status);
6565
tf.Status.Check(true);
6666

6767
int num_devices = c_api.TF_DeviceListCount(devices);
6868
var results = new List<PhysicalDevice>();
6969
for (int i = 0; i < num_devices; ++i)
7070
{
71-
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status.Handle));
71+
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status));
7272
tf.Status.Check(true);
7373

7474
if (dev_type.StartsWith("XLA"))
7575
continue;
7676

7777
if (device_type == null || dev_type == device_type)
7878
{
79-
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status.Handle);
79+
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status);
8080
tf.Status.Check(true);
8181

8282
results.Add(new PhysicalDevice

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace Tensorflow.Contexts
2828
/// <summary>
2929
/// Environment in which eager operations execute.
3030
/// </summary>
31-
public sealed partial class Context : IDisposable
31+
public sealed partial class Context
3232
{
3333
public const int GRAPH_MODE = 0;
3434
public const int EAGER_MODE = 1;
@@ -41,15 +41,7 @@ public sealed partial class Context : IDisposable
4141
public FunctionCallOptions FunctionCallOptions { get; }
4242

4343
SafeContextHandle _handle;
44-
public SafeContextHandle Handle
45-
{
46-
get
47-
{
48-
if (_handle == null)
49-
ensure_initialized();
50-
return _handle;
51-
}
52-
}
44+
5345
int? _seed;
5446
Random _rng;
5547

@@ -59,6 +51,7 @@ public Context()
5951
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false);
6052
initialized = false;
6153
FunctionCallOptions = new FunctionCallOptions();
54+
ensure_initialized();
6255
}
6356

6457
/// <summary>
@@ -72,12 +65,12 @@ public void ensure_initialized()
7265
Config = MergeConfig();
7366
FunctionCallOptions.Config = Config;
7467
var config_str = Config.ToByteArray();
75-
using var opts = new ContextOptions();
76-
using var status = new Status();
77-
c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle);
68+
var opts = new ContextOptions();
69+
var status = new Status();
70+
c_api.TFE_ContextOptionsSetConfig(opts, config_str, (ulong)config_str.Length, status);
7871
status.Check(true);
79-
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts.Handle, _device_policy);
80-
_handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
72+
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy);
73+
_handle = c_api.TFE_NewContext(opts, status);
8174
status.Check(true);
8275
initialized = true;
8376
}
@@ -178,10 +171,14 @@ public void reset_context()
178171
tf.Context.ensure_initialized();
179172

180173
if (_handle != null)
174+
{
181175
c_api.TFE_ContextClearCaches(_handle);
176+
}
182177
}
183178

184-
public void Dispose()
185-
=> _handle.Dispose();
179+
public static implicit operator SafeContextHandle(Context ctx)
180+
{
181+
return ctx._handle;
182+
}
186183
}
187184
}

src/TensorFlowNET.Core/Contexts/ContextOptions.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17-
using System;
1817
using Tensorflow.Eager;
1918

20-
namespace Tensorflow.Contexts
19+
namespace Tensorflow.Contexts;
20+
21+
public sealed class ContextOptions
2122
{
22-
public sealed class ContextOptions : IDisposable
23-
{
24-
public SafeContextOptionsHandle Handle { get; }
23+
SafeContextOptionsHandle _handle { get; }
2524

26-
public ContextOptions()
27-
{
28-
Handle = c_api.TFE_NewContextOptions();
29-
}
25+
public ContextOptions()
26+
{
27+
_handle = c_api.TFE_NewContextOptions();
28+
}
3029

31-
public void Dispose()
32-
=> Handle.Dispose();
30+
public static implicit operator SafeContextOptionsHandle(ContextOptions opt)
31+
{
32+
return opt._handle;
3333
}
3434
}

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public Tensor[] TFE_ExecuteCancelable(Context ctx,
4343
{
4444
var status = tf.Status;
4545
var op = GetOp(ctx, op_name, status);
46-
c_api.TFE_OpSetDevice(op, device_name, status.Handle);
46+
c_api.TFE_OpSetDevice(op, device_name, status);
4747
if (status.ok())
4848
{
4949
for (int i = 0; i < inputs.Length; ++i)
@@ -54,7 +54,7 @@ public Tensor[] TFE_ExecuteCancelable(Context ctx,
5454
Tensor nd => nd.EagerTensorHandle,
5555
_ => throw new NotImplementedException("Eager tensor handle has not been allocated.")
5656
};
57-
c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
57+
c_api.TFE_OpAddInput(op, tensor_handle, status);
5858
status.Check(true);
5959
}
6060
}
@@ -64,7 +64,7 @@ public Tensor[] TFE_ExecuteCancelable(Context ctx,
6464
var outputs = new SafeEagerTensorHandle[num_outputs];
6565
if (status.ok())
6666
{
67-
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);
67+
c_api.TFE_Execute(op, outputs, out num_outputs, status);
6868
status.Check(true);
6969
}
7070
return outputs.Select(x => new EagerTensor(x)).ToArray();

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info)
104104
var eager_tensor = ops.convert_to_tensor(fast_input_array[j]);
105105
attr_values[j] = eager_tensor.dtype;
106106

107-
c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle);
107+
c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status);
108108

109109
if (op_exec_info.run_callbacks)
110110
{
@@ -142,7 +142,7 @@ public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info)
142142
}
143143

144144
var retVals = new SafeEagerTensorHandle[num_retvals];
145-
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
145+
c_api.TFE_Execute(op, retVals, out num_retvals, status);
146146
status.Check(true);
147147

148148
var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray();
@@ -160,10 +160,10 @@ public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info)
160160
SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status)
161161
{
162162
if (thread_local_eager_operation_map.find(op_or_function_name, out var op))
163-
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle);
163+
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status);
164164
else
165165
{
166-
op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle);
166+
op = c_api.TFE_NewOp(ctx, op_or_function_name, status);
167167
thread_local_eager_operation_map[op_or_function_name] = op;
168168
}
169169

@@ -219,7 +219,7 @@ bool AddInputToOp(object inputs,
219219
flattened_attrs.Add(dtype);
220220
}
221221

222-
c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status.Handle);
222+
c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status);
223223
status.Check(true);
224224

225225
return true;
@@ -235,7 +235,7 @@ public void SetOpAttrs(SafeEagerOpHandle op, params object[] attrs)
235235
var value = attrs[i + 1];
236236

237237
byte is_list = 0;
238-
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle);
238+
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status);
239239
if (!status.ok()) return;
240240
if (is_list != 0)
241241
SetOpAttrList(tf.Context, op, key, value as object[], type, null, status);
@@ -264,7 +264,7 @@ void SetOpAttrWithDefaults(Context ctx, SafeEagerOpHandle op, AttrDef attr,
264264
Status status)
265265
{
266266
byte is_list = 0;
267-
var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status.Handle);
267+
var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status);
268268
if (status.Code != TF_Code.TF_OK) return;
269269

270270
if (attr_value == null)
@@ -305,7 +305,7 @@ bool SetOpAttrList(Context ctx, SafeEagerOpHandle op,
305305
tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long));
306306
}
307307

308-
c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle);
308+
c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status);
309309
Array.ForEach(dims, x => Marshal.FreeHGlobal(x));
310310
}
311311
else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2)
@@ -353,7 +353,7 @@ bool SetOpAttrScalar(Context ctx, SafeEagerOpHandle op,
353353
break;
354354
case TF_AttrType.TF_ATTR_SHAPE:
355355
var dims = (value as long[]).ToArray();
356-
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle);
356+
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
357357
status.Check(true);
358358
break;
359359
case TF_AttrType.TF_ATTR_FUNC:

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public EagerTensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) : base(data_
5454
void NewEagerTensorHandle(SafeTensorHandle h)
5555
{
5656
_id = ops.uid();
57-
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);
57+
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status);
5858
#if TRACK_TENSOR_LIFE
5959
Console.WriteLine($"New EagerTensor {_eagerTensorHandle}");
6060
#endif
@@ -65,7 +65,7 @@ public void Resolve()
6565
{
6666
if (_handle != null)
6767
return;
68-
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle);
68+
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status);
6969
tf.Status.Check(true);
7070
}
7171

0 commit comments

Comments
 (0)