Skip to content

Commit af555b7

Browse files
committed
Add generic THStorage_newFromFile
1 parent a1c453c commit af555b7

File tree

6 files changed

+46
-31
lines changed

6 files changed

+46
-31
lines changed

include/torch2c.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#ifndef TORCH2C_H
2+
#define TORCH2C_H
3+
4+
#include <TH.h>
5+
#include "torch2c_generic.h"
6+
#include <THGenerateAllTypes.h>
7+
8+
#endif
9+

include/torch2c_generic.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#ifndef TH_GENERIC_FILE
2+
#define TH_GENERIC_FILE "torch2c_generic.h"
3+
#else
4+
5+
// TODO: this only works with little endian for now, we need to make more general
6+
7+
TH_API THStorage *THStorage_(newFromFile)(const char *filename)
8+
{
9+
FILE *f = fopen(filename,"rb");
10+
11+
if (!f) {
12+
THError("cannot open file %s for reading");
13+
return NULL;
14+
}
15+
16+
long size;
17+
size_t result = fread(&size,sizeof(long),1,f);
18+
19+
THStorage *out = THStorage_(newWithSize)(size);
20+
char *bytes = (char *) out->data;
21+
22+
uint64_t remaining = sizeof(real) * out->size;
23+
result = fread(bytes,sizeof(real),out->size,f);
24+
25+
fclose(f);
26+
27+
return out;
28+
}
29+
30+
#endif

scripts/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ project(torch2ctest)
44

55
include_directories(${INSTALL_DIR}/include/TH)
66
include_directories(${INSTALL_DIR}/include/THNN)
7+
include_directories(${INSTALL_DIR}/include/torch2c)
78

89
link_directories(${INSTALL_DIR}/lib)
910

scripts/run_test.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ INSTALL_DIR=$BASE_DIR/tmp/install
1111
OUT_BASE_DIR=$BASE_DIR/out
1212
OUT_BUILD_DIR=$BASE_DIR/tmp/out-build
1313

14+
mkdir -p $INSTALL_DIR/include/torch2c
15+
cp $BASE_DIR/include/*.h $INSTALL_DIR/include/torch2c
16+
1417
rm -rf $OUT_BASE_DIR
1518
rm -rf $OUT_BUILD_DIR
1619
python3 test/$1.py

torch2c/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _generate_c(nodes, out, fnname, out_path):
5252
last_node = nodes[-1]
5353
ifndef = '#ifndef __%s__\n#define __%s__\n' % (2*(fnname.upper(),))
5454
endif = '#endif'
55-
includes = '#include "TH.h"\n#include "THNN.h"\n'
55+
includes = '#include "TH.h"\n#include "THNN.h"\n#include "torch2c.h"'
5656
fndecl = 'void %s(%s)' % (fnname,
5757
', '.join([el.generate_decl() for el in var_nodes + [out_node]]))
5858
calls = [el.generate_call(out_path,'data') for el in nodes]
@@ -103,6 +103,7 @@ def _generate_test(nodes, out, fnname, filename, out_path):
103103

104104

105105
def compile(node, fnname, out_path, compile_test=False):
106+
includedir = os.path.join(os.path.dirname(__file__),'..','include')
106107
nodes = _traverse_graph(node)
107108
if not os.path.isdir(out_path):
108109
os.mkdir(out_path)

torch2c/emitters.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -180,35 +180,6 @@ def persist_tensor(tensor, name, out_path, datadir, size_name='size_$id', stride
180180
return os.path.join(datadir,filename), meta, meta_free
181181

182182

183-
# TODO: add this function to an auxiliary file
184-
# call it something like TH${T}Storage_newFromFile(filename);
185-
def read_storage(storage_name,filepath,numtype):
186-
subs = {
187-
'filepath': filepath,
188-
'storage_name': storage_name,
189-
'real': type_map[numtype],
190-
'T': numtype
191-
}
192-
# TODO: extend past little endian
193-
tpl = '''
194-
TH${T}Storage *${storage_name};
195-
{
196-
FILE *f = fopen("${filepath}","rb");
197-
if (!f) {
198-
THError("cannot open file ${filepath} for reading");
199-
}
200-
long size;
201-
size_t result = fread(&size,sizeof(long),1,f);
202-
${storage_name} = TH${T}Storage_newWithSize(size);
203-
char *bytes = (char *) ${storage_name}->data;
204-
uint64_t remaining = sizeof(${real}) * ${storage_name}->size;
205-
result = fread(bytes,sizeof(${real}),${storage_name}->size,f);
206-
fclose(f);
207-
}
208-
'''
209-
return Template(tpl).substitute(subs)
210-
211-
212183
class PersistedVariable(Variable):
213184

214185
def __init__(self, obj, prevfns):
@@ -219,7 +190,7 @@ def call_tpl(self):
219190
self.out_path,self.datadir)
220191

221192
return '\n'.join([
222-
read_storage('storage_$id',filepath,self.numtype),
193+
'TH${T}Storage *storage_$id = TH${T}Storage_newFromFile("%s");' % filepath,
223194
meta,
224195
'TH${T}Tensor *$id = TH${T}Tensor_newWithStorage(storage_$id,0,size_$id,stride_$id);',
225196
meta_free

0 commit comments

Comments
 (0)