11// Copyright (c) Microsoft Corporation. All rights reserved.
22// Licensed under the MIT License.
33#include <assert.h>
4- #include <png.h>
54#include <stdio.h>
65
76#include "onnxruntime_c_api.h"
1110#endif
1211#include <objbase.h>
1312#endif
13+ #include "image_file.h"
1414
1515#ifdef _WIN32
1616#define tcscmp wcscmp
@@ -39,10 +39,11 @@ const OrtApi* g_ort = NULL;
3939 * \param output A float array. should be freed by caller after use
4040 * \param output_count Array length of the `output` param
4141 */
42- static void hwc_to_chw (const png_byte * input , size_t h , size_t w , float * * output , size_t * output_count ) {
42+ void hwc_to_chw (const uint8_t * input , size_t h , size_t w , float * * output , size_t * output_count ) {
4343 size_t stride = h * w ;
4444 * output_count = stride * 3 ;
4545 float * output_data = (float * )malloc (* output_count * sizeof (float ));
46+ assert (output_data != NULL );
4647 for (size_t i = 0 ; i != stride ; ++ i ) {
4748 for (size_t c = 0 ; c != 3 ; ++ c ) {
4849 output_data [c * stride + i ] = input [i * 3 + c ];
@@ -58,121 +59,30 @@ static void hwc_to_chw(const png_byte* input, size_t h, size_t w, float** output
5859 * \param w image width
5960 * \param output A byte array. should be freed by caller after use
6061 */
61- static void chw_to_hwc (const float * input , size_t h , size_t w , png_bytep * output ) {
62+ static void chw_to_hwc (const float * input , size_t h , size_t w , uint8_t * * output ) {
6263 size_t stride = h * w ;
63- png_bytep output_data = (png_bytep )malloc (stride * 3 );
64- for (int c = 0 ; c != 3 ; ++ c ) {
64+ uint8_t * output_data = (uint8_t * )malloc (stride * 3 );
65+ assert (output_data != NULL );
66+ for (size_t c = 0 ; c != 3 ; ++ c ) {
6567 size_t t = c * stride ;
6668 for (size_t i = 0 ; i != stride ; ++ i ) {
6769 float f = input [t + i ];
6870 if (f < 0.f || f > 255.0f ) f = 0 ;
69- output_data [i * 3 + c ] = (png_byte )f ;
71+ output_data [i * 3 + c ] = (uint8_t )f ;
7072 }
7173 }
7274 * output = output_data ;
7375}
7476
75- /**
76- * \param out should be freed by caller after use
77- * \param output_count Array length of the `out` param
78- */
79- static int read_png_file (const char * input_file , size_t * height , size_t * width , float * * out , size_t * output_count ) {
80- png_image image ; /* The control structure used by libpng */
81- /* Initialize the 'png_image' structure. */
82- memset (& image , 0 , (sizeof image ));
83- image .version = PNG_IMAGE_VERSION ;
84- if (png_image_begin_read_from_file (& image , input_file ) == 0 ) {
85- return -1 ;
86- }
87- png_bytep buffer ;
88- image .format = PNG_FORMAT_BGR ;
89- size_t input_data_length = PNG_IMAGE_SIZE (image );
90- if (input_data_length != 720 * 720 * 3 ) {
91- printf ("input_data_length:%zd\n" , input_data_length );
92- return -1 ;
93- }
94- buffer = (png_bytep )malloc (input_data_length );
95- memset (buffer , 0 , input_data_length );
96- if (png_image_finish_read (& image , NULL /*background*/ , buffer , 0 /*row_stride*/ , NULL /*colormap*/ ) == 0 ) {
97- return -1 ;
98- }
99- hwc_to_chw (buffer , image .height , image .width , out , output_count );
100- free (buffer );
101- * width = image .width ;
102- * height = image .height ;
103- return 0 ;
104- }
105-
106- /**
107- * \param tensor should be a float tensor in [N,C,H,W] format
108- */
109- static int write_tensor_to_png_file (OrtValue * tensor , const char * output_file ) {
110- struct OrtTensorTypeAndShapeInfo * shape_info ;
111- ORT_ABORT_ON_ERROR (g_ort -> GetTensorTypeAndShape (tensor , & shape_info ));
112- size_t dim_count ;
113- ORT_ABORT_ON_ERROR (g_ort -> GetDimensionsCount (shape_info , & dim_count ));
114- if (dim_count != 4 ) {
115- printf ("output tensor must have 4 dimensions" );
116- return -1 ;
117- }
118- int64_t dims [4 ];
119- ORT_ABORT_ON_ERROR (g_ort -> GetDimensions (shape_info , dims , sizeof (dims ) / sizeof (dims [0 ])));
120- if (dims [0 ] != 1 || dims [1 ] != 3 ) {
121- printf ("output tensor shape error" );
122- return -1 ;
123- }
124- float * f ;
125- ORT_ABORT_ON_ERROR (g_ort -> GetTensorMutableData (tensor , (void * * )& f ));
126- png_bytep model_output_bytes ;
127- png_image image ;
128- memset (& image , 0 , (sizeof image ));
129- image .version = PNG_IMAGE_VERSION ;
130- image .format = PNG_FORMAT_BGR ;
131- image .height = (png_uint_32 )dims [2 ];
132- image .width = (png_uint_32 )dims [3 ];
133- chw_to_hwc (f , image .height , image .width , & model_output_bytes );
134- int ret = 0 ;
135- if (png_image_write_to_file (& image , output_file , 0 /*convert_to_8bit*/ , model_output_bytes , 0 /*row_stride*/ ,
136- NULL /*colormap*/ ) == 0 ) {
137- printf ("write to '%s' failed:%s\n" , output_file , image .message );
138- ret = -1 ;
139- }
140- free (model_output_bytes );
141- return ret ;
142- }
143-
14477static void usage () { printf ("usage: <model_path> <input_file> <output_file> [cpu|cuda|dml] \n" ); }
14578
146- #ifdef _WIN32
147- static char * convert_string (const wchar_t * input ) {
148- size_t src_len = wcslen (input ) + 1 ;
149- if (src_len > INT_MAX ) {
150- printf ("size overflow\n" );
151- abort ();
152- }
153- const int len = WideCharToMultiByte (CP_ACP , 0 , input , (int )src_len , NULL , 0 , NULL , NULL );
154- assert (len > 0 );
155- char * ret = (char * )malloc (len );
156- assert (ret != NULL );
157- const int r = WideCharToMultiByte (CP_ACP , 0 , input , (int )src_len , ret , len , NULL , NULL );
158- assert (len == r );
159- return ret ;
160- }
161- #endif
162-
16379int run_inference (OrtSession * session , const ORTCHAR_T * input_file , const ORTCHAR_T * output_file ) {
16480 size_t input_height ;
16581 size_t input_width ;
16682 float * model_input ;
16783 size_t model_input_ele_count ;
168- #ifdef _WIN32
169- const char * output_file_p = convert_string (output_file );
170- const char * input_file_p = convert_string (input_file );
171- #else
172- const char * output_file_p = output_file ;
173- const char * input_file_p = input_file ;
174- #endif
175- if (read_png_file (input_file_p , & input_height , & input_width , & model_input , & model_input_ele_count ) != 0 ) {
84+
85+ if (read_image_file (input_file , & input_height , & input_width , & model_input , & model_input_ele_count ) != 0 ) {
17686 return -1 ;
17787 }
17888 if (input_height != 720 || input_width != 720 ) {
@@ -204,16 +114,16 @@ int run_inference(OrtSession* session, const ORTCHAR_T* input_file, const ORTCHA
204114 ORT_ABORT_ON_ERROR (g_ort -> IsTensor (output_tensor , & is_tensor ));
205115 assert (is_tensor );
206116 int ret = 0 ;
207- if (write_tensor_to_png_file (output_tensor , output_file_p ) != 0 ) {
117+ float * output_tensor_data = NULL ;
118+ ORT_ABORT_ON_ERROR (g_ort -> GetTensorMutableData (output_tensor , (void * * )& output_tensor_data ));
119+ uint8_t * output_image_data = NULL ;
120+ chw_to_hwc (output_tensor_data , 720 , 720 , & output_image_data );
121+ if (write_image_file (output_image_data , 720 , 720 , output_file ) != 0 ) {
208122 ret = -1 ;
209123 }
210124 g_ort -> ReleaseValue (output_tensor );
211125 g_ort -> ReleaseValue (input_tensor );
212126 free (model_input );
213- #ifdef _WIN32
214- free (input_file_p );
215- free (output_file_p );
216- #endif // _WIN32
217127 return ret ;
218128}
219129
0 commit comments