Skip to content

Commit 59822a4

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents 1b3af40 + bd15b2a commit 59822a4

File tree

3 files changed

+85
-75
lines changed

3 files changed

+85
-75
lines changed

src/graph.rs

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ impl Graph {
153153

154154
/// Returns the graph definition as a protobuf.
155155
pub fn graph_def(&self) -> Result<Vec<u8>> {
156-
let status = Status::new();
156+
let mut status = Status::new();
157157
unsafe {
158158
let c_buffer = tf::TF_NewBuffer();
159-
tf::TF_GraphToGraphDef(self.gimpl.inner, c_buffer, status.inner);
159+
tf::TF_GraphToGraphDef(self.gimpl.inner, c_buffer, status.inner());
160160
if status.is_ok() {
161161
Ok(Buffer::from_c(c_buffer, true).into())
162162
} else {
@@ -173,9 +173,9 @@ impl Graph {
173173
/// Returns an error if:
174174
/// * `output` is not in `graph`.
175175
pub fn num_dims(&self, output: Output) -> Result<c_int> {
176-
let status = Status::new();
176+
let mut status = Status::new();
177177
unsafe {
178-
let val = tf::TF_GraphGetTensorNumDims(self.gimpl.inner, output.to_c(), status.inner);
178+
let val = tf::TF_GraphGetTensorNumDims(self.gimpl.inner, output.to_c(), status.inner());
179179
if status.is_ok() { Ok(val) } else { Err(status) }
180180
}
181181
}
@@ -185,7 +185,7 @@ impl Graph {
185185
/// Returns an error if:
186186
/// * `output` is not in `graph`.
187187
pub fn tensor_shape(&self, output: Output) -> Result<Shape> {
188-
let status = Status::new();
188+
let mut status = Status::new();
189189
let n = try!(self.num_dims(output));
190190
if n == -1 {
191191
return Ok(Shape(None));
@@ -196,7 +196,7 @@ impl Graph {
196196
output.to_c(),
197197
dims.as_mut_ptr(),
198198
dims.len() as c_int,
199-
status.inner);
199+
status.inner());
200200
if status.is_ok() {
201201
dims.set_len(n as usize);
202202
Ok(Shape(Some(dims.iter().map(|x| if *x < 0 { None } else { Some(*x) }).collect())))
@@ -212,10 +212,13 @@ impl Graph {
212212
options: &ImportGraphDefOptions)
213213
-> Result<()> {
214214
let buf = Buffer::from(graph_def);
215-
let status = Status::new();
215+
let mut status = Status::new();
216216
unsafe {
217-
tf::TF_GraphImportGraphDef(self.gimpl.inner, buf.inner(), options.inner, status.inner);
218-
status.as_result()
217+
tf::TF_GraphImportGraphDef(self.gimpl.inner,
218+
buf.inner(),
219+
options.inner,
220+
status.inner());
221+
status.into_result()
219222
}
220223
}
221224
}
@@ -310,9 +313,9 @@ impl Operation {
310313
#[allow(missing_docs)]
311314
pub fn output_list_length(&self, arg_name: &str) -> Result<usize> {
312315
let c_arg_name = try!(CString::new(arg_name));
313-
let status = Status::new();
316+
let mut status = Status::new();
314317
let length = unsafe {
315-
tf::TF_OperationOutputListLength(self.inner, c_arg_name.as_ptr(), status.inner)
318+
tf::TF_OperationOutputListLength(self.inner, c_arg_name.as_ptr(), status.inner())
316319
};
317320
if status.is_ok() {
318321
Ok(length as usize)
@@ -340,9 +343,9 @@ impl Operation {
340343
#[allow(missing_docs)]
341344
pub fn input_list_length(&self, arg_name: &str) -> Result<usize> {
342345
let c_arg_name = try!(CString::new(arg_name));
343-
let status = Status::new();
346+
let mut status = Status::new();
344347
let length = unsafe {
345-
tf::TF_OperationInputListLength(self.inner, c_arg_name.as_ptr(), status.inner)
348+
tf::TF_OperationInputListLength(self.inner, c_arg_name.as_ptr(), status.inner())
346349
};
347350
if status.is_ok() {
348351
Ok(length as usize)
@@ -548,8 +551,8 @@ impl<'a> OperationDescription<'a> {
548551
/// Builds the operation and adds it to the graph.
549552
pub fn finish(mut self) -> Result<Operation> {
550553
self.finished = true; // used by the drop code
551-
let status = Status::new();
552-
let operation = unsafe { tf::TF_FinishOperation(self.inner, status.inner) };
554+
let mut status = Status::new();
555+
let operation = unsafe { tf::TF_FinishOperation(self.inner, status.inner()) };
553556
if status.is_ok() {
554557
Ok(Operation {
555558
inner: operation,
@@ -818,15 +821,15 @@ impl<'a> OperationDescription<'a> {
818821
#[allow(trivial_numeric_casts)]
819822
pub fn set_attr_tensor_shape_proto(&mut self, attr_name: &str, value: &[u8]) -> Result<()> {
820823
let c_attr_name = try!(CString::new(attr_name));
821-
let status = Status::new();
824+
let mut status = Status::new();
822825
unsafe {
823826
tf::TF_SetAttrTensorShapeProto(self.inner,
824827
c_attr_name.as_ptr(),
825828
value.as_ptr() as *const c_void,
826829
value.len() as size_t,
827-
status.inner);
830+
status.inner());
828831
}
829-
status.as_result()
832+
status.into_result()
830833
}
831834

832835
/// Sets an attribute with an array of `TensorShapeProto` protobufs.
@@ -840,16 +843,16 @@ impl<'a> OperationDescription<'a> {
840843
.map(|x| x.as_ref().as_ptr() as *const c_void)
841844
.collect();
842845
let lens: Vec<size_t> = value.iter().map(|x| x.as_ref().len() as size_t).collect();
843-
let status = Status::new();
846+
let mut status = Status::new();
844847
unsafe {
845848
tf::TF_SetAttrTensorShapeProtoList(self.inner,
846849
c_attr_name.as_ptr(),
847850
ptrs.as_ptr(),
848851
lens.as_ptr(),
849852
ptrs.len() as c_int,
850-
status.inner);
853+
status.inner());
851854
}
852-
status.as_result()
855+
status.into_result()
853856
}
854857

855858
/// Sets a tensor-valued attribute.
@@ -858,14 +861,14 @@ impl<'a> OperationDescription<'a> {
858861
value: Tensor<T>)
859862
-> Result<()> {
860863
let c_attr_name = try!(CString::new(attr_name));
861-
let status = Status::new();
864+
let mut status = Status::new();
862865
unsafe {
863866
tf::TF_SetAttrTensor(self.inner,
864867
c_attr_name.as_ptr(),
865868
value.into_ptr(),
866-
status.inner);
869+
status.inner());
867870
}
868-
status.as_result()
871+
status.into_result()
869872
}
870873

871874
/// Sets an attribute which holds an array of tensors.
@@ -874,33 +877,33 @@ impl<'a> OperationDescription<'a> {
874877
value: T)
875878
-> Result<()> {
876879
let c_attr_name = try!(CString::new(attr_name));
877-
let status = Status::new();
880+
let mut status = Status::new();
878881
unsafe {
879882
let ptrs: Vec<*mut tf::TF_Tensor> = value.into_iter().map(|x| x.into_ptr()).collect();
880883
tf::TF_SetAttrTensorList(self.inner,
881884
c_attr_name.as_ptr(),
882885
ptrs.as_ptr(),
883886
ptrs.len() as c_int,
884-
status.inner);
887+
status.inner());
885888
}
886-
status.as_result()
889+
status.into_result()
887890
}
888891

889892
/// Sets an attribute with an `AttrValue` proto.
890893
#[allow(trivial_numeric_casts)]
891894
pub fn set_attr_to_attr_value_proto(&mut self, attr_name: &str, value: &[u8]) -> Result<()> {
892895
let c_attr_name = try!(CString::new(attr_name));
893-
let status = Status::new();
896+
let mut status = Status::new();
894897
unsafe {
895898
tf::TF_SetAttrValueProto(self.inner,
896899
c_attr_name.as_ptr(),
897900
value.as_ptr() as *const c_void,
898901
// Allow trivial_numeric_casts because usize is not
899902
// necessarily size_t.
900903
value.len() as size_t,
901-
status.inner);
904+
status.inner());
902905
}
903-
status.as_result()
906+
status.into_result()
904907
}
905908
}
906909

src/lib.rs

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,9 @@ c_enum!("Type of a single tensor element.", TF_DataType, DataType {
348348

349349
////////////////////////
350350

351-
/// Holds error information. It either has an OK code, or else an error code with an
352-
/// associated error message.
351+
/// Holds error information when communicating with back and forth with `tensorflow`.
352+
///
353+
/// It either has an `Code::Ok` code, or otherwise an error code with an associated message.
353354
pub struct Status {
354355
inner: *mut tf::TF_Status,
355356
}
@@ -375,7 +376,8 @@ impl Status {
375376
self.code() == Code::Ok
376377
}
377378

378-
fn as_result(self) -> Result<()> {
379+
/// Turns the current `Status` into a `Result`.
380+
fn into_result(self) -> Result<()> {
379381
if self.is_ok() { Ok(()) } else { Err(self) }
380382
}
381383

@@ -387,34 +389,39 @@ impl Status {
387389
}
388390
Ok(())
389391
}
392+
393+
/// Returns a mutable pointer to the inner tensorflow Status `TF_Status`.
394+
fn inner(&mut self) -> *mut tf::TF_Status {
395+
self.inner
396+
}
390397
}
391398

392399
impl Display for Status {
393400
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
394-
unsafe {
395-
try!(write!(f, "{}: ", self.code()));
396-
let msg = match CStr::from_ptr(tf::TF_Message(self.inner)).to_str() {
401+
try!(write!(f, "{}: ", self.code()));
402+
let msg = unsafe {
403+
match CStr::from_ptr(tf::TF_Message(self.inner)).to_str() {
397404
Ok(s) => s,
398405
Err(_) => "<invalid UTF-8 in message>",
399-
};
400-
f.write_str(msg)
401-
}
406+
}
407+
};
408+
f.write_str(msg)
402409
}
403410
}
404411

405412
impl Debug for Status {
406413
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
407-
unsafe {
408-
try!(write!(f, "{{inner:{:?}, ", self.inner));
409-
try!(write!(f, "{}: ", self.code()));
410-
let msg = match CStr::from_ptr(tf::TF_Message(self.inner)).to_str() {
414+
try!(write!(f, "{{inner:{:?}, ", self.inner));
415+
try!(write!(f, "{}: ", self.code()));
416+
let msg = unsafe {
417+
match CStr::from_ptr(tf::TF_Message(self.inner)).to_str() {
411418
Ok(s) => s,
412419
Err(_) => "<invalid UTF-8 in message>",
413-
};
414-
try!(f.write_str(msg));
415-
try!(write!(f, "}}"));
416-
Ok(())
417-
}
420+
}
421+
};
422+
try!(f.write_str(msg));
423+
try!(write!(f, "}}"));
424+
Ok(())
418425
}
419426
}
420427

@@ -469,12 +476,12 @@ impl SessionOptions {
469476
/// `config` should be a serialized brain.ConfigProto proto.
470477
/// Returns an error if config was not parsed successfully as a ConfigProto.
471478
pub fn set_config(&mut self, config: &[u8]) -> Result<()> {
472-
let status = Status::new();
479+
let mut status = Status::new();
473480
unsafe {
474481
tf::TF_SetConfig(self.inner,
475482
config.as_ptr() as *const _,
476483
config.len(),
477-
status.inner);
484+
status.inner());
478485
}
479486
if status.is_ok() { Ok(()) } else { Err(status) }
480487
}
@@ -497,8 +504,8 @@ pub struct DeprecatedSession {
497504
impl DeprecatedSession {
498505
/// Creates a session.
499506
pub fn new(options: &SessionOptions) -> Result<Self> {
500-
let status = Status::new();
501-
let inner = unsafe { tf::TF_NewDeprecatedSession(options.inner, status.inner) };
507+
let mut status = Status::new();
508+
let inner = unsafe { tf::TF_NewDeprecatedSession(options.inner, status.inner()) };
502509
if inner.is_null() {
503510
Err(status)
504511
} else {
@@ -508,24 +515,24 @@ impl DeprecatedSession {
508515

509516
/// Closes the session.
510517
pub fn close(&mut self) -> Result<()> {
511-
let status = Status::new();
518+
let mut status = Status::new();
512519
unsafe {
513-
tf::TF_CloseDeprecatedSession(self.inner, status.inner);
520+
tf::TF_CloseDeprecatedSession(self.inner, status.inner());
514521
}
515-
status.as_result()
522+
status.into_result()
516523
}
517524

518525
/// Treat `proto` as a serialized `GraphDef` and add the operations in that `GraphDef`
519526
/// to the graph for the session.
520527
pub fn extend_graph(&mut self, proto: &[u8]) -> Result<()> {
521-
let status = Status::new();
528+
let mut status = Status::new();
522529
unsafe {
523530
tf::TF_ExtendGraph(self.inner,
524531
proto.as_ptr() as *const _,
525532
proto.len(),
526-
status.inner);
533+
status.inner());
527534
}
528-
status.as_result()
535+
status.into_result()
529536
}
530537

531538
/// Runs the graph, feeding the inputs and then fetching the outputs requested in the step.
@@ -551,7 +558,7 @@ impl DeprecatedSession {
551558
// In case we're running it a second time and not all outputs were taken out.
552559
step.drop_output_tensors();
553560

554-
let status = Status::new();
561+
let mut status = Status::new();
555562
unsafe {
556563
tf::TF_Run(self.inner,
557564
std::ptr::null(),
@@ -564,18 +571,18 @@ impl DeprecatedSession {
564571
step.target_name_ptrs.as_mut_ptr(),
565572
step.target_name_ptrs.len() as c_int,
566573
std::ptr::null_mut(),
567-
status.inner);
574+
status.inner());
568575
};
569-
status.as_result()
576+
status.into_result()
570577
}
571578
}
572579

573580
#[allow(deprecated)]
574581
impl Drop for DeprecatedSession {
575582
fn drop(&mut self) {
576-
let status = Status::new();
583+
let mut status = Status::new();
577584
unsafe {
578-
tf::TF_DeleteDeprecatedSession(self.inner, status.inner);
585+
tf::TF_DeleteDeprecatedSession(self.inner, status.inner());
579586
}
580587
// TODO: What do we do with the status?
581588
}
@@ -967,8 +974,8 @@ impl Library {
967974
/// Loads a library.
968975
pub fn load(library_filename: &str) -> Result<Self> {
969976
let c_filename = try!(CString::new(library_filename));
970-
let status = Status::new();
971-
let inner = unsafe { tf::TF_LoadLibrary(c_filename.as_ptr(), status.inner) };
977+
let mut status = Status::new();
978+
let inner = unsafe { tf::TF_LoadLibrary(c_filename.as_ptr(), status.inner()) };
972979
if inner.is_null() {
973980
Err(status)
974981
} else {

0 commit comments

Comments
 (0)