forked from wangxiang1230/OadTR
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
edda3df
commit 4a61bab
Showing
51 changed files
with
3,394 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import argparse | ||
import datetime | ||
import json | ||
import random | ||
import time | ||
import numpy as np | ||
|
||
|
||
def str2bool(string): | ||
return True if string.lower() == 'true' else False | ||
|
||
|
||
def get_args_parser(): | ||
parser = argparse.ArgumentParser('Set IDU Online Detector', add_help=False) | ||
parser.add_argument('--lr', default=1e-4, type=float) # 1e-4 | ||
parser.add_argument('--batch_size', default=128, type=int) | ||
parser.add_argument('--weight_decay', default=1e-4, type=float) | ||
parser.add_argument('--epochs', default=5, type=int) | ||
parser.add_argument('--resize_feature', default=False, type=str2bool, help='run resize prepare_data or not') | ||
parser.add_argument('--lr_drop', default=1, type=int) | ||
parser.add_argument('--clip_max_norm', default=1., type=float, | ||
help='gradient clipping max norm') # dataparallel | ||
parser.add_argument('--dataparallel', action='store_true', help='multi-gpus for training') | ||
parser.add_argument('--removelog', action='store_true', help='remove old log') | ||
|
||
# * Network | ||
parser.add_argument('--version', default='v3', type=str, | ||
help="fixed or learned") # learned fixed | ||
# decoder | ||
parser.add_argument('--query_num', default=8, type=int, | ||
help="Number of query_num (prediction)") | ||
parser.add_argument('--decoder_layers', default=5, type=int, | ||
help="Number of decoder_layers") | ||
parser.add_argument('--decoder_embedding_dim', default=1024, type=int, # 1024 | ||
help="decoder_embedding_dim") | ||
parser.add_argument('--decoder_embedding_dim_out', default=1024, type=int, # 256 512 1024 | ||
help="decoder_embedding_dim_out") | ||
parser.add_argument('--decoder_attn_dropout_rate', default=0.1, type=float, # 0.1=0.2 | ||
help="rate of decoder_attn_dropout_rate") | ||
parser.add_argument('--decoder_num_heads', default=4, type=int, # 8 4 | ||
help="decoder_num_heads") | ||
parser.add_argument('--classification_pred_loss_coef', default=0.5, type=float) # 0.5 | ||
|
||
# encoder | ||
parser.add_argument('--enc_layers', default=64, type=int, | ||
help="Number of enc_layers") | ||
parser.add_argument('--lr_backbone', default=1e-4, type=float, # 2e-4 | ||
help="lr_backbone") | ||
parser.add_argument('--feature', default='Anet2016_feature_v2', type=str, | ||
help="feature type") | ||
parser.add_argument('--dim_feature', default=3072, type=int, | ||
help="input feature dims") | ||
parser.add_argument('--patch_dim', default=1, type=int, | ||
help="input feature dims") | ||
parser.add_argument('--embedding_dim', default=1024, type=int, # 1024 | ||
help="input feature dims") | ||
parser.add_argument('--num_heads', default=8, type=int, | ||
help="input feature dims") | ||
parser.add_argument('--num_layers', default=3, type=int, | ||
help="input feature dims") | ||
parser.add_argument('--attn_dropout_rate', default=0.1, type=float, | ||
help="attn dropout") | ||
parser.add_argument('--positional_encoding_type', default='learned', type=str, | ||
help="fixed or learned") # learned fixed | ||
|
||
parser.add_argument('--hidden_dim', default=1024, type=int, # 512 1024 | ||
help="Size of the embeddings") | ||
parser.add_argument('--dropout_rate', default=0.1, type=float, | ||
help="Dropout applied ") | ||
|
||
parser.add_argument('--numclass', default=22, type=int, | ||
help="Number of class") | ||
|
||
# * Loss coefficients | ||
parser.add_argument('--classification_x_loss_coef', default=0.3, type=float) | ||
parser.add_argument('--classification_h_loss_coef', default=1, type=float) | ||
parser.add_argument('--similar_loss_coef', default=0.1, type=float) # 0.3 | ||
parser.add_argument('--margin', default=1., type=float) | ||
|
||
# dataset parameters | ||
parser.add_argument('--dataset_file', type=str, default='data/data_info_new.json') | ||
parser.add_argument('--frozen_weights', type=str, default=None) | ||
parser.add_argument('--thumos_data_path', type=str, default='/home/dancer/mycode/Temporal.Online.Detection/' | ||
'Online.TRN.Pytorch/preprocess/') | ||
parser.add_argument('--thumos_anno_path', type=str, default='data/thumos_{}_anno.pickle') | ||
parser.add_argument('--remove_difficult', action='store_true') | ||
parser.add_argument('--device', default='cuda', | ||
help='device to use for training / testing') | ||
|
||
parser.add_argument('--output_dir', default='models', | ||
help='path where to save, empty for no saving') | ||
parser.add_argument('--seed', default=20, type=int) | ||
parser.add_argument('--resume', default='', help='resume from checkpoint') | ||
parser.add_argument('--start_epoch', default=1, type=int, metavar='N', | ||
help='start epoch') | ||
|
||
parser.add_argument('--eval', action='store_true') | ||
parser.add_argument('--num_workers', default=8, type=int) | ||
|
||
# distributed training parameters | ||
parser.add_argument('--world_size', default=1, type=int, | ||
help='number of distributed processes') | ||
parser.add_argument('--dist_url', default='tcp://127.0.0.1:12342', help='url used to set up distributed training') | ||
# 'env://' | ||
return parser | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"HDD": {"class_index": ["background", "intersection passing", "left turn", "right turn", "left lane change", "right lane change", "left lane branch", "right lane branch", "crosswalk passing", "railroad passing", "merge", "U-turn"], "train_session_set": ["201702271017", "201702271123", "201702271136", "201702271438", "201702271632", "201702281017", "201702281511", "201702281709", "201703011016", "201703061033", "201703061107", "201703061323", "201703061353", "201703061418", "201703061429", "201703061456", "201703061519", "201703061541", "201703061606", "201703061635", "201703061700", "201703061725", "201703080946", "201703081008", "201703081055", "201703081152", "201703081407", "201703081437", "201703081509", "201703081549", "201703081617", "201703081653", "201703081723", "201703081749", "201704101354", "201704101504", "201704101624", "201704101658", "201704110943", "201704111011", "201704111041", "201704111138", "201704111202", "201704111315", "201704111335", "201704111402", "201704111412", "201704111540", "201706061021", "201706070945", "201706071021", "201706071319", "201706071458", "201706071518", "201706071532", "201706071602", "201706071620", "201706071630", "201706071658", "201706071735", "201706071752", "201706080945", "201706081335", "201706081445", "201706081626", "201706081707", "201706130952", "201706131127", "201706131318", "201706141033", "201706141147", "201706141538", "201706141720", "201706141819", "201709200946", "201709201027", "201709201221", "201709201319", "201709201530", "201709201605", "201709201700", "201709210940", "201709211047", "201709211317", "201709211444", "201709211547", "201709220932", "201709221037", "201709221238", "201709221313", "201709221435", "201709221527", "201710031224", "201710031247", "201710031436", "201710040938", "201710060950", "201710061114", "201710061311", "201710061345"], "test_session_set": ["201704101118", "201704130952", "201704131020", "201704131047", "201704131123", "201704131537", "201704131634", "201704131655", "201704140944", "201704141033", "201704141055", "201704141117", "201704141145", "201704141243", "201704141420", "201704141608", "201704141639", "201704141725", "201704150933", "201704151035", "201704151103", "201704151140", "201704151315", "201704151347", "201704151502", "201706061140", "201706061309", "201706061536", "201706061647", "201706140912", "201710031458", "201710031645", "201710041102", "201710041209", "201710041351", "201710041448"]}, "TVSERIES": {"class_index": ["background", "Pick something up", "Point", "Drink", "Stand up", "Run", "Sit down", "Read", "Smoke", "Drive car", "Open door", "Give something", "Use computer", "Write", "Go down stairway", "Close door", "Throw something", "Go up stairway", "Get in/out of car", "Hang up phone", "Eat", "Answer phone", "Dress up", "Clap", "Undress", "Kiss", "Fall/trip", "Wave", "Pour", "Punch", "Fire weapon"], "train_session_set": ["24_ep1", "24_ep2", "24_ep3", "Breaking_Bad_ep1", "Breaking_Bad_ep2", "How_I_Met_Your_Mother_ep1", "How_I_Met_Your_Mother_ep2", "How_I_Met_Your_Mother_ep3", "How_I_Met_Your_Mother_ep4", "How_I_Met_Your_Mother_ep5", "How_I_Met_Your_Mother_ep6", "Mad_Men_ep1", "Mad_Men_ep2", "Modern_Family_ep1", "Modern_Family_ep2", "Modern_Family_ep3", "Modern_Family_ep4", "Modern_Family_ep6", "Sons_of_Anarchy_ep1", "Sons_of_Anarchy_ep2"], "test_session_set": ["24_ep4", "Breaking_Bad_ep3", "Mad_Men_ep3", "How_I_Met_Your_Mother_ep7", "How_I_Met_Your_Mother_ep8", "Modern_Family_ep5", "Sons_of_Anarchy_ep3"]}, "THUMOS": {"class_index": ["Background", "BaseballPitch", "BasketballDunk", "Billiards", "CleanAndJerk", "CliffDiving", "CricketBowling", "CricketShot", "Diving", "FrisbeeCatch", "GolfSwing", "HammerThrow", "HighJump", "JavelinThrow", "LongJump", "PoleVault", "Shotput", "SoccerPenalty", "TennisSwing", "ThrowDiscus", "VolleyballSpiking", "Ambiguous"], "train_session_set": ["video_validation_0000690", "video_validation_0000288", "video_validation_0000289", "video_validation_0000416", "video_validation_0000282", "video_validation_0000283", "video_validation_0000281", "video_validation_0000286", "video_validation_0000287", "video_validation_0000284", "video_validation_0000285", "video_validation_0000202", "video_validation_0000203", "video_validation_0000201", "video_validation_0000206", "video_validation_0000207", "video_validation_0000204", "video_validation_0000205", "video_validation_0000790", "video_validation_0000208", "video_validation_0000209", "video_validation_0000420", "video_validation_0000364", "video_validation_0000853", "video_validation_0000950", "video_validation_0000937", "video_validation_0000367", "video_validation_0000290", "video_validation_0000210", "video_validation_0000059", "video_validation_0000058", "video_validation_0000057", "video_validation_0000056", "video_validation_0000055", "video_validation_0000054", "video_validation_0000053", "video_validation_0000052", "video_validation_0000051", "video_validation_0000933", "video_validation_0000949", "video_validation_0000948", "video_validation_0000945", "video_validation_0000944", "video_validation_0000947", "video_validation_0000946", "video_validation_0000941", "video_validation_0000940", "video_validation_0000190", "video_validation_0000942", "video_validation_0000261", "video_validation_0000262", "video_validation_0000263", "video_validation_0000264", "video_validation_0000265", "video_validation_0000266", "video_validation_0000267", "video_validation_0000268", "video_validation_0000269", "video_validation_0000989", "video_validation_0000060", "video_validation_0000370", "video_validation_0000938", "video_validation_0000935", "video_validation_0000668", "video_validation_0000669", "video_validation_0000664", "video_validation_0000665", "video_validation_0000932", "video_validation_0000667", "video_validation_0000934", "video_validation_0000661", "video_validation_0000662", "video_validation_0000663", "video_validation_0000181", "video_validation_0000180", "video_validation_0000183", "video_validation_0000182", "video_validation_0000185", "video_validation_0000184", "video_validation_0000187", "video_validation_0000186", "video_validation_0000189", "video_validation_0000188", "video_validation_0000936", "video_validation_0000270", "video_validation_0000854", "video_validation_0000178", "video_validation_0000179", "video_validation_0000174", "video_validation_0000175", "video_validation_0000176", "video_validation_0000177", "video_validation_0000170", "video_validation_0000171", "video_validation_0000172", "video_validation_0000173", "video_validation_0000670", "video_validation_0000419", "video_validation_0000943", "video_validation_0000485", "video_validation_0000369", "video_validation_0000368", "video_validation_0000318", "video_validation_0000319", "video_validation_0000415", "video_validation_0000414", "video_validation_0000413", "video_validation_0000412", "video_validation_0000411", "video_validation_0000311", "video_validation_0000312", "video_validation_0000313", "video_validation_0000314", "video_validation_0000315", "video_validation_0000316", "video_validation_0000317", "video_validation_0000418", "video_validation_0000365", "video_validation_0000482", "video_validation_0000169", "video_validation_0000168", "video_validation_0000167", "video_validation_0000166", "video_validation_0000165", "video_validation_0000164", "video_validation_0000163", "video_validation_0000162", "video_validation_0000161", "video_validation_0000160", "video_validation_0000857", "video_validation_0000856", "video_validation_0000855", "video_validation_0000366", "video_validation_0000488", "video_validation_0000489", "video_validation_0000851", "video_validation_0000484", "video_validation_0000361", "video_validation_0000486", "video_validation_0000487", "video_validation_0000481", "video_validation_0000910", "video_validation_0000483", "video_validation_0000363", "video_validation_0000990", "video_validation_0000939", "video_validation_0000362", "video_validation_0000987", "video_validation_0000859", "video_validation_0000787", "video_validation_0000786", "video_validation_0000785", "video_validation_0000784", "video_validation_0000783", "video_validation_0000782", "video_validation_0000781", "video_validation_0000981", "video_validation_0000983", "video_validation_0000982", "video_validation_0000985", "video_validation_0000984", "video_validation_0000417", "video_validation_0000788", "video_validation_0000152", "video_validation_0000153", "video_validation_0000151", "video_validation_0000156", "video_validation_0000157", "video_validation_0000154", "video_validation_0000155", "video_validation_0000158", "video_validation_0000159", "video_validation_0000901", "video_validation_0000903", "video_validation_0000902", "video_validation_0000905", "video_validation_0000904", "video_validation_0000907", "video_validation_0000906", "video_validation_0000909", "video_validation_0000908", "video_validation_0000490", "video_validation_0000860", "video_validation_0000858", "video_validation_0000988", "video_validation_0000320", "video_validation_0000688", "video_validation_0000689", "video_validation_0000686", "video_validation_0000687", "video_validation_0000684", "video_validation_0000685", "video_validation_0000682", "video_validation_0000683", "video_validation_0000681", "video_validation_0000789", "video_validation_0000986", "video_validation_0000931", "video_validation_0000852", "video_validation_0000666"], "test_session_set": ["video_test_0000292", "video_test_0001078", "video_test_0000896", "video_test_0000897", "video_test_0000950", "video_test_0001159", "video_test_0001079", "video_test_0000807", "video_test_0000179", "video_test_0000173", "video_test_0001072", "video_test_0001075", "video_test_0000767", "video_test_0001076", "video_test_0000007", "video_test_0000006", "video_test_0000556", "video_test_0001307", "video_test_0001153", "video_test_0000718", "video_test_0000716", "video_test_0001309", "video_test_0000714", "video_test_0000558", "video_test_0001267", "video_test_0000367", "video_test_0001324", "video_test_0000085", "video_test_0000887", "video_test_0001281", "video_test_0000882", "video_test_0000671", "video_test_0000964", "video_test_0001164", "video_test_0001114", "video_test_0000771", "video_test_0001163", "video_test_0001118", "video_test_0001201", "video_test_0001040", "video_test_0001207", "video_test_0000723", "video_test_0000569", "video_test_0000672", "video_test_0000673", "video_test_0000278", "video_test_0001162", "video_test_0000405", "video_test_0000073", "video_test_0000560", "video_test_0001276", "video_test_0000270", "video_test_0000273", "video_test_0000374", "video_test_0000372", "video_test_0001168", "video_test_0000379", "video_test_0001446", "video_test_0001447", "video_test_0001098", "video_test_0000873", "video_test_0000039", "video_test_0000442", "video_test_0001219", "video_test_0000762", "video_test_0000611", "video_test_0000617", "video_test_0000615", "video_test_0001270", "video_test_0000740", "video_test_0000293", "video_test_0000504", "video_test_0000505", "video_test_0000665", "video_test_0000664", "video_test_0000577", "video_test_0000814", "video_test_0001369", "video_test_0001194", "video_test_0001195", "video_test_0001512", "video_test_0001235", "video_test_0001459", "video_test_0000691", "video_test_0000765", "video_test_0001452", "video_test_0000188", "video_test_0000591", "video_test_0001268", "video_test_0000593", "video_test_0000864", "video_test_0000601", "video_test_0001135", "video_test_0000004", "video_test_0000903", "video_test_0000285", "video_test_0001174", "video_test_0000046", "video_test_0000045", "video_test_0001223", "video_test_0001358", "video_test_0001134", "video_test_0000698", "video_test_0000461", "video_test_0001182", "video_test_0000450", "video_test_0000602", "video_test_0001229", "video_test_0000989", "video_test_0000357", "video_test_0001039", "video_test_0000355", "video_test_0000353", "video_test_0001508", "video_test_0000981", "video_test_0000242", "video_test_0000854", "video_test_0001484", "video_test_0000635", "video_test_0001129", "video_test_0001339", "video_test_0001483", "video_test_0001123", "video_test_0001127", "video_test_0000689", "video_test_0000756", "video_test_0001431", "video_test_0000129", "video_test_0001433", "video_test_0001343", "video_test_0000324", "video_test_0001064", "video_test_0001531", "video_test_0001532", "video_test_0000413", "video_test_0000991", "video_test_0001255", "video_test_0000464", "video_test_0001202", "video_test_0001080", "video_test_0001081", "video_test_0000847", "video_test_0000028", "video_test_0000844", "video_test_0000622", "video_test_0000026", "video_test_0001325", "video_test_0001496", "video_test_0001495", "video_test_0000624", "video_test_0000724", "video_test_0001409", "video_test_0000131", "video_test_0000448", "video_test_0000444", "video_test_0000443", "video_test_0001038", "video_test_0000238", "video_test_0001527", "video_test_0001522", "video_test_0000051", "video_test_0001058", "video_test_0001391", "video_test_0000429", "video_test_0000426", "video_test_0000785", "video_test_0000786", "video_test_0001314", "video_test_0000392", "video_test_0000423", "video_test_0001146", "video_test_0001313", "video_test_0001008", "video_test_0001247", "video_test_0000737", "video_test_0001319", "video_test_0000308", "video_test_0000730", "video_test_0000058", "video_test_0000538", "video_test_0001556", "video_test_0000113", "video_test_0000626", "video_test_0000839", "video_test_0000220", "video_test_0001389", "video_test_0000437", "video_test_0000940", "video_test_0000211", "video_test_0000946", "video_test_0001558", "video_test_0000796", "video_test_0000062", "video_test_0000793", "video_test_0000987", "video_test_0001066", "video_test_0000412", "video_test_0000798", "video_test_0001549", "video_test_0000011", "video_test_0001257", "video_test_0000541", "video_test_0000701", "video_test_0000250", "video_test_0000254", "video_test_0000549", "video_test_0001209", "video_test_0001463", "video_test_0001460", "video_test_0000319", "video_test_0001468", "video_test_0000846", "video_test_0001292"]}} |
Oops, something went wrong.