Skip to content

Commit

Permalink
add openface scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
piliwilliam0306 committed Nov 7, 2016
1 parent 52f48c7 commit 68401b6
Show file tree
Hide file tree
Showing 41 changed files with 4,558 additions and 0 deletions.
77 changes: 77 additions & 0 deletions batch-represent/batch-represent.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
local ffi = require 'ffi'

local batchNumber, nImgs = 0

torch.setdefaulttensortype('torch.FloatTensor')

function batchRepresent()
local loadSize = {3, opt.imgDim, opt.imgDim}
print(opt.data)
local cacheFile = paths.concat(opt.data, 'cache.t7')
print('cache lotation: ', cacheFile)
local dumpLoader
if paths.filep(cacheFile) then
print('Loading metadata from cache.')
print('If your dataset has changed, delete the cache file.')
dumpLoader = torch.load(cacheFile)
else
print('Creating metadata for cache.')
dumpLoader = dataLoader{
paths = {opt.data},
loadSize = loadSize,
sampleSize = loadSize,
split = 0,
verbose = true
}
torch.save(cacheFile, dumpLoader)
end
collectgarbage()
nImgs = dumpLoader:sizeTest()
print('nImgs: ', nImgs)
assert(nImgs > 0, "Failed to get nImgs")

batchNumber = 0

for i=1,math.ceil(nImgs/opt.batchSize) do
local indexStart = (i-1) * opt.batchSize + 1
local indexEnd = math.min(nImgs, indexStart + opt.batchSize - 1)
local batchSz = indexEnd-indexStart+1
local inputs, labels = dumpLoader:get(indexStart, indexEnd)
local paths = {}
for j=indexStart,indexEnd do
table.insert(paths,
ffi.string(dumpLoader.imagePath[dumpLoader.testIndices[j]]:data()))
end
repBatch(paths, inputs, labels, batchSz)
if i % 5 == 0 then
collectgarbage()
end
end

if opt.cuda then
cutorch.synchronize()
end
end

function repBatch(paths, inputs, labels, batchSz)
batchNumber = batchNumber + batchSz

if opt.cuda then
inputs = inputs:cuda()
end
local embeddings = model:forward(inputs):float()
if opt.cuda then
cutorch.synchronize()
end

if batchSz == 1 then
embeddings = embeddings:reshape(1, embeddings:size(1))
end

for i=1,batchSz do
labelsCSV:write({labels[i], paths[i]})
repsCSV:write(embeddings[i]:totable())
end

print(('Represent: %d/%d'):format(batchNumber, nImgs))
end
Loading

0 comments on commit 68401b6

Please sign in to comment.