Skip to content

Commit

Permalink
Merge pull request dotnet#3 from Oceania2018/tftransferlearning
Browse files Browse the repository at this point in the history
Tftransferlearning
  • Loading branch information
codemzs authored Aug 1, 2019
2 parents 5beab30 + f050f03 commit fd6b8d5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Dnn/Microsoft.ML.Dnn.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<ItemGroup>
<PackageReference Include="System.IO.FileSystem.AccessControl" Version="$(SystemIOFileSystemAccessControl)" />
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindows)" />
<PackageReference Include="TensorFlow.NET" Version="0.10.6" />
<PackageReference Include="TensorFlow.NET" Version="0.10.7.2" />
</ItemGroup>

<ItemGroup>
Expand Down
29 changes: 3 additions & 26 deletions src/Microsoft.ML.Dnn/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -291,34 +291,11 @@ internal static unsafe void FetchStringData<T>(Tensor tensor, Span<T> result)
{
if (tensor == null)
throw Contracts.ExceptEmpty(nameof(tensor));
//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
//
long size = 1;
foreach (var s in tensor.TensorShape.Dimensions)
size *= s;

var buffer = new byte[size][];
var src = c_api.TF_TensorData(tensor);
var srcLen = (IntPtr)(src.ToInt64() + (long)tensor.bytesize);
src += (int)(size * 8);
for (int i = 0; i < buffer.Length; i++)
{
using (var status = new Status())
{
IntPtr dst = IntPtr.Zero;
ulong dstLen = 0;
var read = c_api.TF_StringDecode(src, (ulong)(srcLen.ToInt64() - src.ToInt64()), dst, ref dstLen, status);
status.Check();
buffer[i] = new byte[(int)dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
src += (int)read;
}
}

var buffer = tensor.StringData();

for (int i = 0; i < buffer.Length; i++)
result[i] = (T)(object)Encoding.UTF8.GetString(buffer[i]).AsMemory();
result[i] = (T)(object)buffer[i].AsMemory();
}

internal static bool IsTypeSupported(TF_DataType tfoutput)
Expand Down

0 comments on commit fd6b8d5

Please sign in to comment.