Skip to content

Commit

Permalink
automatic model downloads for segNet
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Dec 18, 2022
1 parent 5cb08cf commit f642afc
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 230 deletions.
182 changes: 39 additions & 143 deletions c/segNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "segNet.h"
#include "tensorConvert.h"
#include "modelDownloader.h"

#include "cudaMappedMemory.h"
#include "cudaOverlay.h"
Expand All @@ -44,8 +45,6 @@ segNet::segNet() : tensorNet()
mColorsAlphaSet = NULL;
mClassColors = NULL;
mClassMap = NULL;

mNetworkType = SEGNET_CUSTOM;
}


Expand Down Expand Up @@ -128,147 +127,35 @@ segNet::FilterMode segNet::FilterModeFromStr( const char* str, FilterMode defaul
}


// NetworkTypeFromStr
segNet::NetworkType segNet::NetworkTypeFromStr( const char* modelName )
{
if( !modelName )
return segNet::SEGNET_CUSTOM;

segNet::NetworkType type = segNet::FCN_RESNET18_VOC_320x320;

// ONNX models
if( strcasecmp(modelName, "fcn-resnet18-cityscapes-512x256") == 0 || strcasecmp(modelName, "fcn-resnet18-cityscapes") == 0 )
type = segNet::FCN_RESNET18_CITYSCAPES_512x256;
else if( strcasecmp(modelName, "fcn-resnet18-cityscapes-1024x512") == 0 )
type = segNet::FCN_RESNET18_CITYSCAPES_1024x512;
else if( strcasecmp(modelName, "fcn-resnet18-cityscapes-2048x1024") == 0 )
type = segNet::FCN_RESNET18_CITYSCAPES_2048x1024;
else if( strcasecmp(modelName, "fcn-resnet18-deepscene-576x320") == 0 || strcasecmp(modelName, "fcn-resnet18-deepscene") == 0)
type = segNet::FCN_RESNET18_DEEPSCENE_576x320;
else if( strcasecmp(modelName, "fcn-resnet18-deepscene-864x480") == 0 )
type = segNet::FCN_RESNET18_DEEPSCENE_864x480;
else if( strcasecmp(modelName, "fcn-resnet18-mhp-512x320") == 0 || strcasecmp(modelName, "fcn-resnet18-mhp") == 0 )
type = segNet::FCN_RESNET18_MHP_512x320;
else if( strcasecmp(modelName, "fcn-resnet18-mhp-640x360") == 0 )
type = segNet::FCN_RESNET18_MHP_640x360;
else if( strcasecmp(modelName, "fcn-resnet18-voc-320x320") == 0 || strcasecmp(modelName, "fcn-resnet18-pascal-voc-320x320") == 0 || strcasecmp(modelName, "fcn-resnet18-voc") == 0 || strcasecmp(modelName, "fcn-resnet18-pascal-voc") == 0 )
type = segNet::FCN_RESNET18_VOC_320x320;
else if( strcasecmp(modelName, "fcn-resnet18-voc-512x320") == 0 || strcasecmp(modelName, "fcn-resnet18-pascal-voc-512x320") == 0 )
type = segNet::FCN_RESNET18_VOC_512x320;
else if( strcasecmp(modelName, "fcn-resnet18-sun-512x400") == 0 || strcasecmp(modelName, "fcn-resnet18-sun-rgbd-512x400") == 0 || strcasecmp(modelName, "fcn-resnet18-sun") == 0 || strcasecmp(modelName, "fcn-resnet18-sunrgb") == 0 )
type = segNet::FCN_RESNET18_SUNRGB_512x400;
else if( strcasecmp(modelName, "fcn-resnet18-sun-640x512") == 0 || strcasecmp(modelName, "fcn-resnet18-sun-rgbd-640x512") == 0 )
type = segNet::FCN_RESNET18_SUNRGB_640x512;

// legacy models
else if( strcasecmp(modelName, "fcn-alexnet-cityscapes-sd") == 0 || strcasecmp(modelName, "fcn-alexnet-cityscapes") == 0 )
type = segNet::FCN_ALEXNET_CITYSCAPES_SD;
else if( strcasecmp(modelName, "fcn-alexnet-cityscapes-hd") == 0 )
type = segNet::FCN_ALEXNET_CITYSCAPES_HD;
else if( strcasecmp(modelName, "fcn-alexnet-pascal-voc") == 0 )
type = segNet::FCN_ALEXNET_PASCAL_VOC;
else if( strcasecmp(modelName, "synthia-cvpr16") == 0 || strcasecmp(modelName, "fcn-alexnet-synthia-cvpr16") == 0 )
type = segNet::FCN_ALEXNET_SYNTHIA_CVPR16;
else if( strcasecmp(modelName, "synthia-summer-sd") == 0 || strcasecmp(modelName, "fcn-alexnet-synthia-summer-sd") == 0 )
type = segNet::FCN_ALEXNET_SYNTHIA_SUMMER_SD;
else if( strcasecmp(modelName, "synthia-summer-hd") == 0 || strcasecmp(modelName, "fcn-alexnet-synthia-summer-hd") == 0 )
type = segNet::FCN_ALEXNET_SYNTHIA_SUMMER_HD;
else if( strcasecmp(modelName, "aerial-fpv") == 0 || strcasecmp(modelName, "aerial-fpv-720p") == 0 || strcasecmp(modelName, "fcn-alexnet-aerial-fpv-720p") == 0 )
type = segNet::FCN_ALEXNET_AERIAL_FPV_720p;
else
type = segNet::SEGNET_CUSTOM;

return type;
}


// NetworkTypeToStr
const char* segNet::NetworkTypeToStr( segNet::NetworkType type )
{
switch(type)
{
// ONNX models
case FCN_RESNET18_CITYSCAPES_512x256: return "fcn-resnet18-cityscapes-512x256";
case FCN_RESNET18_CITYSCAPES_1024x512: return "fcn-resnet18-cityscapes-1024x512";
case FCN_RESNET18_CITYSCAPES_2048x1024: return "fcn-resnet18-cityscapes-2048x1024";
case FCN_RESNET18_DEEPSCENE_576x320: return "fcn-resnet18-deepscene-576x320";
case FCN_RESNET18_DEEPSCENE_864x480: return "fcn-resnet18-deepscene-864x480";
case FCN_RESNET18_MHP_512x320: return "fcn-resnet18-mhp-512x320";
case FCN_RESNET18_MHP_640x360: return "fcn-resnet18-mhp-640x360";
case FCN_RESNET18_VOC_320x320: return "fcn-resnet18-voc-320x320";
case FCN_RESNET18_VOC_512x320: return "fcn-resnet18-voc-512x320";
case FCN_RESNET18_SUNRGB_512x400: return "fcn-resnet18-sun-512x400";
case FCN_RESNET18_SUNRGB_640x512: return "fcn-resnet18-sun-640x512";

// legacy models
case FCN_ALEXNET_PASCAL_VOC: return "fcn-alexnet-pascal-voc";
case FCN_ALEXNET_SYNTHIA_CVPR16: return "fcn-alexnet-synthia-cvpr16";
case FCN_ALEXNET_SYNTHIA_SUMMER_HD: return "fcn-alexnet-synthia-summer-hd";
case FCN_ALEXNET_SYNTHIA_SUMMER_SD: return "fcn-alexnet-synthia-summer-sd";
case FCN_ALEXNET_CITYSCAPES_HD: return "fcn-alexnet-cityscapes-hd";
case FCN_ALEXNET_CITYSCAPES_SD: return "fcn-alexnet-cityscapes-sd";
case FCN_ALEXNET_AERIAL_FPV_720p: return "fcn-alexnet-aerial-fpv-720p";
default: return "custom segNet";
}
}


// Create
segNet* segNet::Create( NetworkType networkType, uint32_t maxBatchSize,
segNet* segNet::Create( const char* network, uint32_t maxBatchSize,
precisionType precision, deviceType device, bool allowGPUFallback )
{
segNet* net = NULL;

#define LOAD_ONNX(x) Create(NULL, "networks/" x "/fcn_resnet18.onnx", "networks/" x "/classes.txt", "networks/" x "/colors.txt", "input_0", "output_0", maxBatchSize, precision, device, allowGPUFallback )

// ONNX models
if( networkType == FCN_RESNET18_CITYSCAPES_512x256 )
net = LOAD_ONNX("FCN-ResNet18-Cityscapes-512x256");
else if( networkType == FCN_RESNET18_CITYSCAPES_1024x512 )
net = LOAD_ONNX("FCN-ResNet18-Cityscapes-1024x512");
else if( networkType == FCN_RESNET18_CITYSCAPES_2048x1024 )
net = LOAD_ONNX("FCN-ResNet18-Cityscapes-2048x1024");
else if( networkType == FCN_RESNET18_DEEPSCENE_576x320 )
net = LOAD_ONNX("FCN-ResNet18-DeepScene-576x320");
else if( networkType == FCN_RESNET18_DEEPSCENE_864x480 )
net = LOAD_ONNX("FCN-ResNet18-DeepScene-864x480");
else if( networkType == FCN_RESNET18_MHP_512x320 )
net = LOAD_ONNX("FCN-ResNet18-MHP-512x320");
else if( networkType == FCN_RESNET18_MHP_640x360 )
net = LOAD_ONNX("FCN-ResNet18-MHP-640x360");
else if( networkType == FCN_RESNET18_VOC_320x320 )
net = LOAD_ONNX("FCN-ResNet18-Pascal-VOC-320x320");
else if( networkType == FCN_RESNET18_VOC_512x320 )
net = LOAD_ONNX("FCN-ResNet18-Pascal-VOC-512x320");
else if( networkType == FCN_RESNET18_SUNRGB_512x400 )
net = LOAD_ONNX("FCN-ResNet18-SUN-RGBD-512x400");
else if( networkType == FCN_RESNET18_SUNRGB_640x512 )
net = LOAD_ONNX("FCN-ResNet18-SUN-RGBD-640x512");

// legacy models
else if( networkType == FCN_ALEXNET_PASCAL_VOC )
net = Create("networks/FCN-Alexnet-Pascal-VOC/deploy.prototxt", "networks/FCN-Alexnet-Pascal-VOC/snapshot_iter_146400.caffemodel", "networks/FCN-Alexnet-Pascal-VOC/pascal-voc-classes.txt", "networks/FCN-Alexnet-Pascal-VOC/pascal-voc-colors.txt", SEGNET_DEFAULT_INPUT, SEGNET_DEFAULT_OUTPUT, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == FCN_ALEXNET_SYNTHIA_CVPR16 )
net = Create("networks/FCN-Alexnet-SYNTHIA-CVPR16/deploy.prototxt", "networks/FCN-Alexnet-SYNTHIA-CVPR16/snapshot_iter_1206700.caffemodel", "networks/FCN-Alexnet-SYNTHIA-CVPR16/synthia-cvpr16-labels.txt", "networks/FCN-Alexnet-SYNTHIA-CVPR16/synthia-cvpr16-train-colors.txt", SEGNET_DEFAULT_INPUT, SEGNET_DEFAULT_OUTPUT, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == FCN_ALEXNET_SYNTHIA_SUMMER_HD )
net = Create("networks/FCN-Alexnet-SYNTHIA-Summer-HD/deploy.prototxt", "networks/FCN-Alexnet-SYNTHIA-Summer-HD/snapshot_iter_902888.caffemodel", "networks/FCN-Alexnet-SYNTHIA-Summer-HD/synthia-seq-labels.txt", "networks/FCN-Alexnet-SYNTHIA-Summer-HD/synthia-seq-train-colors.txt", SEGNET_DEFAULT_INPUT, SEGNET_DEFAULT_OUTPUT, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == FCN_ALEXNET_SYNTHIA_SUMMER_SD )
net = Create("networks/FCN-Alexnet-SYNTHIA-Summer-SD/deploy.prototxt", "networks/FCN-Alexnet-SYNTHIA-Summer-SD/snapshot_iter_431816.caffemodel", "networks/FCN-Alexnet-SYNTHIA-Summer-SD/synthia-seq-labels.txt", "networks/FCN-Alexnet-SYNTHIA-Summer-SD/synthia-seq-train-colors.txt", SEGNET_DEFAULT_INPUT, SEGNET_DEFAULT_OUTPUT, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == FCN_ALEXNET_CITYSCAPES_HD )
net = Create("networks/FCN-Alexnet-Cityscapes-HD/deploy.prototxt", "networks/FCN-Alexnet-Cityscapes-HD/snapshot_iter_367568.caffemodel", "networks/FCN-Alexnet-Cityscapes-HD/cityscapes-labels.txt", "networks/FCN-Alexnet-Cityscapes-HD/cityscapes-deploy-colors.txt", SEGNET_DEFAULT_INPUT, SEGNET_DEFAULT_OUTPUT, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == FCN_ALEXNET_CITYSCAPES_SD )
net = Create("networks/FCN-Alexnet-Cityscapes-SD/deploy.prototxt", "networks/FCN-Alexnet-Cityscapes-SD/snapshot_iter_2756640.caffemodel", "networks/FCN-Alexnet-Cityscapes-SD/cityscapes-labels.txt", "networks/FCN-Alexnet-Cityscapes-SD/cityscapes-deploy-colors.txt", SEGNET_DEFAULT_INPUT, SEGNET_DEFAULT_OUTPUT, maxBatchSize, precision, device, allowGPUFallback );
//else if( networkType == FCN_ALEXNET_AERIAL_FPV_720p_4ch )
// net = Create("FCN-Alexnet-Aerial-FPV-4ch-720p/deploy.prototxt", "FCN-Alexnet-Aerial-FPV-4ch-720p/snapshot_iter_1777146.caffemodel", "FCN-Alexnet-Aerial-FPV-4ch-720p/fpv-labels.txt", "FCN-Alexnet-Aerial-FPV-4ch-720p/fpv-deploy-colors.txt", "data", "score_fr_4classes", SEGNET_DEFAULT_INPUT, SEGNET_DEFAULT_OUTPUT, maxBatchSize );
else if( networkType == FCN_ALEXNET_AERIAL_FPV_720p )
net = Create("networks/FCN-Alexnet-Aerial-FPV-720p/fcn_alexnet.deploy.prototxt", "networks/FCN-Alexnet-Aerial-FPV-720p/snapshot_iter_10280.caffemodel", "networks/FCN-Alexnet-Aerial-FPV-720p/fpv-labels.txt", "networks/FCN-Alexnet-Aerial-FPV-720p/fpv-deploy-colors.txt", SEGNET_DEFAULT_INPUT, SEGNET_DEFAULT_OUTPUT, maxBatchSize, precision, device, allowGPUFallback );
else
nlohmann::json model;

if( !DownloadModel(SEGNET_MODEL_TYPE, network, model) )
return NULL;

if( net != NULL )
net->mNetworkType = networkType;

return net;

std::string model_dir = "networks/" + model["dir"].get<std::string>() + "/";
std::string model_path = model_dir + model["model"].get<std::string>();
std::string prototxt = JSON_STR(model["prototxt"]);
std::string labels = JSON_STR(model["labels"]);
std::string colors = JSON_STR(model["colors"]);
std::string input = JSON_STR_DEFAULT(model["input"], SEGNET_DEFAULT_INPUT);
std::string output = JSON_STR_DEFAULT(model["output"], SEGNET_DEFAULT_OUTPUT);

if( prototxt.length() > 0 )
prototxt = model_dir + prototxt;

if( locateFile(labels).length() == 0 )
labels = model_dir + labels;

if( locateFile(colors).length() == 0 )
colors = model_dir + colors;

return Create(prototxt.c_str(), model_path.c_str(), labels.c_str(),
colors.c_str(), input.c_str(), output.c_str(),
maxBatchSize, precision, device, allowGPUFallback);
}


Expand All @@ -291,9 +178,7 @@ segNet* segNet::Create( const commandLine& cmdLine )
modelName = cmdLine.GetString("network", "fcn-resnet18-voc-320x320");

// parse the model type
const segNet::NetworkType type = NetworkTypeFromStr(modelName);

if( type == SEGNET_CUSTOM )
if( !FindModel(SEGNET_MODEL_TYPE, modelName) )
{
const char* prototxt = cmdLine.GetString("prototxt");
const char* labels = cmdLine.GetString("labels");
Expand All @@ -314,7 +199,7 @@ segNet* segNet::Create( const commandLine& cmdLine )
else
{
// create segnet from pretrained model
net = segNet::Create(type);
net = segNet::Create(modelName);
}

if( !net )
Expand Down Expand Up @@ -342,6 +227,17 @@ segNet* segNet::Create( const char* prototxt, const char* model, const char* lab
const char* input_blob, const char* output_blob, uint32_t maxBatchSize,
precisionType precision, deviceType device, bool allowGPUFallback )
{
// check for built-in model string
if( FindModel(SEGNET_MODEL_TYPE, model) )
{
return Create(model, maxBatchSize, precision, device, allowGPUFallback);
}
else if( fileExtension(model).length() == 0 )
{
LogError(LOG_TRT "couldn't find built-in segmentation model '%s'\n", model);
return NULL;
}

// create segmentation model
segNet* net = new segNet();

Expand Down
75 changes: 13 additions & 62 deletions c/segNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,26 @@
* Name of default input blob for segmentation model.
* @ingroup segNet
*/
#define SEGNET_DEFAULT_INPUT "data"
#define SEGNET_DEFAULT_INPUT "input_0"

/**
* Name of default output blob for segmentation model.
* @ingroup segNet
*/
#define SEGNET_DEFAULT_OUTPUT "score_fr_21classes"
#define SEGNET_DEFAULT_OUTPUT "output_0"

/**
* Default alpha blending value used during overlay
* @ingroup segNet
*/
#define SEGNET_DEFAULT_ALPHA 150

/**
* The model type for segNet in data/networks/models.json
* @ingroup segNet
*/
#define SEGNET_MODEL_TYPE "segmentation"

/**
* Standard command-line options able to be passed to segNet::Create()
* @ingroup segNet
Expand Down Expand Up @@ -82,36 +88,6 @@
class segNet : public tensorNet
{
public:
/**
* Enumeration of pretrained/built-in network models.
*/
enum NetworkType
{
FCN_RESNET18_CITYSCAPES_512x256, /**< FCN-ResNet18 trained on Cityscapes dataset (512x256) */
FCN_RESNET18_CITYSCAPES_1024x512, /**< FCN-ResNet18 trained on Cityscapes dataset (1024x512) */
FCN_RESNET18_CITYSCAPES_2048x1024, /**< FCN-ResNet18 trained on Cityscapes dataset (2048x1024) */
FCN_RESNET18_DEEPSCENE_576x320, /**< FCN-ResNet18 trained on DeepScene Forest dataset (576x320) */
FCN_RESNET18_DEEPSCENE_864x480, /**< FCN-ResNet18 trained on DeepScene Forest dataset (864x480) */
FCN_RESNET18_MHP_512x320, /**< FCN-ResNet18 trained on Multi-Human Parsing dataset (512x320) */
FCN_RESNET18_MHP_640x360, /**< FCN-ResNet18 trained on Multi-Human Parsing dataset (640x360) */
FCN_RESNET18_VOC_320x320, /**< FCN-ResNet18 trained on Pascal VOC dataset (320x320) */
FCN_RESNET18_VOC_512x320, /**< FCN-ResNet18 trained on Pascal VOC dataset (512x320) */
FCN_RESNET18_SUNRGB_512x400, /**< FCN-ResNet18 trained on SUN RGB-D dataset (512x400) */
FCN_RESNET18_SUNRGB_640x512, /**< FCN-ResNet18 trained on SUN RGB-D dataset (640x512) */

/* legacy models (deprecated) */
FCN_ALEXNET_PASCAL_VOC, /**< FCN-Alexnet trained on Pascal VOC dataset. */
FCN_ALEXNET_SYNTHIA_CVPR16, /**< FCN-Alexnet trained on SYNTHIA CVPR16 dataset. @note To save disk space, this model isn't downloaded by default. Enable it in CMakePreBuild.sh */
FCN_ALEXNET_SYNTHIA_SUMMER_HD, /**< FCN-Alexnet trained on SYNTHIA SEQS summer datasets. @note To save disk space, this model isn't downloaded by default. Enable it in CMakePreBuild.sh */
FCN_ALEXNET_SYNTHIA_SUMMER_SD, /**< FCN-Alexnet trained on SYNTHIA SEQS summer datasets. @note To save disk space, this model isn't downloaded by default. Enable it in CMakePreBuild.sh */
FCN_ALEXNET_CITYSCAPES_HD, /**< FCN-Alexnet trained on Cityscapes dataset with 21 classes. */
FCN_ALEXNET_CITYSCAPES_SD, /**< FCN-Alexnet trained on Cityscapes dataset with 21 classes. @note To save disk space, this model isn't downloaded by default. Enable it in CMakePreBuild.sh */
FCN_ALEXNET_AERIAL_FPV_720p, /**< FCN-Alexnet trained on aerial first-person view of the horizon line for drones, 1280x720 and 21 output classes */

/* add new models here */
SEGNET_CUSTOM
};

/**
* Enumeration of mask/overlay filtering modes.
*/
Expand All @@ -126,9 +102,8 @@ class segNet : public tensorNet
*/
enum VisualizationFlags
{
VISUALIZE_OVERLAY = (1 << 0),
VISUALIZE_MASK = (1 << 1),
/*VISUALIZE_LEGEND = (1 << 2)*/ // TODO
VISUALIZE_OVERLAY = (1 << 0), /**< Overlay the segmentation class colors with alpha blending */
VISUALIZE_MASK = (1 << 1), /**< View just the colorized segmentation class mask */
};

/**
Expand All @@ -145,22 +120,10 @@ class segNet : public tensorNet
static FilterMode FilterModeFromStr( const char* str, FilterMode default_value=FILTER_LINEAR );

/**
* Parse a string from one of the built-in pretrained models.
* Valid names are "cityscapes-hd", "cityscapes-sd", "pascal-voc", ect.
* @returns one of the segNet::NetworkType enums, or segNet::CUSTOM on invalid string.
*/
static NetworkType NetworkTypeFromStr( const char* model_name );

/**
* Convert a NetworkType enum to a human-readable string.
* @returns stringized version of the provided NetworkType enum.
* Load a pre-trained model.
* @see SEGNET_USAGE_STRING for the models available.
*/
static const char* NetworkTypeToStr( NetworkType networkType );

/**
* Load a new network instance
*/
static segNet* Create( NetworkType networkType=FCN_ALEXNET_CITYSCAPES_SD, uint32_t maxBatchSize=DEFAULT_MAX_BATCH_SIZE,
static segNet* Create( const char* network="fcn-resnet18-voc", uint32_t maxBatchSize=DEFAULT_MAX_BATCH_SIZE,
precisionType precision=TYPE_FASTEST, deviceType device=DEVICE_GPU, bool allowGPUFallback=true );

/**
Expand Down Expand Up @@ -348,16 +311,6 @@ class segNet : public tensorNet
*/
inline uint32_t GetGridHeight() const { return DIMS_H(mOutputs[0].dims); }

/**
* Retrieve the network type (alexnet or googlenet)
*/
inline NetworkType GetNetworkType() const { return mNetworkType; }

/**
* Retrieve a string describing the network name.
*/
inline const char* GetNetworkName() const { return NetworkTypeToStr(mNetworkType); }

protected:
segNet();

Expand All @@ -381,8 +334,6 @@ class segNet : public tensorNet
uint32_t mLastInputWidth; /**< width in pixels of last input image to be processed */
uint32_t mLastInputHeight; /**< height in pixels of last input image to be processed */
imageFormat mLastInputFormat; /**< pixel format of last input image */

NetworkType mNetworkType; /**< Pretrained built-in model type enumeration */
};


Expand Down
Loading

0 comments on commit f642afc

Please sign in to comment.