diff --git a/.gitignore b/.gitignore index 88f397f48..578f420bd 100644 --- a/.gitignore +++ b/.gitignore @@ -8,8 +8,8 @@ gen .settings/ .classpath .vscode/ +**/sdk.zip .factorypath - # 排除编译输出 *.class target/ @@ -44,5 +44,8 @@ build/ *.properties venv/ /docker/ -**/assembly/ -/assembly/ + +# 排除环境配置文件 +.env.local +/.env.local +.env.*.local diff --git a/README.md b/README.md index e9fcd3efd..2527605ec 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,16 @@ WeFe ( WeLab Federated Learning ) 是 Welab 汇立集团子公司[天冕](https: `Documentation:` https://tianmiantech.github.io/WeFe/ or http://tianmiantech.gitee.io/wefe +## 在线体验 + +WeFe 不仅支持本地部署运行测试,并且提供了一套完整的线上体验环境; + +用户可以通过线上体验环境,模拟联邦中三位成员间的建模操作; + +体验环境的联邦成员角色有 DemoMember1、DemoMember2、DemoMember3; + +详情访问[在线体验平台](https://tianmiantech.com/federal)体验。 + # 项目特点 混合联邦,纵向联邦学习与横向联邦学习结合的行业解决方案; @@ -49,16 +59,6 @@ WeFe 提供了一套完整的在线体验环境 Demo ENV。 单机部署,详见 [release/docker/README.md](./release/docker) -## 在线体验 - -WeFe 不仅支持本地部署运行测试,并且提供了一套完整的线上体验环境; - -用户可以通过线上体验环境,模拟联邦中三位成员间的建模操作; - -体验环境的联邦成员角色有 DemoMember1、DemoMember2、DemoMember3; - -详情访问[在线体验平台](https://tianmiantech.com/federal)体验。 - # 系统架构 WeFe 系统由两大模块 union 与 member 组成; diff --git a/VisualFL/config.properties b/VisualFL/config.properties new file mode 100644 index 000000000..af13820a8 --- /dev/null +++ b/VisualFL/config.properties @@ -0,0 +1,13 @@ + +# The configuration for business database mysql +# mysql is used to save modeling processes, modeling information, member information, and more +# ************************************************ +db.mysql.url=jdbc:mysql://127.0.0.1:3306/wefe_board?serverTimezone=GMT%2B8 +db.mysql.host=127.0.0.1 +db.mysql.port=3306 +db.mysql.database=wefe_board +db.mysql.username=xxx +db.mysql.password=xxxxxxx +is_local=false +# ************************************************ + diff --git a/VisualFL/data/coco/download_coco.py b/VisualFL/data/coco/download_coco.py new file mode 100644 index 000000000..29b8db7cd --- /dev/null +++ b/VisualFL/data/coco/download_coco.py @@ -0,0 +1,42 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os.path as osp +import logging +# add python path of PadleDetection to sys.path +parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.utils.download import download_dataset + +logging.basicConfig(level=logging.INFO) + +download_path = osp.split(osp.realpath(sys.argv[0]))[0] +download_dataset(download_path, 'coco') diff --git a/VisualFL/data/fddb/download.sh b/VisualFL/data/fddb/download.sh new file mode 100755 index 000000000..29375d791 --- /dev/null +++ b/VisualFL/data/fddb/download.sh @@ -0,0 +1,31 @@ +# All rights `PaddleDetection` reserved +# References: +# @TechReport{fddbTech, +# author = {Vidit Jain and Erik Learned-Miller}, +# title = {FDDB: A Benchmark for Face Detection in Unconstrained Settings}, +# institution = {University of Massachusetts, Amherst}, +# year = {2010}, +# number = {UM-CS-2010-009} +# } + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd "$DIR" + +# Download the data. +echo "Downloading..." +# external link to the Faces in the Wild data set and annotations file +wget http://tamaraberg.com/faceDataset/originalPics.tar.gz +wget http://vis-www.cs.umass.edu/fddb/FDDB-folds.tgz +wget http://vis-www.cs.umass.edu/fddb/evaluation.tgz + +# Extract the data. +echo "Extracting..." +tar -zxf originalPics.tar.gz +tar -zxf FDDB-folds.tgz +tar -zxf evaluation.tgz + +# Generate full image path list and groundtruth in FDDB-folds: +cd FDDB-folds +cat `ls|grep -v"ellipse"` > filePath.txt && cat *ellipse* > fddb_annotFile.txt +cd .. +echo "------------- All done! --------------" diff --git a/VisualFL/data/fruit/download_fruit.py b/VisualFL/data/fruit/download_fruit.py new file mode 100644 index 000000000..99434d0f9 --- /dev/null +++ b/VisualFL/data/fruit/download_fruit.py @@ -0,0 +1,42 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os.path as osp +import logging +# add python path of PadleDetection to sys.path +parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.utils.download import download_dataset + +logging.basicConfig(level=logging.INFO) + +download_path = osp.split(osp.realpath(sys.argv[0]))[0] +download_dataset(download_path, 'fruit') diff --git a/VisualFL/data/fruit/label_list.txt b/VisualFL/data/fruit/label_list.txt new file mode 100644 index 000000000..1f60d62c3 --- /dev/null +++ b/VisualFL/data/fruit/label_list.txt @@ -0,0 +1,3 @@ +apple +banana +orange diff --git a/VisualFL/data/roadsign_voc/download_roadsign_voc.py b/VisualFL/data/roadsign_voc/download_roadsign_voc.py new file mode 100644 index 000000000..e6560019d --- /dev/null +++ b/VisualFL/data/roadsign_voc/download_roadsign_voc.py @@ -0,0 +1,42 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os.path as osp +import logging +# add python path of PadleDetection to sys.path +parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.utils.download import download_dataset + +logging.basicConfig(level=logging.INFO) + +download_path = osp.split(osp.realpath(sys.argv[0]))[0] +download_dataset(download_path, 'roadsign_voc') diff --git a/VisualFL/data/roadsign_voc/label_list.txt b/VisualFL/data/roadsign_voc/label_list.txt new file mode 100644 index 000000000..1be460f45 --- /dev/null +++ b/VisualFL/data/roadsign_voc/label_list.txt @@ -0,0 +1,4 @@ +speedlimit +crosswalk +trafficlight +stop \ No newline at end of file diff --git a/VisualFL/data/voc/create_list.py b/VisualFL/data/voc/create_list.py new file mode 100644 index 000000000..65607dd7a --- /dev/null +++ b/VisualFL/data/voc/create_list.py @@ -0,0 +1,59 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os.path as osp +import logging +import argparse + +# add python path of PadleDetection to sys.path +parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.utils.download import create_voc_list +logging.basicConfig(level=logging.INFO) + + +def main(config): + voc_path = config.dataset_dir + create_voc_list(voc_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + default_voc_path = osp.split(osp.realpath(sys.argv[0]))[0] + parser.add_argument( + "-d", + "--dataset_dir", + default=default_voc_path, + type=str, + help="VOC dataset directory, default is current directory.") + config = parser.parse_args() + + main(config) diff --git a/VisualFL/data/voc/download_voc.py b/VisualFL/data/voc/download_voc.py new file mode 100644 index 000000000..7555c1f27 --- /dev/null +++ b/VisualFL/data/voc/download_voc.py @@ -0,0 +1,43 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os.path as osp +import logging +# add python path of PadleDetection to sys.path +parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.utils.download import download_dataset, create_voc_list + +logging.basicConfig(level=logging.INFO) + +download_path = osp.split(osp.realpath(sys.argv[0]))[0] +download_dataset(download_path, 'voc') +create_voc_list(download_path) diff --git a/VisualFL/data/voc/generic_det_label_list.txt b/VisualFL/data/voc/generic_det_label_list.txt new file mode 100644 index 000000000..410f9ae59 --- /dev/null +++ b/VisualFL/data/voc/generic_det_label_list.txt @@ -0,0 +1,676 @@ +Infant bed +Rose +Flag +Flashlight +Sea turtle +Camera +Animal +Glove +Crocodile +Cattle +House +Guacamole +Penguin +Vehicle registration plate +Bench +Ladybug +Human nose +Watermelon +Flute +Butterfly +Washing machine +Raccoon +Segway +Taco +Jellyfish +Cake +Pen +Cannon +Bread +Tree +Shellfish +Bed +Hamster +Hat +Toaster +Sombrero +Tiara +Bowl +Dragonfly +Moths and butterflies +Antelope +Vegetable +Torch +Building +Power plugs and sockets +Blender +Billiard table +Cutting board +Bronze sculpture +Turtle +Broccoli +Tiger +Mirror +Bear +Zucchini +Dress +Volleyball +Guitar +Reptile +Golf cart +Tart +Fedora +Carnivore +Car +Lighthouse +Coffeemaker +Food processor +Truck +Bookcase +Surfboard +Footwear +Bench +Necklace +Flower +Radish +Marine mammal +Frying pan +Tap +Peach +Knife +Handbag +Laptop +Tent +Ambulance +Christmas tree +Eagle +Limousine +Kitchen & dining room table +Polar bear +Tower +Football +Willow +Human head +Stop sign +Banana +Mixer +Binoculars +Dessert +Bee +Chair +Wood-burning stove +Flowerpot +Beaker +Oyster +Woodpecker +Harp +Bathtub +Wall clock +Sports uniform +Rhinoceros +Beehive +Cupboard +Chicken +Man +Blue jay +Cucumber +Balloon +Kite +Fireplace +Lantern +Missile +Book +Spoon +Grapefruit +Squirrel +Orange +Coat +Punching bag +Zebra +Billboard +Bicycle +Door handle +Mechanical fan +Ring binder +Table +Parrot +Sock +Vase +Weapon +Shotgun +Glasses +Seahorse +Belt +Watercraft +Window +Giraffe +Lion +Tire +Vehicle +Canoe +Tie +Shelf +Picture frame +Printer +Human leg +Boat +Slow cooker +Croissant +Candle +Pancake +Pillow +Coin +Stretcher +Sandal +Woman +Stairs +Harpsichord +Stool +Bus +Suitcase +Human mouth +Juice +Skull +Door +Violin +Chopsticks +Digital clock +Sunflower +Leopard +Bell pepper +Harbor seal +Snake +Sewing machine +Goose +Helicopter +Seat belt +Coffee cup +Microwave oven +Hot dog +Countertop +Serving tray +Dog bed +Beer +Sunglasses +Golf ball +Waffle +Palm tree +Trumpet +Ruler +Helmet +Ladder +Office building +Tablet computer +Toilet paper +Pomegranate +Skirt +Gas stove +Cookie +Cart +Raven +Egg +Burrito +Goat +Kitchen knife +Skateboard +Salt and pepper shakers +Lynx +Boot +Platter +Ski +Swimwear +Swimming pool +Drinking straw +Wrench +Drum +Ant +Human ear +Headphones +Fountain +Bird +Jeans +Television +Crab +Microphone +Home appliance +Snowplow +Beetle +Artichoke +Jet ski +Stationary bicycle +Human hair +Brown bear +Starfish +Fork +Lobster +Corded phone +Drink +Saucer +Carrot +Insect +Clock +Castle +Tennis racket +Ceiling fan +Asparagus +Jaguar +Musical instrument +Train +Cat +Rifle +Dumbbell +Mobile phone +Taxi +Shower +Pitcher +Lemon +Invertebrate +Turkey +High heels +Bust +Elephant +Scarf +Barrel +Trombone +Pumpkin +Box +Tomato +Frog +Bidet +Human face +Houseplant +Van +Shark +Ice cream +Swim cap +Falcon +Ostrich +Handgun +Whiteboard +Lizard +Pasta +Snowmobile +Light bulb +Window blind +Muffin +Pretzel +Computer monitor +Horn +Furniture +Sandwich +Fox +Convenience store +Fish +Fruit +Earrings +Curtain +Grape +Sofa bed +Horse +Luggage and bags +Desk +Crutch +Bicycle helmet +Tick +Airplane +Canary +Spatula +Watch +Lily +Kitchen appliance +Filing cabinet +Aircraft +Cake stand +Candy +Sink +Mouse +Wine +Wheelchair +Goldfish +Refrigerator +French fries +Drawer +Treadmill +Picnic basket +Dice +Cabbage +Football helmet +Pig +Person +Shorts +Gondola +Honeycomb +Doughnut +Chest of drawers +Land vehicle +Bat +Monkey +Dagger +Tableware +Human foot +Mug +Alarm clock +Pressure cooker +Human hand +Tortoise +Baseball glove +Sword +Pear +Miniskirt +Traffic sign +Girl +Roller skates +Dinosaur +Porch +Human beard +Submarine sandwich +Screwdriver +Strawberry +Wine glass +Seafood +Racket +Wheel +Sea lion +Toy +Tea +Tennis ball +Waste container +Mule +Cricket ball +Pineapple +Coconut +Doll +Coffee table +Snowman +Lavender +Shrimp +Maple +Cowboy hat +Goggles +Rugby ball +Caterpillar +Poster +Rocket +Organ +Saxophone +Traffic light +Cocktail +Plastic bag +Squash +Mushroom +Hamburger +Light switch +Parachute +Teddy bear +Winter melon +Deer +Musical keyboard +Plumbing fixture +Scoreboard +Baseball bat +Envelope +Adhesive tape +Briefcase +Paddle +Bow and arrow +Telephone +Sheep +Jacket +Boy +Pizza +Otter +Office supplies +Couch +Cello +Bull +Camel +Ball +Duck +Whale +Shirt +Tank +Motorcycle +Accordion +Owl +Porcupine +Sun hat +Nail +Scissors +Swan +Lamp +Crown +Piano +Sculpture +Cheetah +Oboe +Tin can +Mango +Tripod +Oven +Mouse +Barge +Coffee +Snowboard +Common fig +Salad +Marine invertebrates +Umbrella +Kangaroo +Human arm +Measuring cup +Snail +Loveseat +Suit +Teapot +Bottle +Alpaca +Kettle +Trousers +Popcorn +Centipede +Spider +Sparrow +Plate +Bagel +Personal care +Apple +Brassiere +Bathroom cabinet +studio couch +Computer keyboard +Table tennis racket +Sushi +Cabinetry +Street light +Towel +Nightstand +Rabbit +Dolphin +Dog +Jug +Wok +Fire hydrant +Human eye +Skyscraper +Backpack +Potato +Paper towel +Lifejacket +Bicycle wheel +Toilet +tuba +carpet +trolley +tv +fan +llama +stapler +tricycle +head_phone +air_conditioner +cookies +towel/napkin +boots +sausage +suv +bar_soap +baseball +luggage +poker_card +shovel +marker +earphone +projector +pencil_case +french_horn +tangerine +router/modem +folder +donut +durian +sailboat +nuts +coffee_machine +meat_balls +basket +extension_cord +green_beans +avocado +soccer +egg_tart +clutch +slide +fishing_rod +hanger +bread/bun +surveillance_camera +globe +blackboard/whiteboard +life_saver +pigeon +red_cabbage +cymbal +faucet +steak +swing +mangosteen +cheese +urinal +lettuce +hurdle +ring +basketball +potted_plant +rickshaw +target +race_car +bow_tie +iron +toiletries +donkey +saw +hammer +billiards +cutting/chopping_board +power_outlet +hair_drier +baozi +medal +liquid_soap +wild_bird +leather_shoes +dining_table +game_board +barbell +radio +street_lights +tape +hockey +spring_rolls +rice +golf_club +lighter +chips +microscope +cell_phone +fire_truck +noodles +cabinet/shelf +electronic_stove_and_gas_stove +key +comb +trash_bin/can +toothbrush +dates +electric_drill +cow +eggplant +broom +vent +tong +green_onion +scallop +facial_cleanser +toothpaste +hamimelon +eraser +shampoo/shower_gel +CD +skating_and_skiing_shoes +american_football +slippers +pitaya +pot/pan +calculator +tissue +table_tennis_paddle +board_eraser +speaker +papaya +cigar +notepaper +garlic +rice_cooker +canned +parking_meter +flashlight +paint_brush +cup +cue +crosswalk_sign +kiwi_fruit +radiator +mop +chainsaw +sandals +storage_box +onion +bracelet +fire_extinguisher +scale +okra +microwave +sneakers +pepper +corn +pomelo +computer_box +pliers +trophy +plum +brush +machinery_vehicle +yak +crane +converter +facial_mask +carriage +pickup_truck +traffic_cone +pie +pen/pencil +sports_car +frisbee +cleaning_products +remote +stroller diff --git a/VisualFL/data/voc/generic_det_label_list_zh.txt b/VisualFL/data/voc/generic_det_label_list_zh.txt new file mode 100644 index 000000000..0012d759d --- /dev/null +++ b/VisualFL/data/voc/generic_det_label_list_zh.txt @@ -0,0 +1,676 @@ +婴儿床 +玫瑰 +旗 +手电筒 +海龟 +照相机 +动物 +手套 +鳄鱼 +牛 +房子 +鳄梨酱 +企鹅 +车辆牌照 +凳子 +瓢虫 +人鼻 +西瓜 +长笛 +蝴蝶 +洗衣机 +浣熊 +赛格威 +墨西哥玉米薄饼卷 +海蜇 +蛋糕 +笔 +加农炮 +面包 +树 +贝类 +床 +仓鼠 +帽子 +烤面包机 +帽帽 +冠状头饰 +碗 +蜻蜓 +飞蛾和蝴蝶 +羚羊 +蔬菜 +火炬 +建筑物 +电源插头和插座 +搅拌机 +台球桌 +切割板 +青铜雕塑 +乌龟 +西兰花 +老虎 +镜子 +熊 +西葫芦 +礼服 +排球 +吉他 +爬行动物 +高尔夫球车 +蛋挞 +费多拉 +食肉动物 +小型车 +灯塔 +咖啡壶 +食品加工厂 +卡车 +书柜 +冲浪板 +鞋类 +凳子 +项链 +花 +萝卜 +海洋哺乳动物 +煎锅 +水龙头 +桃 +刀 +手提包 +笔记本电脑 +帐篷 +救护车 +圣诞树 +鹰 +豪华轿车 +厨房和餐桌 +北极熊 +塔楼 +足球 +柳树 +人头 +停车标志 +香蕉 +搅拌机 +双筒望远镜 +甜点 +蜜蜂 +椅子 +烧柴炉 +花盆 +烧杯 +牡蛎 +啄木鸟 +竖琴 +浴缸 +挂钟 +运动服 +犀牛 +蜂箱 +橱柜 +鸡 +人 +冠蓝鸦 +黄瓜 +气球 +风筝 +壁炉 +灯笼 +导弹 +书 +勺子 +葡萄柚 +松鼠 +橙色 +外套 +打孔袋 +斑马 +广告牌 +自行车 +门把手 +机械风扇 +环形粘结剂 +桌子 +鹦鹉 +袜子 +花瓶 +武器 +猎枪 +玻璃杯 +海马 +腰带 +船舶 +窗口 +长颈鹿 +狮子 +轮胎 +车辆 +独木舟 +领带 +架子 +相框 +打印机 +人腿 +小船 +慢炖锅 +牛角包 +蜡烛 +煎饼 +枕头 +硬币 +担架 +凉鞋 +女人 +楼梯 +拨弦键琴 +凳子 +公共汽车 +手提箱 +人口学 +果汁 +颅骨 +门 +小提琴 +筷子 +数字时钟 +向日葵 +豹 +甜椒 +海港海豹 +蛇 +缝纫机 +鹅 +直升机 +座椅安全带 +咖啡杯 +微波炉 +热狗 +台面 +服务托盘 +狗床 +啤酒 +太阳镜 +高尔夫球 +华夫饼干 +棕榈树 +小号 +尺子 +头盔 +梯子 +办公楼 +平板电脑 +厕纸 +石榴 +裙子 +煤气炉 +曲奇饼干 +大车 +掠夺 +鸡蛋 +墨西哥煎饼 +山羊 +菜刀 +滑板 +盐和胡椒瓶 +猞猁 +靴子 +大浅盘 +滑雪板 +泳装 +游泳池 +吸管 +扳手 +鼓 +蚂蚁 +人耳 +耳机 +喷泉 +鸟 +牛仔裤 +电视机 +蟹 +话筒 +家用电器 +除雪机 +甲虫 +朝鲜蓟 +喷气式滑雪板 +固定自行车 +人发 +棕熊 +海星 +叉子 +龙虾 +有线电话 +饮料 +碟 +胡萝卜 +昆虫 +时钟 +城堡 +网球拍 +吊扇 +芦笋 +美洲虎 +乐器 +火车 +猫 +来复枪 +哑铃 +手机 +出租车 +淋浴 +投掷者 +柠檬 +无脊椎动物 +火鸡 +高跟鞋 +打破 +大象 +围巾 +枪管 +长号 +南瓜 +盒子 +番茄 +蛙 +坐浴盆 +人脸 +室内植物 +厢式货车 +鲨鱼 +冰淇淋 +游泳帽 +隼 +鸵鸟 +手枪 +白板 +蜥蜴 +面食 +雪车 +灯泡 +窗盲 +松饼 +椒盐脆饼 +计算机显示器 +喇叭 +家具 +三明治 +福克斯 +便利店 +鱼 +水果 +耳环 +帷幕 +葡萄 +沙发床 +马 +行李和行李 +书桌 +拐杖 +自行车头盔 +滴答声 +飞机 +金丝雀 +铲 +手表 +莉莉 +厨房用具 +文件柜 +飞机 +蛋糕架 +糖果 +水槽 +鼠标 +葡萄酒 +轮椅 +金鱼 +冰箱 +炸薯条 +抽屉 +单调的工作 +野餐篮子 +骰子 +甘蓝 +足球头盔 +猪 +人 +短裤 +贡多拉 +蜂巢 +炸圈饼 +抽屉柜 +陆地车辆 +蝙蝠 +猴子 +匕首 +餐具 +人足 +马克杯 +闹钟 +高压锅 +人手 +乌龟 +棒球手套 +剑 +梨 +迷你裙 +交通标志 +女孩 +旱冰鞋 +恐龙 +门廊 +胡须 +潜艇三明治 +螺丝起子 +草莓 +酒杯 +海鲜 +球拍 +车轮 +海狮 +玩具 +茶叶 +网球 +废物容器 +骡子 +板球 +菠萝 +椰子 +娃娃 +咖啡桌 +雪人 +薰衣草 +小虾 +枫树 +牛仔帽 +护目镜 +橄榄球 +毛虫 +海报 +火箭 +器官 +萨克斯 +交通灯 +鸡尾酒 +塑料袋 +壁球 +蘑菇 +汉堡包 +电灯开关 +降落伞 +泰迪熊 +冬瓜 +鹿 +音乐键盘 +卫生器具 +记分牌 +棒球棒 +包络线 +胶带 +公文包 +桨 +弓箭 +电话 +羊 +夹克 +男孩 +披萨 +水獭 +办公用品 +沙发 +大提琴 +公牛 +骆驼 +球 +鸭子 +鲸鱼 +衬衫 +坦克 +摩托车 +手风琴 +猫头鹰 +豪猪 +太阳帽 +钉子 +剪刀 +天鹅 +灯 +皇冠 +钢琴 +雕塑 +猎豹 +双簧管 +罐头罐 +芒果 +三脚架 +烤箱 +鼠标 +驳船 +咖啡 +滑雪板 +普通无花果 +沙拉 +无脊椎动物 +雨伞 +袋鼠 +人手臂 +量杯 +蜗牛 +相思 +西服 +茶壶 +瓶 +羊驼 +水壶 +裤子 +爆米花 +蜈蚣 +蜘蛛 +麻雀 +盘子 +百吉饼 +个人护理 +苹果 +胸罩 +浴室柜 +演播室沙发 +电脑键盘 +乒乓球拍 +寿司 +橱柜 +路灯 +毛巾 +床头柜 +兔 +海豚 +狗 +大罐 +炒锅 +消火栓 +人眼 +摩天大楼 +背包 +马铃薯 +纸巾 +小精灵 +自行车车轮 +卫生间 +大号 +地毯 +手推车 +电视 +风扇 +美洲驼 +订书机 +三轮车 +耳机 +空调器 +饼干 +毛巾/餐巾 +靴子 +香肠 +运动型多用途汽车 +肥皂 +棒球 +行李 +扑克牌 +铲子 +标记笔 +耳机 +投影机 +铅笔盒 +法国圆号 +橘子 +路由器 +文件夹 +甜甜圈 +榴莲 +帆船 +坚果 +咖啡机 +肉丸 +篮子 +插线板 +青豆 +鳄梨 +英式足球 +蛋挞 +离合器 +滑梯 +鱼竿 +衣架 +面包 +监控摄像头 +地球仪 +黑板/白板 +救生员 +鸽子 +红卷心菜 +铜钹 +水龙头 +牛排 +秋千 +山竹 +奶酪 +小便池 +生菜 +跨栏 +戒指 +篮球 +盆栽植物 +人力车 +目标 +赛车 +蝴蝶结 +熨斗 +化妆品 +驴 +锯 +铁锤 +台球 +切割/砧板 +电源插座 +吹风机 +包子 +奖章/奖牌 +液体肥皂 +野鸟 +皮鞋 +餐桌 +游戏板 +杠铃 +收音机 +路灯 +磁带 +曲棍球 +春卷 +大米 +高尔夫俱乐部 +打火机 +炸薯条 +显微镜 +手机 +消防车 +面条 +橱柜/架子 +电磁炉和煤气炉 +钥匙 +梳子 +垃圾箱/罐 +牙刷 +枣子 +电钻 +奶牛 +茄子 +扫帚 +抽油烟机 +钳子 +大葱 +扇贝 +洁面乳 +牙膏 +哈密瓜 +橡皮擦 +洗发水/沐浴露 +光盘 +溜冰鞋和滑雪鞋 +美式足球 +拖鞋 +火龙果 +锅/平底锅 +计算器 +纸巾 +乒乓球拍 +板擦 +扬声器 +木瓜 +雪茄 +信纸 +大蒜 +电饭锅 +罐装的 +停车计时器 +手电筒 +画笔 +杯子 +球杆 +人行横道标志 +奇异果/猕猴桃 +散热器 +拖把 +电锯 +凉鞋拖鞋 +储物箱 +洋葱 +手镯 +灭火器 +秤 +秋葵 +微波炉 +运动鞋 +胡椒 +玉米 +柚子 +主机 +钳子 +奖杯 +李子/梅子 +刷子/画笔 +机械车辆 +牦牛 +起重机 +转换器 +面膜 +马车 +皮卡车 +交通锥 +馅饼 +钢笔/铅笔 +跑车 +飞盘 +清洁用品/洗涤剂/洗衣液 +遥控器 +婴儿车/手推车 diff --git a/VisualFL/data/voc/label_list.txt b/VisualFL/data/voc/label_list.txt new file mode 100644 index 000000000..8420ab35e --- /dev/null +++ b/VisualFL/data/voc/label_list.txt @@ -0,0 +1,20 @@ +aeroplane +bicycle +bird +boat +bottle +bus +car +cat +chair +cow +diningtable +dog +horse +motorbike +person +pottedplant +sheep +sofa +train +tvmonitor diff --git a/VisualFL/data/wider_face/download.sh b/VisualFL/data/wider_face/download.sh new file mode 100755 index 000000000..59a2054de --- /dev/null +++ b/VisualFL/data/wider_face/download.sh @@ -0,0 +1,21 @@ +# All rights `PaddleDetection` reserved +# References: +# @inproceedings{yang2016wider, +# Author = {Yang, Shuo and Luo, Ping and Loy, Chen Change and Tang, Xiaoou}, +# Booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, +# Title = {WIDER FACE: A Face Detection Benchmark}, +# Year = {2016}} + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd "$DIR" + +# Download the data. +echo "Downloading..." +wget https://dataset.bj.bcebos.com/wider_face/WIDER_train.zip +wget https://dataset.bj.bcebos.com/wider_face/WIDER_val.zip +wget https://dataset.bj.bcebos.com/wider_face/wider_face_split.zip +# Extract the data. +echo "Extracting..." +unzip -q WIDER_train.zip +unzip -q WIDER_val.zip +unzip -q wider_face_split.zip diff --git a/VisualFL/depends/PaddleDetection/ppdet/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/__init__.py new file mode 100755 index 000000000..d0c32e260 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/VisualFL/depends/PaddleDetection/ppdet/core/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/core/__init__.py new file mode 100755 index 000000000..f8561f944 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/core/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ppdet.modeling +import ppdet.optimizer +import ppdet.data diff --git a/VisualFL/depends/PaddleDetection/ppdet/core/config/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/core/config/__init__.py new file mode 100755 index 000000000..d0c32e260 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/core/config/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/VisualFL/depends/PaddleDetection/ppdet/core/config/schema.py b/VisualFL/depends/PaddleDetection/ppdet/core/config/schema.py new file mode 100755 index 000000000..0d2b0dabf --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/core/config/schema.py @@ -0,0 +1,248 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import inspect +import importlib +import re + +try: + from docstring_parser import parse as doc_parse +except Exception: + + def doc_parse(*args): + pass + + +try: + from typeguard import check_type +except Exception: + + def check_type(*args): + pass + + +__all__ = ['SchemaValue', 'SchemaDict', 'SharedConfig', 'extract_schema'] + + +class SchemaValue(object): + def __init__(self, name, doc='', type=None): + super(SchemaValue, self).__init__() + self.name = name + self.doc = doc + self.type = type + + def set_default(self, value): + self.default = value + + def has_default(self): + return hasattr(self, 'default') + + +class SchemaDict(dict): + def __init__(self, **kwargs): + super(SchemaDict, self).__init__() + self.schema = {} + self.strict = False + self.doc = "" + self.update(kwargs) + + def __setitem__(self, key, value): + # XXX also update regular dict to SchemaDict?? + if isinstance(value, dict) and key in self and isinstance(self[key], + SchemaDict): + self[key].update(value) + else: + super(SchemaDict, self).__setitem__(key, value) + + def __missing__(self, key): + if self.has_default(key): + return self.schema[key].default + elif key in self.schema: + return self.schema[key] + else: + raise KeyError(key) + + def copy(self): + newone = SchemaDict() + newone.__dict__.update(self.__dict__) + newone.update(self) + return newone + + def set_schema(self, key, value): + assert isinstance(value, SchemaValue) + self.schema[key] = value + + def set_strict(self, strict): + self.strict = strict + + def has_default(self, key): + return key in self.schema and self.schema[key].has_default() + + def is_default(self, key): + if not self.has_default(key): + return False + if hasattr(self[key], '__dict__'): + return True + else: + return key not in self or self[key] == self.schema[key].default + + def find_default_keys(self): + return [ + k for k in list(self.keys()) + list(self.schema.keys()) + if self.is_default(k) + ] + + def mandatory(self): + return any([k for k in self.schema.keys() if not self.has_default(k)]) + + def find_missing_keys(self): + missing = [ + k for k in self.schema.keys() + if k not in self and not self.has_default(k) + ] + placeholders = [k for k in self if self[k] in ('', '')] + return missing + placeholders + + def find_extra_keys(self): + return list(set(self.keys()) - set(self.schema.keys())) + + def find_mismatch_keys(self): + mismatch_keys = [] + for arg in self.schema.values(): + if arg.type is not None: + try: + check_type("{}.{}".format(self.name, arg.name), + self[arg.name], arg.type) + except Exception: + mismatch_keys.append(arg.name) + return mismatch_keys + + def validate(self): + missing_keys = self.find_missing_keys() + if missing_keys: + raise ValueError("Missing param for class<{}>: {}".format( + self.name, ", ".join(missing_keys))) + extra_keys = self.find_extra_keys() + if extra_keys and self.strict: + raise ValueError("Extraneous param for class<{}>: {}".format( + self.name, ", ".join(extra_keys))) + mismatch_keys = self.find_mismatch_keys() + if mismatch_keys: + raise TypeError("Wrong param type for class<{}>: {}".format( + self.name, ", ".join(mismatch_keys))) + + +class SharedConfig(object): + """ + Representation class for `__shared__` annotations, which work as follows: + + - if `key` is set for the module in config file, its value will take + precedence + - if `key` is not set for the module but present in the config file, its + value will be used + - otherwise, use the provided `default_value` as fallback + + Args: + key: config[key] will be injected + default_value: fallback value + """ + + def __init__(self, key, default_value=None): + super(SharedConfig, self).__init__() + self.key = key + self.default_value = default_value + + +def extract_schema(cls): + """ + Extract schema from a given class + + Args: + cls (type): Class from which to extract. + + Returns: + schema (SchemaDict): Extracted schema. + """ + ctor = cls.__init__ + # python 2 compatibility + if hasattr(inspect, 'getfullargspec'): + argspec = inspect.getfullargspec(ctor) + annotations = argspec.annotations + has_kwargs = argspec.varkw is not None + else: + argspec = inspect.getargspec(ctor) + # python 2 type hinting workaround, see pep-3107 + # however, since `typeguard` does not support python 2, type checking + # is still python 3 only for now + annotations = getattr(ctor, '__annotations__', {}) + has_kwargs = argspec.keywords is not None + + names = [arg for arg in argspec.args if arg != 'self'] + defaults = argspec.defaults + num_defaults = argspec.defaults is not None and len(argspec.defaults) or 0 + num_required = len(names) - num_defaults + + docs = cls.__doc__ + if docs is None and getattr(cls, '__category__', None) == 'op': + docs = cls.__call__.__doc__ + try: + docstring = doc_parse(docs) + except Exception: + docstring = None + + if docstring is None: + comments = {} + else: + comments = {} + for p in docstring.params: + match_obj = re.match('^([a-zA-Z_]+[a-zA-Z_0-9]*).*', p.arg_name) + if match_obj is not None: + comments[match_obj.group(1)] = p.description + + schema = SchemaDict() + schema.name = cls.__name__ + schema.doc = "" + if docs is not None: + start_pos = docs[0] == '\n' and 1 or 0 + schema.doc = docs[start_pos:].split("\n")[0].strip() + # XXX handle paddle's weird doc convention + if '**' == schema.doc[:2] and '**' == schema.doc[-2:]: + schema.doc = schema.doc[2:-2].strip() + schema.category = hasattr(cls, '__category__') and getattr( + cls, '__category__') or 'module' + schema.strict = not has_kwargs + schema.pymodule = importlib.import_module(cls.__module__) + schema.inject = getattr(cls, '__inject__', []) + schema.shared = getattr(cls, '__shared__', []) + for idx, name in enumerate(names): + comment = name in comments and comments[name] or name + if name in schema.inject: + type_ = None + else: + type_ = name in annotations and annotations[name] or None + value_schema = SchemaValue(name, comment, type_) + if name in schema.shared: + assert idx >= num_required, "shared config must have default value" + default = defaults[idx - num_required] + value_schema.set_default(SharedConfig(name, default)) + elif idx >= num_required: + default = defaults[idx - num_required] + value_schema.set_default(default) + schema.set_schema(name, value_schema) + + return schema diff --git a/VisualFL/depends/PaddleDetection/ppdet/core/config/yaml_helpers.py b/VisualFL/depends/PaddleDetection/ppdet/core/config/yaml_helpers.py new file mode 100755 index 000000000..1545b6be7 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/core/config/yaml_helpers.py @@ -0,0 +1,118 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect + +import yaml +from .schema import SharedConfig + +__all__ = ['serializable', 'Callable'] + + +def represent_dictionary_order(self, dict_data): + return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items()) + + +def setup_orderdict(): + from collections import OrderedDict + yaml.add_representer(OrderedDict, represent_dictionary_order) + + +def _make_python_constructor(cls): + def python_constructor(loader, node): + if isinstance(node, yaml.SequenceNode): + args = loader.construct_sequence(node, deep=True) + return cls(*args) + else: + kwargs = loader.construct_mapping(node, deep=True) + try: + return cls(**kwargs) + except Exception as ex: + print("Error when construct {} instance from yaml config". + format(cls.__name__)) + raise ex + + return python_constructor + + +def _make_python_representer(cls): + # python 2 compatibility + if hasattr(inspect, 'getfullargspec'): + argspec = inspect.getfullargspec(cls) + else: + argspec = inspect.getargspec(cls.__init__) + argnames = [arg for arg in argspec.args if arg != 'self'] + + def python_representer(dumper, obj): + if argnames: + data = {name: getattr(obj, name) for name in argnames} + else: + data = obj.__dict__ + if '_id' in data: + del data['_id'] + return dumper.represent_mapping(u'!{}'.format(cls.__name__), data) + + return python_representer + + +def serializable(cls): + """ + Add loader and dumper for given class, which must be + "trivially serializable" + + Args: + cls: class to be serialized + + Returns: cls + """ + yaml.add_constructor(u'!{}'.format(cls.__name__), + _make_python_constructor(cls)) + yaml.add_representer(cls, _make_python_representer(cls)) + return cls + + +yaml.add_representer(SharedConfig, + lambda d, o: d.represent_data(o.default_value)) + + +@serializable +class Callable(object): + """ + Helper to be used in Yaml for creating arbitrary class objects + + Args: + full_type (str): the full module path to target function + """ + + def __init__(self, full_type, args=[], kwargs={}): + super(Callable, self).__init__() + self.full_type = full_type + self.args = args + self.kwargs = kwargs + + def __call__(self): + if '.' in self.full_type: + idx = self.full_type.rfind('.') + module = importlib.import_module(self.full_type[:idx]) + func_name = self.full_type[idx + 1:] + else: + try: + module = importlib.import_module('builtins') + except Exception: + module = importlib.import_module('__builtin__') + func_name = self.full_type + + func = getattr(module, func_name) + return func(*self.args, **self.kwargs) diff --git a/VisualFL/depends/PaddleDetection/ppdet/core/workspace.py b/VisualFL/depends/PaddleDetection/ppdet/core/workspace.py new file mode 100755 index 000000000..d7074a1f4 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/core/workspace.py @@ -0,0 +1,268 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import importlib +import os +import sys + +import yaml +import copy +import collections + +from .config.schema import SchemaDict, SharedConfig, extract_schema +from .config.yaml_helpers import serializable + +__all__ = [ + 'global_config', + 'load_config', + 'merge_config', + 'get_registered_modules', + 'create', + 'register', + 'serializable', + 'dump_value', +] + + +def dump_value(value): + # XXX this is hackish, but collections.abc is not available in python 2 + if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)): + value = yaml.dump(value, default_flow_style=True) + value = value.replace('\n', '') + value = value.replace('...', '') + return "'{}'".format(value) + else: + # primitive types + return str(value) + + +class AttrDict(dict): + """Single level attribute dict, NOT recursive""" + + def __init__(self, **kwargs): + super(AttrDict, self).__init__() + super(AttrDict, self).update(kwargs) + + def __getattr__(self, key): + if key in self: + return self[key] + raise AttributeError("object has no attribute '{}'".format(key)) + + +global_config = AttrDict() + +READER_KEY = '_READER_' + + +def load_config(file_path): + """ + Load config from file. + + Args: + file_path (str): Path of the config file to be loaded. + + Returns: global config + """ + _, ext = os.path.splitext(file_path) + assert ext in ['.yml', '.yaml'], "only support yaml files for now" + + cfg = AttrDict() + with open(file_path) as f: + cfg = merge_config(yaml.load(f, Loader=yaml.Loader), cfg) + + if READER_KEY in cfg: + reader_cfg = cfg[READER_KEY] + if reader_cfg.startswith("~"): + reader_cfg = os.path.expanduser(reader_cfg) + if not reader_cfg.startswith('/'): + reader_cfg = os.path.join(os.path.dirname(file_path), reader_cfg) + + with open(reader_cfg) as f: + merge_config(yaml.load(f, Loader=yaml.Loader)) + del cfg[READER_KEY] + + merge_config(cfg) + + return global_config + + +def dict_merge(dct, merge_dct): + """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of + updating only top-level keys, dict_merge recurses down into dicts nested + to an arbitrary depth, updating keys. The ``merge_dct`` is merged into + ``dct``. + + Args: + dct: dict onto which the merge is executed + merge_dct: dct merged into dct + + Returns: dct + """ + for k, v in merge_dct.items(): + if (k in dct and isinstance(dct[k], dict) and + isinstance(merge_dct[k], collections.Mapping)): + dict_merge(dct[k], merge_dct[k]) + else: + dct[k] = merge_dct[k] + return dct + + +def merge_config(config, another_cfg=None): + """ + Merge config into global config or another_cfg. + + Args: + config (dict): Config to be merged. + + Returns: global config + """ + global global_config + dct = another_cfg if another_cfg is not None else global_config + dct = dict_merge(dct, config) + + # NOTE: training batch size defined only in TrainReader, sychornized + # batch size config to global, models can get batch size config + # from global config when building model. + # batch size in evaluation or inference can also be added here + if 'TrainReader' in dct and 'batch_size' in dct['TrainReader']: + dct['train_batch_size'] = dct['TrainReader']['batch_size'] + + return dct + + +def get_registered_modules(): + return {k: v for k, v in global_config.items() if isinstance(v, SchemaDict)} + + +def make_partial(cls): + if isinstance(cls.__op__, str): + sep = cls.__op__.split('.') + op_name = sep[-1] + op_module = importlib.import_module('.'.join(sep[:-1])) + else: + op_name = cls.__op__.__name__ + op_module = importlib.import_module(cls.__op__.__module__) + + if not hasattr(op_module, op_name): + import logging + logger = logging.getLogger(__name__) + logger.warn('{} OP not found, maybe a newer version of paddle ' + 'is required.'.format(cls.__op__)) + return cls + + op = getattr(op_module, op_name) + cls.__category__ = getattr(cls, '__category__', None) or 'op' + + def partial_apply(self, *args, **kwargs): + kwargs_ = self.__dict__.copy() + kwargs_.update(kwargs) + return op(*args, **kwargs_) + + if getattr(cls, '__append_doc__', True): # XXX should default to True? + if sys.version_info[0] > 2: + cls.__doc__ = "Wrapper for `{}` OP".format(op.__name__) + cls.__init__.__doc__ = op.__doc__ + cls.__call__ = partial_apply + cls.__call__.__doc__ = op.__doc__ + else: + # XXX work around for python 2 + partial_apply.__doc__ = op.__doc__ + cls.__call__ = partial_apply + return cls + + +def register(cls): + """ + Register a given module class. + + Args: + cls (type): Module class to be registered. + + Returns: cls + """ + if cls.__name__ in global_config: + raise ValueError("Module class already registered: {}".format( + cls.__name__)) + if hasattr(cls, '__op__'): + cls = make_partial(cls) + global_config[cls.__name__] = extract_schema(cls) + return cls + + +def create(cls_or_name, **kwargs): + """ + Create an instance of given module class. + + Args: + cls_or_name (type or str): Class of which to create instance. + + Returns: instance of type `cls_or_name` + """ + assert type(cls_or_name) in [type, str + ], "should be a class or name of a class" + name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__ + assert name in global_config and \ + isinstance(global_config[name], SchemaDict), \ + "the module {} is not registered".format(name) + config = global_config[name] + config.update(kwargs) + config.validate() + cls = getattr(config.pymodule, name) + + kwargs = {} + kwargs.update(global_config[name]) + + # parse `shared` annoation of registered modules + if getattr(config, 'shared', None): + for k in config.shared: + target_key = config[k] + shared_conf = config.schema[k].default + assert isinstance(shared_conf, SharedConfig) + if target_key is not None and not isinstance(target_key, + SharedConfig): + continue # value is given for the module + elif shared_conf.key in global_config: + # `key` is present in config + kwargs[k] = global_config[shared_conf.key] + else: + kwargs[k] = shared_conf.default_value + + # parse `inject` annoation of registered modules + if getattr(config, 'inject', None): + for k in config.inject: + target_key = config[k] + # optional dependency + if target_key is None: + continue + # also accept dictionaries and serialized objects + if isinstance(target_key, dict) or hasattr(target_key, '__dict__'): + continue + elif isinstance(target_key, str): + if target_key not in global_config: + raise ValueError("Missing injection config:", target_key) + target = global_config[target_key] + if isinstance(target, SchemaDict): + kwargs[k] = create(target_key) + elif hasattr(target, '__dict__'): # serialized object + kwargs[k] = target + else: + raise ValueError("Unsupported injection type:", target_key) + # prevent modification of global config values of reference types + # (e.g., list, dict) from within the created module instances + kwargs = copy.deepcopy(kwargs) + return cls(**kwargs) diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/data/__init__.py new file mode 100644 index 000000000..1a6576e78 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from .reader import * +from .source import * +from .transform import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/parallel_map.py b/VisualFL/depends/PaddleDetection/ppdet/data/parallel_map.py new file mode 100644 index 000000000..789fda1f2 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/parallel_map.py @@ -0,0 +1,311 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# function: +# transform samples in 'source' using 'worker' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import six +if six.PY3: + from queue import Empty +else: + from Queue import Empty + +import uuid +import logging +import signal +import threading +import traceback + +logger = logging.getLogger(__name__) + +main_pid = os.getpid() +worker_set = set() + + +class EndSignal(object): + """ signal used to notify worker to exit + """ + + def __init__(self, id, errno=0, errmsg=''): + self.id = id + self.errno = errno + self.errmsg = errmsg + + +class ParallelMap(object): + """ + Transform samples to mapped samples which is similar to + 'basic.MappedDataset', but multiple workers (threads or processes) + will be used + + Notes: + this class is not thread-safe + """ + + def __init__(self, + source, + worker, + worker_num, + bufsize=100, + use_process=False, + memsize='3G'): + self._worker_num = worker_num + self._bufsize = bufsize + self._use_process = use_process + if self._use_process and sys.platform == "win32": + logger.debug("Use multi-thread reader instead of " + "multi-process reader on Windows.") + self._use_process = False + if self._use_process and type(memsize) is str: + assert memsize[-1].lower() in ['g', 'm'], \ + "invalid param for memsize[%s], should be " \ + "ended with 'G' or 'g' or 'M' or 'm'" % (memsize) + power = 3 if memsize[-1].lower() == 'g' else 2 + self._memsize = int(memsize[:-1]) * (1024**power) + self._started = False + self._source = source + self._worker = worker + self._exit = False + self._setup() + self._souce_drained = False + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def _setup(self): + """setup input/output queues and workers """ + use_process = self._use_process + + bufsize = self._bufsize + if use_process: + from .shared_queue import SharedQueue as Queue + from multiprocessing import Process as Worker + from multiprocessing import Event + memsize = self._memsize + self._inq = Queue(bufsize, memsize=memsize) + self._outq = Queue(bufsize, memsize=memsize) + else: + if six.PY3: + from queue import Queue + else: + from Queue import Queue + from threading import Thread as Worker + from threading import Event + self._inq = Queue(bufsize) + self._outq = Queue(bufsize) + + consumer_num = self._worker_num + id = str(uuid.uuid4())[-3:] + self._producer = threading.Thread( + target=self._produce, + args=('producer-' + id, self._source, self._inq)) + self._producer.daemon = True + + self._consumers = [] + self._consumer_endsig = {} + global worker_set + for i in range(consumer_num): + consumer_id = 'consumer-' + id + '-' + str(i) + p = Worker( + target=self._consume, + args=(consumer_id, self._inq, self._outq, self._worker)) + self._consumers.append(p) + p.daemon = True + setattr(p, 'id', consumer_id) + if use_process: + worker_set.add(p) + + self._epoch = -1 + self._feeding_ev = Event() + self._produced = 0 # produced sample in self._produce + self._consumed = 0 # consumed sample in self.next + + def _produce(self, id, source, inq): + """Fetch data from source and feed it to 'inq' queue""" + endsig = EndSignal(id) + while True: + self._feeding_ev.wait() + if self._exit: + break + try: + s = source.next() + inq.put(s) + self._produced += 1 + except StopIteration: + self._souce_drained = True + self._feeding_ev.clear() + self._feeding_ev.wait() + except Exception as e: + endsig.errno = -1 + endsig.errmsg = "producer[{}] failed with error: {}" \ + .format(id, str(e)) + inq.put(endsig) + break + + def _consume(self, id, inq, outq, worker): + """Fetch data from 'inq', process it and put result to 'outq'""" + if self._use_process: + # handle SIGTERM signal to exit to prevent print stack frame + signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit()) + + endsig = EndSignal(id) + while True: + sample = inq.get() + if isinstance(sample, EndSignal): + endsig.errno = sample.errno + endsig.errmsg = "consumer[{}] exits for reason[{}]" \ + .format(id, sample.errmsg) + outq.put(endsig) + break + + try: + result = worker(sample) + outq.put(result) + except Exception as e: + endsig.errno = -2 + endsig.errmsg = "consumer[{}] failed to map with error:[{}]" \ + .format(id, str(e)) + outq.put(endsig) + break + + def drained(self): + assert self._epoch >= 0, "first epoch has not started yet" + return self._source.drained() and self._produced == self._consumed + + def stop(self): + """ notify to exit + """ + self._exit = True + self._feeding_ev.set() + for _ in range(len(self._consumers)): + self._inq.put(EndSignal(0, "notify consumers to exit")) + + def _consumer_healthy(self): + abnormal_num = 0 + for w in self._consumers: + if not w.is_alive() and w.id not in self._consumer_endsig: + abnormal_num += 1 + if self._use_process: + errmsg = "consumer[{}] exit abnormally with exitcode[{}]" \ + .format(w.pid, w.exitcode) + else: + errmsg = "consumer[{}] exit abnormally".format(w.ident) + + logger.warn(errmsg) + + if abnormal_num > 0: + logger.warn("{} consumers have exited abnormally!!!" \ + .format(abnormal_num)) + + return abnormal_num == 0 + + def next(self): + """ get next transformed sample + """ + if self._epoch < 0: + self.reset() + + if self.drained(): + raise StopIteration() + + while not self._exit: + try: + sample = self._outq.get(timeout=3) + except Empty as e: + if not self._consumer_healthy(): + raise StopIteration() + else: + continue + + if isinstance(sample, EndSignal): + self._consumer_endsig[sample.id] = sample + logger.warn("recv endsignal from outq with errmsg[{}]" \ + .format(sample.errmsg)) + + if len(self._consumer_endsig.keys()) < len(self._consumers): + self._inq.put(sample) + else: + self._exit = True + raise StopIteration("all consumers exited, no more samples") + else: + self._consumed += 1 + return sample + + raise StopIteration() + + def reset(self): + """ reset for a new epoch of samples + """ + assert not self._exit, "cannot reset for already stopped dataset" + + if self._epoch < 0: + self._epoch = 0 + for w in self._consumers: + w.start() + self._producer.start() + else: + assert self._consumer_healthy(), "cannot start another pass of data" \ + " for some consumers exited abnormally before!!!" + + if not self.drained(): + logger.warn("reset before epoch[{}] finishes".format( + self._epoch)) + self._produced = self._produced - self._consumed + else: + self._produced = 0 + + self._epoch += 1 + + assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \ + + " cannot start another epoch" + + self._source.reset() + self._souce_drained = False + self._consumed = 0 + self._feeding_ev.set() + + +# FIXME: fix me if you have better impliment +# handle terminate reader process, do not print stack frame +signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit()) + + +# FIXME(dkp): KeyboardInterrupt should be handled inside ParallelMap +# and do such as: 1. exit workers 2. close queues 3. release shared +# memory, HACK KeyboardInterrupt with global signal.SIGINT handler +# here, should be refined later +def _term_workers(sig_num, frame): + global worker_set, main_pid + # only do subporcess killing in main process + if os.getpid() != main_pid: + return + + logger.info("KeyboardInterrupt: main proc {} exit, kill subprocess {}" \ + .format(os.getpid(), [w.pid for w in worker_set])) + for w in worker_set: + if w.pid is not None: + os.kill(w.pid, signal.SIGINT) + sys.exit() + + +signal.signal(signal.SIGINT, _term_workers) diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/reader.py b/VisualFL/depends/PaddleDetection/ppdet/data/reader.py new file mode 100644 index 000000000..d19653078 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/reader.py @@ -0,0 +1,457 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import copy +import functools +import collections +import traceback +import numpy as np +import logging + +from ppdet.core.workspace import register, serializable + +from .parallel_map import ParallelMap +from .transform.batch_operators import Gt2YoloTarget + +__all__ = ['Reader', 'create_reader'] + +logger = logging.getLogger(__name__) + + +class Compose(object): + def __init__(self, transforms, ctx=None): + self.transforms = transforms + self.ctx = ctx + + def __call__(self, data): + ctx = self.ctx if self.ctx else {} + for f in self.transforms: + try: + data = f(data, ctx) + except Exception as e: + stack_info = traceback.format_exc() + logger.warn("fail to map op [{}] with error: {} and stack:\n{}". + format(f, e, str(stack_info))) + raise e + return data + + +def _calc_img_weights(roidbs): + """ calculate the probabilities of each sample + """ + imgs_cls = [] + num_per_cls = {} + img_weights = [] + for i, roidb in enumerate(roidbs): + img_cls = set([k for cls in roidbs[i]['gt_class'] for k in cls]) + imgs_cls.append(img_cls) + for c in img_cls: + if c not in num_per_cls: + num_per_cls[c] = 1 + else: + num_per_cls[c] += 1 + + for i in range(len(roidbs)): + weights = 0 + for c in imgs_cls[i]: + weights += 1 / num_per_cls[c] + img_weights.append(weights) + # probabilities sum to 1 + img_weights = img_weights / np.sum(img_weights) + return img_weights + + +def _has_empty(item): + def empty(x): + if isinstance(x, np.ndarray) and x.size == 0: + return True + elif isinstance(x, collections.Sequence) and len(x) == 0: + return True + else: + return False + + if isinstance(item, collections.Sequence) and len(item) == 0: + return True + if item is None: + return True + if empty(item): + return True + return False + + +def _segm(samples): + assert 'gt_poly' in samples + segms = samples['gt_poly'] + if 'is_crowd' in samples: + is_crowd = samples['is_crowd'] + if len(segms) != 0: + assert len(segms) == is_crowd.shape[0] + + gt_masks = [] + valid = True + for i in range(len(segms)): + segm = segms[i] + gt_segm = [] + if 'is_crowd' in samples and is_crowd[i]: + gt_segm.append([[0, 0]]) + else: + for poly in segm: + if len(poly) == 0: + valid = False + break + gt_segm.append(np.array(poly).reshape(-1, 2)) + if (not valid) or len(gt_segm) == 0: + break + gt_masks.append(gt_segm) + return gt_masks + + +def batch_arrange(batch_samples, fields): + def im_shape(samples, dim=3): + # hard code + assert 'h' in samples + assert 'w' in samples + if dim == 3: # RCNN, .. + return np.array((samples['h'], samples['w'], 1), dtype=np.float32) + else: # YOLOv3, .. + return np.array((samples['h'], samples['w']), dtype=np.int32) + + arrange_batch = [] + for samples in batch_samples: + one_ins = () + for i, field in enumerate(fields): + if field == 'gt_mask': + one_ins += (_segm(samples), ) + elif field == 'im_shape': + one_ins += (im_shape(samples), ) + elif field == 'im_size': + one_ins += (im_shape(samples, 2), ) + else: + if field == 'is_difficult': + field = 'difficult' + assert field in samples, '{} not in samples'.format(field) + one_ins += (samples[field], ) + arrange_batch.append(one_ins) + return arrange_batch + + +@register +@serializable +class Reader(object): + """ + Args: + dataset (DataSet): DataSet object + sample_transforms (list of BaseOperator): a list of sample transforms + operators. + batch_transforms (list of BaseOperator): a list of batch transforms + operators. + batch_size (int): batch size. + shuffle (bool): whether shuffle dataset or not. Default False. + drop_last (bool): whether drop last batch or not. Default False. + drop_empty (bool): whether drop sample when it's gt is empty or not. + Default True. + mixup_epoch (int): mixup epoc number. Default is -1, meaning + not use mixup. + cutmix_epoch (int): cutmix epoc number. Default is -1, meaning + not use cutmix. + class_aware_sampling (bool): whether use class-aware sampling or not. + Default False. + worker_num (int): number of working threads/processes. + Default -1, meaning not use multi-threads/multi-processes. + use_process (bool): whether use multi-processes or not. + It only works when worker_num > 1. Default False. + bufsize (int): buffer size for multi-threads/multi-processes, + please note, one instance in buffer is one batch data. + memsize (str): size of shared memory used in result queue when + use_process is true. Default 3G. + inputs_def (dict): network input definition use to get input fields, + which is used to determine the order of returned data. + devices_num (int): number of devices. + num_trainers (int): number of trainers. Default 1. + """ + + def __init__(self, + dataset=None, + sample_transforms=None, + batch_transforms=None, + batch_size=1, + shuffle=False, + drop_last=False, + drop_empty=True, + mixup_epoch=-1, + cutmix_epoch=-1, + class_aware_sampling=False, + worker_num=-1, + use_process=False, + use_fine_grained_loss=False, + num_classes=80, + bufsize=-1, + memsize='3G', + inputs_def=None, + devices_num=1, + num_trainers=1): + self._dataset = dataset + self._roidbs = self._dataset.get_roidb() + self._fields = copy.deepcopy(inputs_def[ + 'fields']) if inputs_def else None + + # transform + self._sample_transforms = Compose(sample_transforms, + {'fields': self._fields}) + self._batch_transforms = None + + if use_fine_grained_loss: + for bt in batch_transforms: + if isinstance(bt, Gt2YoloTarget): + bt.num_classes = num_classes + elif batch_transforms: + batch_transforms = [ + bt for bt in batch_transforms + if not isinstance(bt, Gt2YoloTarget) + ] + + if batch_transforms: + self._batch_transforms = Compose(batch_transforms, + {'fields': self._fields}) + + # data + if inputs_def and inputs_def.get('multi_scale', False): + from ppdet.modeling.architectures.input_helper import multiscale_def + im_shape = inputs_def[ + 'image_shape'] if 'image_shape' in inputs_def else [ + 3, None, None + ] + _, ms_fields = multiscale_def(im_shape, inputs_def['num_scales'], + inputs_def['use_flip']) + self._fields += ms_fields + self._batch_size = batch_size + self._shuffle = shuffle + self._drop_last = drop_last + self._drop_empty = drop_empty + + # sampling + self._mixup_epoch = mixup_epoch // num_trainers + self._cutmix_epoch = cutmix_epoch // num_trainers + self._class_aware_sampling = class_aware_sampling + + self._load_img = False + self._sample_num = len(self._roidbs) + + if self._class_aware_sampling: + self.img_weights = _calc_img_weights(self._roidbs) + self._indexes = None + + self._pos = -1 + self._epoch = -1 + + self._curr_iter = 0 + + # multi-process + self._worker_num = worker_num + self._parallel = None + if self._worker_num > -1: + task = functools.partial(self.worker, self._drop_empty) + bufsize = devices_num * 2 if bufsize == -1 else bufsize + self._parallel = ParallelMap(self, task, worker_num, bufsize, + use_process, memsize) + + def __call__(self): + if self._worker_num > -1: + return self._parallel + else: + return self + + def __iter__(self): + return self + + def reset(self): + """implementation of Dataset.reset + """ + if self._epoch < 0: + self._epoch = 0 + else: + self._epoch += 1 + + self.indexes = [i for i in range(self.size())] + if self._class_aware_sampling: + self.indexes = np.random.choice( + self._sample_num, + self._sample_num, + replace=True, + p=self.img_weights) + + if self._shuffle: + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) + np.random.seed(self._epoch + trainer_id) + np.random.shuffle(self.indexes) + + if self._mixup_epoch > 0 and len(self.indexes) < 2: + logger.debug("Disable mixup for dataset samples " + "less than 2 samples") + self._mixup_epoch = -1 + if self._cutmix_epoch > 0 and len(self.indexes) < 2: + logger.info("Disable cutmix for dataset samples " + "less than 2 samples") + self._cutmix_epoch = -1 + + self._pos = 0 + + def __next__(self): + return self.next() + + def next(self): + if self._epoch < 0: + self.reset() + if self.drained(): + raise StopIteration + batch = self._load_batch() + self._curr_iter += 1 + if self._drop_last and len(batch) < self._batch_size: + raise StopIteration + if self._worker_num > -1: + return batch + else: + return self.worker(self._drop_empty, batch) + + def _load_batch(self): + batch = [] + bs = 0 + while bs != self._batch_size: + if self._pos >= self.size(): + break + pos = self.indexes[self._pos] + sample = copy.deepcopy(self._roidbs[pos]) + sample["curr_iter"] = self._curr_iter + self._pos += 1 + + if self._drop_empty and self._fields and 'gt_bbox' in sample: + if _has_empty(sample['gt_bbox']): + #logger.warn('gt_bbox {} is empty or not valid in {}, ' + # 'drop this sample'.format( + # sample['im_file'], sample['gt_bbox'])) + continue + has_mask = 'gt_mask' in self._fields or 'gt_segm' in self._fields + if self._drop_empty and self._fields and has_mask: + if _has_empty(_segm(sample)): + #logger.warn('gt_mask is empty or not valid in {}'.format( + # sample['im_file'])) + continue + + if self._load_img: + sample['image'] = self._load_image(sample['im_file']) + + if self._epoch < self._mixup_epoch: + num = len(self.indexes) + mix_idx = np.random.randint(1, num) + mix_idx = self.indexes[(mix_idx + self._pos - 1) % num] + sample['mixup'] = copy.deepcopy(self._roidbs[mix_idx]) + sample['mixup']["curr_iter"] = self._curr_iter + if self._load_img: + sample['mixup']['image'] = self._load_image(sample['mixup'][ + 'im_file']) + if self._epoch < self._cutmix_epoch: + num = len(self.indexes) + mix_idx = np.random.randint(1, num) + sample['cutmix'] = copy.deepcopy(self._roidbs[mix_idx]) + sample['cutmix']["curr_iter"] = self._curr_iter + if self._load_img: + sample['cutmix']['image'] = self._load_image(sample[ + 'cutmix']['im_file']) + + batch.append(sample) + bs += 1 + return batch + + def worker(self, drop_empty=True, batch_samples=None): + """ + sample transform and batch transform. + """ + batch = [] + for sample in batch_samples: + sample = self._sample_transforms(sample) + if drop_empty and 'gt_bbox' in sample: + if _has_empty(sample['gt_bbox']): + #logger.warn('gt_bbox {} is empty or not valid in {}, ' + # 'drop this sample'.format( + # sample['im_file'], sample['gt_bbox'])) + continue + batch.append(sample) + if len(batch) > 0 and self._batch_transforms: + batch = self._batch_transforms(batch) + if len(batch) > 0 and self._fields: + batch = batch_arrange(batch, self._fields) + return batch + + def _load_image(self, filename): + with open(filename, 'rb') as f: + return f.read() + + def size(self): + """ implementation of Dataset.size + """ + return self._sample_num + + def drained(self): + """ implementation of Dataset.drained + """ + assert self._epoch >= 0, 'The first epoch has not begin!' + return self._pos >= self.size() + + def stop(self): + if self._parallel: + self._parallel.stop() + + +def create_reader(cfg, + max_iter=0, + global_cfg=None, + devices_num=1, + num_trainers=1): + """ + Return iterable data reader. + + Args: + max_iter (int): number of iterations. + """ + if not isinstance(cfg, dict): + raise TypeError("The config should be a dict when creating reader.") + + # synchornize use_fine_grained_loss/num_classes from global_cfg to reader cfg + if global_cfg: + cfg['use_fine_grained_loss'] = getattr(global_cfg, + 'use_fine_grained_loss', False) + cfg['num_classes'] = getattr(global_cfg, 'num_classes', 80) + cfg['devices_num'] = devices_num + cfg['num_trainers'] = num_trainers + reader = Reader(**cfg)() + + def _reader(): + n = 0 + while True: + for _batch in reader: + if len(_batch) > 0: + yield _batch + n += 1 + if max_iter > 0 and n == max_iter: + return + reader.reset() + if max_iter <= 0: + return + + return _reader diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/shared_queue/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/data/shared_queue/__init__.py new file mode 100644 index 000000000..f118eb76a --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/shared_queue/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +__all__ = ['SharedBuffer', 'SharedMemoryMgr', 'SharedQueue'] + +from .sharedmemory import SharedBuffer +from .sharedmemory import SharedMemoryMgr +from .sharedmemory import SharedMemoryError +from .queue import SharedQueue diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/shared_queue/queue.py b/VisualFL/depends/PaddleDetection/ppdet/data/shared_queue/queue.py new file mode 100644 index 000000000..8f0ba8ab4 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/shared_queue/queue.py @@ -0,0 +1,106 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +if six.PY3: + import pickle + from io import BytesIO as StringIO + from queue import Empty +else: + import cPickle as pickle + from cStringIO import StringIO + from Queue import Empty + +import logging +import traceback +import multiprocessing as mp +from multiprocessing.queues import Queue +from .sharedmemory import SharedMemoryMgr + +logger = logging.getLogger(__name__) + + +class SharedQueueError(ValueError): + """ SharedQueueError + """ + pass + + +class SharedQueue(Queue): + """ a Queue based on shared memory to communicate data between Process, + and it's interface is compatible with 'multiprocessing.queues.Queue' + """ + + def __init__(self, maxsize=0, mem_mgr=None, memsize=None, pagesize=None): + """ init + """ + if six.PY3: + super(SharedQueue, self).__init__(maxsize, ctx=mp.get_context()) + else: + super(SharedQueue, self).__init__(maxsize) + + if mem_mgr is not None: + self._shared_mem = mem_mgr + else: + self._shared_mem = SharedMemoryMgr( + capacity=memsize, pagesize=pagesize) + + def put(self, obj, **kwargs): + """ put an object to this queue + """ + obj = pickle.dumps(obj, -1) + buff = None + try: + buff = self._shared_mem.malloc(len(obj)) + buff.put(obj) + super(SharedQueue, self).put(buff, **kwargs) + except Exception as e: + stack_info = traceback.format_exc() + err_msg = 'failed to put a element to SharedQueue '\ + 'with stack info[%s]' % (stack_info) + logger.warn(err_msg) + + if buff is not None: + buff.free() + raise e + + def get(self, **kwargs): + """ get an object from this queue + """ + buff = None + try: + buff = super(SharedQueue, self).get(**kwargs) + data = buff.get() + return pickle.load(StringIO(data)) + except Empty as e: + raise e + except Exception as e: + stack_info = traceback.format_exc() + err_msg = 'failed to get element from SharedQueue '\ + 'with stack info[%s]' % (stack_info) + logger.warn(err_msg) + raise e + finally: + if buff is not None: + buff.free() + + def release(self): + self._shared_mem.release() + self._shared_mem = None diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/shared_queue/sharedmemory.py b/VisualFL/depends/PaddleDetection/ppdet/data/shared_queue/sharedmemory.py new file mode 100644 index 000000000..8b1d3ab40 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/shared_queue/sharedmemory.py @@ -0,0 +1,532 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# utils for memory management which is allocated on sharedmemory, +# note that these structures may not be thread-safe + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import time +import math +import struct +import sys +import six + +if six.PY3: + import pickle +else: + import cPickle as pickle + +import json +import uuid +import random +import numpy as np +import weakref +import logging +from multiprocessing import Lock +from multiprocessing import RawArray + +logger = logging.getLogger(__name__) + + +class SharedMemoryError(ValueError): + """ SharedMemoryError + """ + pass + + +class SharedBufferError(SharedMemoryError): + """ SharedBufferError + """ + pass + + +class MemoryFullError(SharedMemoryError): + """ MemoryFullError + """ + + def __init__(self, errmsg=''): + super(MemoryFullError, self).__init__() + self.errmsg = errmsg + + +def memcopy(dst, src, offset=0, length=None): + """ copy data from 'src' to 'dst' in bytes + """ + length = length if length is not None else len(src) + assert type(dst) == np.ndarray, 'invalid type for "dst" in memcopy' + if type(src) is not np.ndarray: + if type(src) is str and six.PY3: + src = src.encode() + src = np.frombuffer(src, dtype='uint8', count=len(src)) + + dst[:] = src[offset:offset + length] + + +class SharedBuffer(object): + """ Buffer allocated from SharedMemoryMgr, and it stores data on shared memory + + note that: + every instance of this should be freed explicitely by calling 'self.free' + """ + + def __init__(self, owner, capacity, pos, size=0, alloc_status=''): + """ Init + + Args: + owner (str): manager to own this buffer + capacity (int): capacity in bytes for this buffer + pos (int): page position in shared memory + size (int): bytes already used + alloc_status (str): debug info about allocator when allocate this + """ + self._owner = owner + self._cap = capacity + self._pos = pos + self._size = size + self._alloc_status = alloc_status + assert self._pos >= 0 and self._cap > 0, \ + "invalid params[%d:%d] to construct SharedBuffer" \ + % (self._pos, self._cap) + + def owner(self): + """ get owner + """ + return SharedMemoryMgr.get_mgr(self._owner) + + def put(self, data, override=False): + """ put data to this buffer + + Args: + data (str): data to be stored in this buffer + + Returns: + None + + Raises: + SharedMemoryError when not enough space in this buffer + """ + assert type(data) in [str, bytes], \ + 'invalid type[%s] for SharedBuffer::put' % (str(type(data))) + if self._size > 0 and not override: + raise SharedBufferError('already has already been setted before') + + if self.capacity() < len(data): + raise SharedBufferError('data[%d] is larger than size of buffer[%s]'\ + % (len(data), str(self))) + + self.owner().put_data(self, data) + self._size = len(data) + + def get(self, offset=0, size=None, no_copy=True): + """ get the data stored this buffer + + Args: + offset (int): position for the start point to 'get' + size (int): size to get + + Returns: + data (np.ndarray('uint8')): user's data in numpy + which is passed in by 'put' + None: if no data stored in + """ + offset = offset if offset >= 0 else self._size + offset + if self._size <= 0: + return None + + size = self._size if size is None else size + assert offset + size <= self._cap, 'invalid offset[%d] '\ + 'or size[%d] for capacity[%d]' % (offset, size, self._cap) + return self.owner().get_data(self, offset, size, no_copy=no_copy) + + def size(self): + """ bytes of used memory + """ + return self._size + + def resize(self, size): + """ resize the used memory to 'size', should not be greater than capacity + """ + assert size >= 0 and size <= self._cap, \ + "invalid size[%d] for resize" % (size) + + self._size = size + + def capacity(self): + """ size of allocated memory + """ + return self._cap + + def __str__(self): + """ human readable format + """ + return "SharedBuffer(owner:%s, pos:%d, size:%d, "\ + "capacity:%d, alloc_status:[%s], pid:%d)" \ + % (str(self._owner), self._pos, self._size, \ + self._cap, self._alloc_status, os.getpid()) + + def free(self): + """ free this buffer to it's owner + """ + if self._owner is not None: + self.owner().free(self) + self._owner = None + self._cap = 0 + self._pos = -1 + self._size = 0 + return True + else: + return False + + +class PageAllocator(object): + """ allocator used to malloc and free shared memory which + is split into pages + """ + s_allocator_header = 12 + + def __init__(self, base, total_pages, page_size): + """ init + """ + self._magic_num = 1234321000 + random.randint(100, 999) + self._base = base + self._total_pages = total_pages + self._page_size = page_size + + header_pages = int( + math.ceil((total_pages + self.s_allocator_header) / page_size)) + + self._header_pages = header_pages + self._free_pages = total_pages - header_pages + self._header_size = self._header_pages * page_size + self._reset() + + def _dump_alloc_info(self, fname): + hpages, tpages, pos, used = self.header() + + start = self.s_allocator_header + end = start + self._page_size * hpages + alloc_flags = self._base[start:end].tostring() + info = { + 'magic_num': self._magic_num, + 'header_pages': hpages, + 'total_pages': tpages, + 'pos': pos, + 'used': used + } + info['alloc_flags'] = alloc_flags + fname = fname + '.' + str(uuid.uuid4())[:6] + with open(fname, 'wb') as f: + f.write(pickle.dumps(info, -1)) + logger.warn('dump alloc info to file[%s]' % (fname)) + + def _reset(self): + alloc_page_pos = self._header_pages + used_pages = self._header_pages + header_info = struct.pack( + str('III'), self._magic_num, alloc_page_pos, used_pages) + assert len(header_info) == self.s_allocator_header, \ + 'invalid size of header_info' + + memcopy(self._base[0:self.s_allocator_header], header_info) + self.set_page_status(0, self._header_pages, '1') + self.set_page_status(self._header_pages, self._free_pages, '0') + + def header(self): + """ get header info of this allocator + """ + header_str = self._base[0:self.s_allocator_header].tostring() + magic, pos, used = struct.unpack(str('III'), header_str) + + assert magic == self._magic_num, \ + 'invalid header magic[%d] in shared memory' % (magic) + return self._header_pages, self._total_pages, pos, used + + def empty(self): + """ are all allocatable pages available + """ + header_pages, pages, pos, used = self.header() + return header_pages == used + + def full(self): + """ are all allocatable pages used + """ + header_pages, pages, pos, used = self.header() + return header_pages + used == pages + + def __str__(self): + header_pages, pages, pos, used = self.header() + desc = '{page_info[magic:%d,total:%d,used:%d,header:%d,alloc_pos:%d,pagesize:%d]}' \ + % (self._magic_num, pages, used, header_pages, pos, self._page_size) + return 'PageAllocator:%s' % (desc) + + def set_alloc_info(self, alloc_pos, used_pages): + """ set allocating position to new value + """ + memcopy(self._base[4:12], struct.pack(str('II'), alloc_pos, used_pages)) + + def set_page_status(self, start, page_num, status): + """ set pages from 'start' to 'end' with new same status 'status' + """ + assert status in ['0', '1'], 'invalid status[%s] for page status '\ + 'in allocator[%s]' % (status, str(self)) + start += self.s_allocator_header + end = start + page_num + assert start >= 0 and end <= self._header_size, 'invalid end[%d] of pages '\ + 'in allocator[%s]' % (end, str(self)) + memcopy(self._base[start:end], str(status * page_num)) + + def get_page_status(self, start, page_num, ret_flag=False): + start += self.s_allocator_header + end = start + page_num + assert start >= 0 and end <= self._header_size, 'invalid end[%d] of pages '\ + 'in allocator[%s]' % (end, str(self)) + status = self._base[start:end].tostring().decode() + if ret_flag: + return status + + zero_num = status.count('0') + if zero_num == 0: + return (page_num, 1) + else: + return (zero_num, 0) + + def malloc_page(self, page_num): + header_pages, pages, pos, used = self.header() + end = pos + page_num + if end > pages: + pos = self._header_pages + end = pos + page_num + + start_pos = pos + flags = '' + while True: + flags = self.get_page_status(pos, page_num, ret_flag=True) + + if flags.count('0') == page_num: + break + + # not found enough pages, so shift to next few pages + free_pos = flags.rfind('1') + 1 + pos += free_pos + end = pos + page_num + if end > pages: + pos = self._header_pages + end = pos + page_num + flags = '' + + # not found available pages after scan all pages + if pos <= start_pos and end >= start_pos: + logger.debug('not found available pages after scan all pages') + break + + page_status = (flags.count('0'), 0) + if page_status != (page_num, 0): + free_pages = self._total_pages - used + if free_pages == 0: + err_msg = 'all pages have been used:%s' % (str(self)) + else: + err_msg = 'not found enough pages[avail:%d, expect:%d] '\ + 'with total free pages[%d]' % (page_status[0], page_num, free_pages) + err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] '\ + 'and allocator status[%s]' % (page_num, pos, err_msg, str(self)) + raise MemoryFullError(err_msg) + + self.set_page_status(pos, page_num, '1') + used += page_num + self.set_alloc_info(end, used) + return pos + + def free_page(self, start, page_num): + """ free 'page_num' pages start from 'start' + """ + page_status = self.get_page_status(start, page_num) + assert page_status == (page_num, 1), \ + 'invalid status[%s] when free [%d, %d]' \ + % (str(page_status), start, page_num) + self.set_page_status(start, page_num, '0') + _, _, pos, used = self.header() + used -= page_num + self.set_alloc_info(pos, used) + + +DEFAULT_SHARED_MEMORY_SIZE = 1024 * 1024 * 1024 + + +class SharedMemoryMgr(object): + """ manage a continouse block of memory, provide + 'malloc' to allocate new buffer, and 'free' to free buffer + """ + s_memory_mgrs = weakref.WeakValueDictionary() + s_mgr_num = 0 + s_log_statis = False + + @classmethod + def get_mgr(cls, id): + """ get a SharedMemoryMgr with size of 'capacity' + """ + assert id in cls.s_memory_mgrs, 'invalid id[%s] for memory managers' % ( + id) + return cls.s_memory_mgrs[id] + + def __init__(self, capacity=None, pagesize=None): + """ init + """ + logger.debug('create SharedMemoryMgr') + + pagesize = 64 * 1024 if pagesize is None else pagesize + assert type(pagesize) is int, "invalid type of pagesize[%s]" \ + % (str(pagesize)) + + capacity = DEFAULT_SHARED_MEMORY_SIZE if capacity is None else capacity + assert type(capacity) is int, "invalid type of capacity[%s]" \ + % (str(capacity)) + + assert capacity > 0, '"size of shared memory should be greater than 0' + self._released = False + self._cap = capacity + self._page_size = pagesize + + assert self._cap % self._page_size == 0, \ + "capacity[%d] and pagesize[%d] are not consistent" \ + % (self._cap, self._page_size) + self._total_pages = self._cap // self._page_size + + self._pid = os.getpid() + SharedMemoryMgr.s_mgr_num += 1 + self._id = self._pid * 100 + SharedMemoryMgr.s_mgr_num + SharedMemoryMgr.s_memory_mgrs[self._id] = self + self._locker = Lock() + self._setup() + + def _setup(self): + self._shared_mem = RawArray('c', self._cap) + self._base = np.frombuffer( + self._shared_mem, dtype='uint8', count=self._cap) + self._locker.acquire() + try: + self._allocator = PageAllocator(self._base, self._total_pages, + self._page_size) + finally: + self._locker.release() + + def malloc(self, size, wait=True): + """ malloc a new SharedBuffer + + Args: + size (int): buffer size to be malloc + wait (bool): whether to wait when no enough memory + + Returns: + SharedBuffer + + Raises: + SharedMemoryError when not found available memory + """ + page_num = int(math.ceil(size / self._page_size)) + size = page_num * self._page_size + + start = None + ct = 0 + errmsg = '' + while True: + self._locker.acquire() + try: + start = self._allocator.malloc_page(page_num) + alloc_status = str(self._allocator) + except MemoryFullError as e: + start = None + errmsg = e.errmsg + if not wait: + raise e + finally: + self._locker.release() + + if start is None: + time.sleep(0.1) + if ct % 100 == 0: + logger.warn('not enough space for reason[%s]' % (errmsg)) + + ct += 1 + else: + break + + return SharedBuffer(self._id, size, start, alloc_status=alloc_status) + + def free(self, shared_buf): + """ free a SharedBuffer + + Args: + shared_buf (SharedBuffer): buffer to be freed + + Returns: + None + + Raises: + SharedMemoryError when failed to release this buffer + """ + assert shared_buf._owner == self._id, "invalid shared_buf[%s] "\ + "for it's not allocated from me[%s]" % (str(shared_buf), str(self)) + cap = shared_buf.capacity() + start_page = shared_buf._pos + page_num = cap // self._page_size + + #maybe we don't need this lock here + self._locker.acquire() + try: + self._allocator.free_page(start_page, page_num) + finally: + self._locker.release() + + def put_data(self, shared_buf, data): + """ fill 'data' into 'shared_buf' + """ + assert len(data) <= shared_buf.capacity(), 'too large data[%d] '\ + 'for this buffer[%s]' % (len(data), str(shared_buf)) + start = shared_buf._pos * self._page_size + end = start + len(data) + assert start >= 0 and end <= self._cap, "invalid start "\ + "position[%d] when put data to buff:%s" % (start, str(shared_buf)) + self._base[start:end] = np.frombuffer(data, 'uint8', len(data)) + + def get_data(self, shared_buf, offset, size, no_copy=True): + """ extract 'data' from 'shared_buf' in range [offset, offset + size) + """ + start = shared_buf._pos * self._page_size + start += offset + if no_copy: + return self._base[start:start + size] + else: + return self._base[start:start + size].tostring() + + def __str__(self): + return 'SharedMemoryMgr:{id:%d, %s}' % (self._id, str(self._allocator)) + + def __del__(self): + if SharedMemoryMgr.s_log_statis: + logger.info('destroy [%s]' % (self)) + + if not self._released and not self._allocator.empty(): + logger.debug('not empty when delete this SharedMemoryMgr[%s]' % + (self)) + else: + self._released = True + + if self._id in SharedMemoryMgr.s_memory_mgrs: + del SharedMemoryMgr.s_memory_mgrs[self._id] + SharedMemoryMgr.s_mgr_num -= 1 diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/source/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/data/source/__init__.py new file mode 100644 index 000000000..c5c26a16f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/source/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import coco +from . import voc +from . import widerface + +from .coco import * +from .voc import * +from .widerface import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/source/coco.py b/VisualFL/depends/PaddleDetection/ppdet/data/source/coco.py new file mode 100644 index 000000000..67c561786 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/source/coco.py @@ -0,0 +1,186 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np + +from .dataset import DataSet +from ppdet.core.workspace import register, serializable + +import logging +logger = logging.getLogger(__name__) + + +@register +@serializable +class COCODataSet(DataSet): + """ + Load COCO records with annotations in json file 'anno_path' + + Args: + dataset_dir (str): root directory for dataset. + image_dir (str): directory for images. + anno_path (str): json file path. + sample_num (int): number of samples to load, -1 means all. + with_background (bool): whether load background as a class. + if True, total class number will be 81. default True. + """ + + def __init__(self, + image_dir=None, + anno_path=None, + dataset_dir=None, + sample_num=-1, + with_background=True, + load_semantic=False): + super(COCODataSet, self).__init__( + image_dir=image_dir, + anno_path=anno_path, + dataset_dir=dataset_dir, + sample_num=sample_num, + with_background=with_background) + self.anno_path = anno_path + self.sample_num = sample_num + self.with_background = with_background + # `roidbs` is list of dict whose structure is: + # { + # 'im_file': im_fname, # image file name + # 'im_id': img_id, # image id + # 'h': im_h, # height of image + # 'w': im_w, # width + # 'is_crowd': is_crowd, + # 'gt_score': gt_score, + # 'gt_class': gt_class, + # 'gt_bbox': gt_bbox, + # 'gt_poly': gt_poly, + # } + self.roidbs = None + # a dict used to map category name to class id + self.cname2cid = None + self.load_image_only = False + self.load_semantic = load_semantic + + def load_roidb_and_cname2cid(self): + anno_path = os.path.join(self.dataset_dir, self.anno_path) + image_dir = os.path.join(self.dataset_dir, self.image_dir) + + assert anno_path.endswith('.json'), \ + 'invalid coco annotation file: ' + anno_path + from pycocotools.coco import COCO + coco = COCO(anno_path) + img_ids = coco.getImgIds() + cat_ids = coco.getCatIds() + records = [] + ct = 0 + + # when with_background = True, mapping category to classid, like: + # background:0, first_class:1, second_class:2, ... + catid2clsid = dict({ + catid: i + int(self.with_background) + for i, catid in enumerate(cat_ids) + }) + cname2cid = dict({ + coco.loadCats(catid)[0]['name']: clsid + for catid, clsid in catid2clsid.items() + }) + + if 'annotations' not in coco.dataset: + self.load_image_only = True + logger.warn('Annotation file: {} does not contains ground truth ' + 'and load image information only.'.format(anno_path)) + + for img_id in img_ids: + img_anno = coco.loadImgs(img_id)[0] + im_fname = img_anno['file_name'] + im_w = float(img_anno['width']) + im_h = float(img_anno['height']) + + im_path = os.path.join(image_dir, + im_fname) if image_dir else im_fname + if not os.path.exists(im_path): + logger.warn('Illegal image file: {}, and it will be ' + 'ignored'.format(im_path)) + continue + + if im_w < 0 or im_h < 0: + logger.warn('Illegal width: {} or height: {} in annotation, ' + 'and im_id: {} will be ignored'.format(im_w, im_h, + img_id)) + continue + + coco_rec = { + 'im_file': im_path, + 'im_id': np.array([img_id]), + 'h': im_h, + 'w': im_w, + } + + if not self.load_image_only: + ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) + instances = coco.loadAnns(ins_anno_ids) + + bboxes = [] + for inst in instances: + x, y, box_w, box_h = inst['bbox'] + x1 = max(0, x) + y1 = max(0, y) + x2 = min(im_w - 1, x1 + max(0, box_w - 1)) + y2 = min(im_h - 1, y1 + max(0, box_h - 1)) + if x2 >= x1 and y2 >= y1: + inst['clean_bbox'] = [x1, y1, x2, y2] + bboxes.append(inst) + else: + logger.warn( + 'Found an invalid bbox in annotations: im_id: {}, ' + 'x1: {}, y1: {}, x2: {}, y2: {}.'.format( + img_id, x1, y1, x2, y2)) + num_bbox = len(bboxes) + + gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + gt_class = np.zeros((num_bbox, 1), dtype=np.int32) + gt_score = np.ones((num_bbox, 1), dtype=np.float32) + is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) + difficult = np.zeros((num_bbox, 1), dtype=np.int32) + gt_poly = [None] * num_bbox + + for i, box in enumerate(bboxes): + catid = box['category_id'] + gt_class[i][0] = catid2clsid[catid] + gt_bbox[i, :] = box['clean_bbox'] + is_crowd[i][0] = box['iscrowd'] + if 'segmentation' in box: + gt_poly[i] = box['segmentation'] + + coco_rec.update({ + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + 'gt_score': gt_score, + 'gt_poly': gt_poly, + }) + + if self.load_semantic: + seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps', + 'train2017', im_fname[:-3] + 'png') + coco_rec.update({'semantic': seg_path}) + + logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format( + im_path, img_id, im_h, im_w)) + records.append(coco_rec) + ct += 1 + if self.sample_num > 0 and ct >= self.sample_num: + break + assert len(records) > 0, 'not found any coco record in %s' % (anno_path) + logger.debug('{} samples in file {}'.format(ct, anno_path)) + self.roidbs, self.cname2cid = records, cname2cid diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/source/dataset.py b/VisualFL/depends/PaddleDetection/ppdet/data/source/dataset.py new file mode 100644 index 000000000..7cddaa93a --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/source/dataset.py @@ -0,0 +1,164 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np + +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence + +from ppdet.core.workspace import register, serializable +from ppdet.utils.download import get_dataset_path + + +@serializable +class DataSet(object): + """ + Dataset, e.g., coco, pascal voc + + Args: + annotation (str): annotation file path + image_dir (str): directory where image files are stored + shuffle (bool): shuffle samples + """ + + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + sample_num=-1, + with_background=True, + use_default_label=False, + **kwargs): + super(DataSet, self).__init__() + self.anno_path = anno_path + self.image_dir = image_dir if image_dir is not None else '' + self.dataset_dir = dataset_dir if dataset_dir is not None else '' + self.sample_num = sample_num + self.with_background = with_background + self.use_default_label = use_default_label + + self.cname2cid = None + self._imid2path = None + + def load_roidb_and_cname2cid(self): + """load dataset""" + raise NotImplementedError('%s.load_roidb_and_cname2cid not available' % + (self.__class__.__name__)) + + def get_roidb(self): + if not self.roidbs: + data_dir = get_dataset_path(self.dataset_dir, self.anno_path, + self.image_dir) + if data_dir: + self.dataset_dir = data_dir + self.load_roidb_and_cname2cid() + + return self.roidbs + + def get_cname2cid(self): + if not self.cname2cid: + self.load_roidb_and_cname2cid() + return self.cname2cid + + def get_anno(self): + if self.anno_path is None: + return + return os.path.join(self.dataset_dir, self.anno_path) + + def get_imid2path(self): + return self._imid2path + + +def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')): + return f.lower().endswith(extensions) + + +def _make_dataset(data_dir): + data_dir = os.path.expanduser(data_dir) + if not os.path.isdir(data_dir): + raise ('{} should be a dir'.format(data_dir)) + images = [] + for root, _, fnames in sorted(os.walk(data_dir, followlinks=True)): + for fname in sorted(fnames): + file_path = os.path.join(root, fname) + if _is_valid_file(file_path): + images.append(file_path) + return images + + +@register +@serializable +class ImageFolder(DataSet): + """ + Args: + dataset_dir (str): root directory for dataset. + image_dir(list|str): list of image folders or list of image files + anno_path (str): annotation file path. + samples (int): number of samples to load, -1 means all + """ + + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + sample_num=-1, + with_background=True, + use_default_label=False, + **kwargs): + super(ImageFolder, self).__init__(dataset_dir, image_dir, anno_path, + sample_num, with_background, + use_default_label) + self.roidbs = None + self._imid2path = {} + + def get_roidb(self): + if not self.roidbs: + self.roidbs = self._load_images() + return self.roidbs + + def set_images(self, images): + self.image_dir = images + self.roidbs = self._load_images() + + def _parse(self): + image_dir = self.image_dir + if not isinstance(image_dir, Sequence): + image_dir = [image_dir] + images = [] + for im_dir in image_dir: + if os.path.isdir(im_dir): + im_dir = os.path.join(self.dataset_dir, im_dir) + images.extend(_make_dataset(im_dir)) + elif os.path.isfile(im_dir) and _is_valid_file(im_dir): + images.append(im_dir) + return images + + def _load_images(self): + images = self._parse() + ct = 0 + records = [] + for image in images: + assert image != '' and os.path.isfile(image), \ + "Image {} not found".format(image) + if self.sample_num > 0 and ct >= self.sample_num: + break + rec = {'im_id': np.array([ct]), 'im_file': image} + self._imid2path[ct] = image + ct += 1 + records.append(rec) + assert len(records) > 0, "No image file found" + return records diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/source/voc.py b/VisualFL/depends/PaddleDetection/ppdet/data/source/voc.py new file mode 100644 index 000000000..84c5990c3 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/source/voc.py @@ -0,0 +1,216 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np + +import xml.etree.ElementTree as ET + +from ppdet.core.workspace import register, serializable + +from .dataset import DataSet +import logging +logger = logging.getLogger(__name__) + + +@register +@serializable +class VOCDataSet(DataSet): + """ + Load dataset with PascalVOC format. + + Notes: + `anno_path` must contains xml file and image file path for annotations. + + Args: + dataset_dir (str): root directory for dataset. + image_dir (str): directory for images. + anno_path (str): voc annotation file path. + sample_num (int): number of samples to load, -1 means all. + use_default_label (bool): whether use the default mapping of + label to integer index. Default True. + with_background (bool): whether load background as a class, + default True. + label_list (str): if use_default_label is False, will load + mapping between category and class index. + """ + + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + sample_num=-1, + use_default_label=False, + with_background=True, + label_list='label_list.txt'): + super(VOCDataSet, self).__init__( + image_dir=image_dir, + anno_path=anno_path, + sample_num=sample_num, + dataset_dir=dataset_dir, + with_background=with_background) + # roidbs is list of dict whose structure is: + # { + # 'im_file': im_fname, # image file name + # 'im_id': im_id, # image id + # 'h': im_h, # height of image + # 'w': im_w, # width + # 'is_crowd': is_crowd, + # 'gt_class': gt_class, + # 'gt_score': gt_score, + # 'gt_bbox': gt_bbox, + # 'difficult': difficult + # } + self.roidbs = None + # 'cname2id' is a dict to map category name to class id + self.cname2cid = None + self.use_default_label = use_default_label + self.label_list = label_list + + def load_roidb_and_cname2cid(self): + anno_path = os.path.join(self.dataset_dir, self.anno_path) + image_dir = os.path.join(self.dataset_dir, self.image_dir) + + # mapping category name to class id + # if with_background is True: + # background:0, first_class:1, second_class:2, ... + # if with_background is False: + # first_class:0, second_class:1, ... + records = [] + ct = 0 + cname2cid = {} + if not self.use_default_label: + label_path = os.path.join(self.dataset_dir, self.label_list) + if not os.path.exists(label_path): + raise ValueError("label_list {} does not exists".format( + label_path)) + with open(label_path, 'r') as fr: + label_id = int(self.with_background) + for line in fr.readlines(): + cname2cid[line.strip()] = label_id + label_id += 1 + else: + cname2cid = pascalvoc_label(self.with_background) + + with open(anno_path, 'r') as fr: + while True: + line = fr.readline() + if not line: + break + img_file, xml_file = [os.path.join(image_dir, x) \ + for x in line.strip().split()[:2]] + if not os.path.exists(img_file): + logger.warn( + 'Illegal image file: {}, and it will be ignored'.format( + img_file)) + continue + if not os.path.isfile(xml_file): + logger.warn('Illegal xml file: {}, and it will be ignored'. + format(xml_file)) + continue + tree = ET.parse(xml_file) + if tree.find('id') is None: + im_id = np.array([ct]) + else: + im_id = np.array([int(tree.find('id').text)]) + + objs = tree.findall('object') + im_w = float(tree.find('size').find('width').text) + im_h = float(tree.find('size').find('height').text) + if im_w < 0 or im_h < 0: + logger.warn( + 'Illegal width: {} or height: {} in annotation, ' + 'and {} will be ignored'.format(im_w, im_h, xml_file)) + continue + gt_bbox = [] + gt_class = [] + gt_score = [] + is_crowd = [] + difficult = [] + for i, obj in enumerate(objs): + cname = obj.find('name').text + _difficult = int(obj.find('difficult').text) + x1 = float(obj.find('bndbox').find('xmin').text) + y1 = float(obj.find('bndbox').find('ymin').text) + x2 = float(obj.find('bndbox').find('xmax').text) + y2 = float(obj.find('bndbox').find('ymax').text) + x1 = max(0, x1) + y1 = max(0, y1) + x2 = min(im_w - 1, x2) + y2 = min(im_h - 1, y2) + if x2 > x1 and y2 > y1: + gt_bbox.append([x1, y1, x2, y2]) + gt_class.append([cname2cid[cname]]) + gt_score.append([1.]) + is_crowd.append([0]) + difficult.append([_difficult]) + else: + logger.warn( + 'Found an invalid bbox in annotations: xml_file: {}' + ', x1: {}, y1: {}, x2: {}, y2: {}.'.format( + xml_file, x1, y1, x2, y2)) + gt_bbox = np.array(gt_bbox).astype('float32') + gt_class = np.array(gt_class).astype('int32') + gt_score = np.array(gt_score).astype('float32') + is_crowd = np.array(is_crowd).astype('int32') + difficult = np.array(difficult).astype('int32') + voc_rec = { + 'im_file': img_file, + 'im_id': im_id, + 'h': im_h, + 'w': im_w, + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_score': gt_score, + 'gt_bbox': gt_bbox, + 'difficult': difficult + } + if len(objs) != 0: + records.append(voc_rec) + + ct += 1 + if self.sample_num > 0 and ct >= self.sample_num: + break + assert len(records) > 0, 'not found any voc record in %s' % ( + self.anno_path) + logger.debug('{} samples in file {}'.format(ct, anno_path)) + self.roidbs, self.cname2cid = records, cname2cid + + +def pascalvoc_label(with_background=True): + labels_map = { + 'aeroplane': 1, + 'bicycle': 2, + 'bird': 3, + 'boat': 4, + 'bottle': 5, + 'bus': 6, + 'car': 7, + 'cat': 8, + 'chair': 9, + 'cow': 10, + 'diningtable': 11, + 'dog': 12, + 'horse': 13, + 'motorbike': 14, + 'person': 15, + 'pottedplant': 16, + 'sheep': 17, + 'sofa': 18, + 'train': 19, + 'tvmonitor': 20 + } + if not with_background: + labels_map = {k: v - 1 for k, v in labels_map.items()} + return labels_map diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/source/widerface.py b/VisualFL/depends/PaddleDetection/ppdet/data/source/widerface.py new file mode 100644 index 000000000..75da05234 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/source/widerface.py @@ -0,0 +1,176 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import logging +logger = logging.getLogger(__name__) + +from ppdet.core.workspace import register, serializable +from .dataset import DataSet + + +@register +@serializable +class WIDERFaceDataSet(DataSet): + """ + Load WiderFace records with 'anno_path' + + Args: + dataset_dir (str): root directory for dataset. + image_dir (str): directory for images. + anno_path (str): root directory for voc annotation data + sample_num (int): number of samples to load, -1 means all + with_background (bool): whether load background as a class. + if True, total class number will be 2. default True. + """ + + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + sample_num=-1, + with_background=True, + with_lmk=False): + super(WIDERFaceDataSet, self).__init__( + image_dir=image_dir, + anno_path=anno_path, + sample_num=sample_num, + dataset_dir=dataset_dir, + with_background=with_background) + self.anno_path = anno_path + self.sample_num = sample_num + self.with_background = with_background + self.roidbs = None + self.cname2cid = None + self.with_lmk = with_lmk + + def load_roidb_and_cname2cid(self): + anno_path = os.path.join(self.dataset_dir, self.anno_path) + image_dir = os.path.join(self.dataset_dir, self.image_dir) + + txt_file = anno_path + + records = [] + ct = 0 + file_lists = self._load_file_list(txt_file) + cname2cid = widerface_label(self.with_background) + + for item in file_lists: + im_fname = item[0] + im_id = np.array([ct]) + gt_bbox = np.zeros((len(item) - 1, 4), dtype=np.float32) + gt_class = np.ones((len(item) - 1, 1), dtype=np.int32) + gt_lmk_labels = np.zeros((len(item) - 1, 10), dtype=np.float32) + lmk_ignore_flag = np.zeros((len(item) - 1, 1), dtype=np.int32) + for index_box in range(len(item)): + if index_box < 1: + continue + gt_bbox[index_box - 1] = item[index_box][0] + if self.with_lmk: + gt_lmk_labels[index_box - 1] = item[index_box][1] + lmk_ignore_flag[index_box - 1] = item[index_box][2] + im_fname = os.path.join(image_dir, + im_fname) if image_dir else im_fname + widerface_rec = { + 'im_file': im_fname, + 'im_id': im_id, + 'gt_bbox': gt_bbox, + 'gt_class': gt_class, + } + if self.with_lmk: + widerface_rec['gt_keypoint'] = gt_lmk_labels + widerface_rec['keypoint_ignore'] = lmk_ignore_flag + + if len(item) != 0: + records.append(widerface_rec) + + ct += 1 + if self.sample_num > 0 and ct >= self.sample_num: + break + assert len(records) > 0, 'not found any widerface in %s' % (anno_path) + logger.debug('{} samples in file {}'.format(ct, anno_path)) + self.roidbs, self.cname2cid = records, cname2cid + + def _load_file_list(self, input_txt): + with open(input_txt, 'r') as f_dir: + lines_input_txt = f_dir.readlines() + + file_dict = {} + num_class = 0 + exts = ['jpg', 'jpeg', 'png', 'bmp'] + exts += [ext.upper() for ext in exts] + for i in range(len(lines_input_txt)): + line_txt = lines_input_txt[i].strip('\n\t\r') + split_str = line_txt.split(' ') + if len(split_str) == 1: + img_file_name = os.path.split(split_str[0])[1] + split_txt = img_file_name.split('.') + if len(split_txt) < 2: + continue + elif split_txt[-1] in exts: + if i != 0: + num_class += 1 + file_dict[num_class] = [line_txt] + else: + if len(line_txt) <= 6: + continue + result_boxs = [] + xmin = float(split_str[0]) + ymin = float(split_str[1]) + w = float(split_str[2]) + h = float(split_str[3]) + # Filter out wrong labels + if w < 0 or h < 0: + logger.warn('Illegal box with w: {}, h: {} in ' + 'img: {}, and it will be ignored'.format( + w, h, file_dict[num_class][0])) + continue + xmin = max(0, xmin) + ymin = max(0, ymin) + xmax = xmin + w + ymax = ymin + h + gt_bbox = [xmin, ymin, xmax, ymax] + result_boxs.append(gt_bbox) + if self.with_lmk: + assert len(split_str) > 18, 'When `with_lmk=True`, the number' \ + 'of characters per line in the annotation file should' \ + 'exceed 18.' + lmk0_x = float(split_str[5]) + lmk0_y = float(split_str[6]) + lmk1_x = float(split_str[8]) + lmk1_y = float(split_str[9]) + lmk2_x = float(split_str[11]) + lmk2_y = float(split_str[12]) + lmk3_x = float(split_str[14]) + lmk3_y = float(split_str[15]) + lmk4_x = float(split_str[17]) + lmk4_y = float(split_str[18]) + lmk_ignore_flag = 0 if lmk0_x == -1 else 1 + gt_lmk_label = [ + lmk0_x, lmk0_y, lmk1_x, lmk1_y, lmk2_x, lmk2_y, lmk3_x, + lmk3_y, lmk4_x, lmk4_y + ] + result_boxs.append(gt_lmk_label) + result_boxs.append(lmk_ignore_flag) + file_dict[num_class].append(result_boxs) + + return list(file_dict.values()) + + +def widerface_label(with_background=True): + labels_map = {'face': 1} + if not with_background: + labels_map = {k: v - 1 for k, v in labels_map.items()} + return labels_map diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/tests/test.yml b/VisualFL/depends/PaddleDetection/ppdet/data/tests/test.yml new file mode 100644 index 000000000..885d5ab1e --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/tests/test.yml @@ -0,0 +1,73 @@ +TrainReader: + inputs_def: + fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_mask'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + sample_num: 10 + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !RandomFlipImage + is_mask_flip: true + is_normalized: false + prob: 0.5 + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + interp: 1 + max_size: 1333 + target_size: 800 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: false + batch_size: 1 + shuffle: true + worker_num: 2 + drop_last: false + use_process: false + +EvalReader: + inputs_def: + fields: ['image', 'im_info', 'im_id'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + sample_num: 10 + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + interp: 1 + max_size: 1333 + target_size: 800 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: true + batch_size: 1 + shuffle: false + drop_last: false diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/tests/test_dataset.py b/VisualFL/depends/PaddleDetection/ppdet/data/tests/test_dataset.py new file mode 100644 index 000000000..0ab84314f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/tests/test_dataset.py @@ -0,0 +1,152 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import unittest +import sys +import logging +import random +import copy +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.data.parallel_map import ParallelMap +from ppdet.utils.check import enable_static_mode + + +class MemorySource(object): + """ memory data source for testing + """ + + def __init__(self, samples): + self._epoch = -1 + + self._pos = -1 + self._drained = False + self._samples = samples + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + if self._epoch < 0: + self.reset() + + if self._pos >= self.size(): + self._drained = True + raise StopIteration("no more data in " + str(self)) + else: + sample = copy.deepcopy(self._samples[self._pos]) + self._pos += 1 + return sample + + def reset(self): + if self._epoch < 0: + self._epoch = 0 + else: + self._epoch += 1 + + self._pos = 0 + self._drained = False + random.shuffle(self._samples) + + def size(self): + return len(self._samples) + + def drained(self): + assert self._epoch >= 0, "the first epoch has not started yet" + return self._pos >= self.size() + + def epoch_id(self): + return self._epoch + + +class TestDataset(unittest.TestCase): + """Test cases for ppdet.data.dataset + """ + + @classmethod + def setUpClass(cls): + """ setup + """ + pass + + @classmethod + def tearDownClass(cls): + """ tearDownClass """ + pass + + def test_next(self): + """ test next + """ + samples = list(range(10)) + mem_sc = MemorySource(samples) + + for i, d in enumerate(mem_sc): + self.assertTrue(d in samples) + + def test_transform_with_abnormal_worker(self): + """ test dataset transform with abnormally exit process + """ + samples = list(range(20)) + mem_sc = MemorySource(samples) + + def _worker(sample): + if sample == 3: + sys.exit(1) + + return 2 * sample + + test_worker = ParallelMap( + mem_sc, _worker, worker_num=2, use_process=True, memsize='2M') + + ct = 0 + for i, d in enumerate(test_worker): + ct += 1 + self.assertTrue(d / 2 in samples) + + self.assertEqual(len(samples) - 1, ct) + + def test_transform_with_delay_worker(self): + """ test dataset transform with delayed process + """ + samples = list(range(20)) + mem_sc = MemorySource(samples) + + def _worker(sample): + if sample == 3: + time.sleep(30) + + return 2 * sample + + test_worker = ParallelMap( + mem_sc, _worker, worker_num=2, use_process=True, memsize='2M') + + ct = 0 + for i, d in enumerate(test_worker): + ct += 1 + self.assertTrue(d / 2 in samples) + + self.assertEqual(len(samples), ct) + + +if __name__ == '__main__': + enable_static_mode() + logging.basicConfig() + unittest.main() diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/tests/test_loader.py b/VisualFL/depends/PaddleDetection/ppdet/data/tests/test_loader.py new file mode 100644 index 000000000..020c5d0e2 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/tests/test_loader.py @@ -0,0 +1,173 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import sys +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.data.source.coco import COCODataSet +from ppdet.data.reader import Reader +from ppdet.utils.download import get_path +from ppdet.utils.download import DATASET_HOME + +from ppdet.data.transform.operators import DecodeImage, ResizeImage, Permute +from ppdet.data.transform.batch_operators import PadBatch +from ppdet.utils.check import enable_static_mode + +COCO_VAL_URL = 'http://images.cocodataset.org/zips/val2017.zip' +COCO_VAL_MD5SUM = '442b8da7639aecaf257c1dceb8ba8c80' +COCO_ANNO_URL = 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip' +COCO_ANNO_MD5SUM = 'f4bbac642086de4f52a3fdda2de5fa2c' + + +class TestReader(unittest.TestCase): + @classmethod + def setUpClass(cls): + """ setup + """ + root_path = os.path.join(DATASET_HOME, 'coco') + _, _ = get_path(COCO_VAL_URL, root_path, COCO_VAL_MD5SUM) + _, _ = get_path(COCO_ANNO_URL, root_path, COCO_ANNO_MD5SUM) + cls.anno_path = 'annotations/instances_val2017.json' + cls.image_dir = 'val2017' + cls.root_path = root_path + + @classmethod + def tearDownClass(cls): + """ tearDownClass """ + pass + + def test_loader(self): + coco_loader = COCODataSet( + dataset_dir=self.root_path, + image_dir=self.image_dir, + anno_path=self.anno_path, + sample_num=10) + sample_trans = [ + DecodeImage(to_rgb=True), ResizeImage( + target_size=800, max_size=1333, interp=1), Permute(to_bgr=False) + ] + batch_trans = [PadBatch(pad_to_stride=32, use_padded_im_info=True), ] + + inputs_def = { + 'fields': [ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', + 'gt_mask' + ], + } + data_loader = Reader( + coco_loader, + sample_transforms=sample_trans, + batch_transforms=batch_trans, + batch_size=2, + shuffle=True, + drop_empty=True, + inputs_def=inputs_def)() + for i in range(2): + for samples in data_loader: + for sample in samples: + im_shape = sample[0].shape + self.assertEqual(im_shape[0], 3) + self.assertEqual(im_shape[1] % 32, 0) + self.assertEqual(im_shape[2] % 32, 0) + + im_info_shape = sample[1].shape + self.assertEqual(im_info_shape[-1], 3) + + im_id_shape = sample[2].shape + self.assertEqual(im_id_shape[-1], 1) + + gt_bbox_shape = sample[3].shape + self.assertEqual(gt_bbox_shape[-1], 4) + + gt_class_shape = sample[4].shape + self.assertEqual(gt_class_shape[-1], 1) + self.assertEqual(gt_class_shape[0], gt_bbox_shape[0]) + + is_crowd_shape = sample[5].shape + self.assertEqual(is_crowd_shape[-1], 1) + self.assertEqual(is_crowd_shape[0], gt_bbox_shape[0]) + + mask = sample[6] + self.assertEqual(len(mask), gt_bbox_shape[0]) + self.assertEqual(mask[0][0].shape[-1], 2) + data_loader.reset() + + def test_loader_multi_threads(self): + coco_loader = COCODataSet( + dataset_dir=self.root_path, + image_dir=self.image_dir, + anno_path=self.anno_path, + sample_num=10) + sample_trans = [ + DecodeImage(to_rgb=True), ResizeImage( + target_size=800, max_size=1333, interp=1), Permute(to_bgr=False) + ] + batch_trans = [PadBatch(pad_to_stride=32, use_padded_im_info=True), ] + + inputs_def = { + 'fields': [ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', + 'gt_mask' + ], + } + data_loader = Reader( + coco_loader, + sample_transforms=sample_trans, + batch_transforms=batch_trans, + batch_size=2, + shuffle=True, + drop_empty=True, + worker_num=2, + use_process=False, + bufsize=8, + inputs_def=inputs_def)() + for i in range(2): + for samples in data_loader: + for sample in samples: + im_shape = sample[0].shape + self.assertEqual(im_shape[0], 3) + self.assertEqual(im_shape[1] % 32, 0) + self.assertEqual(im_shape[2] % 32, 0) + + im_info_shape = sample[1].shape + self.assertEqual(im_info_shape[-1], 3) + + im_id_shape = sample[2].shape + self.assertEqual(im_id_shape[-1], 1) + + gt_bbox_shape = sample[3].shape + self.assertEqual(gt_bbox_shape[-1], 4) + + gt_class_shape = sample[4].shape + self.assertEqual(gt_class_shape[-1], 1) + self.assertEqual(gt_class_shape[0], gt_bbox_shape[0]) + + is_crowd_shape = sample[5].shape + self.assertEqual(is_crowd_shape[-1], 1) + self.assertEqual(is_crowd_shape[0], gt_bbox_shape[0]) + + mask = sample[6] + self.assertEqual(len(mask), gt_bbox_shape[0]) + self.assertEqual(mask[0][0].shape[-1], 2) + data_loader.reset() + + +if __name__ == '__main__': + enable_static_mode() + unittest.main() diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/tests/test_loader_yaml.py b/VisualFL/depends/PaddleDetection/ppdet/data/tests/test_loader_yaml.py new file mode 100644 index 000000000..a7c38dcd2 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/tests/test_loader_yaml.py @@ -0,0 +1,117 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import yaml +import logging +import sys +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.utils.download import get_path +from ppdet.utils.download import DATASET_HOME +from ppdet.core.workspace import load_config, merge_config + +from ppdet.data.reader import create_reader +from ppdet.utils.check import enable_static_mode + +COCO_VAL_URL = 'http://images.cocodataset.org/zips/val2017.zip' +COCO_VAL_MD5SUM = '442b8da7639aecaf257c1dceb8ba8c80' +COCO_ANNO_URL = 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip' +COCO_ANNO_MD5SUM = 'f4bbac642086de4f52a3fdda2de5fa2c' + +FORMAT = '[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +class TestReaderYAML(unittest.TestCase): + @classmethod + def setUpClass(cls): + """ setup + """ + root_path = os.path.join(DATASET_HOME, 'coco') + _, _ = get_path(COCO_VAL_URL, root_path, COCO_VAL_MD5SUM) + _, _ = get_path(COCO_ANNO_URL, root_path, COCO_ANNO_MD5SUM) + cls.anno_path = 'annotations/instances_val2017.json' + cls.image_dir = 'val2017' + cls.root_path = root_path + + @classmethod + def tearDownClass(cls): + """ tearDownClass """ + pass + + def test_loader_yaml(self): + cfg_file = 'ppdet/data/tests/test.yml' + cfg = load_config(cfg_file) + data_cfg = '[!COCODataSet {{image_dir: {0}, dataset_dir: {1}, ' \ + 'anno_path: {2}, sample_num: 10}}]'.format( + self.image_dir, self.root_path, self.anno_path) + dataset_ins = yaml.load(data_cfg, Loader=yaml.Loader) + update_train_cfg = {'TrainReader': {'dataset': dataset_ins[0]}} + update_test_cfg = {'EvalReader': {'dataset': dataset_ins[0]}} + merge_config(update_train_cfg) + merge_config(update_test_cfg) + + reader = create_reader(cfg['TrainReader'], 10)() + for samples in reader: + for sample in samples: + im_shape = sample[0].shape + self.assertEqual(im_shape[0], 3) + self.assertEqual(im_shape[1] % 32, 0) + self.assertEqual(im_shape[2] % 32, 0) + + im_info_shape = sample[1].shape + self.assertEqual(im_info_shape[-1], 3) + + im_id_shape = sample[2].shape + self.assertEqual(im_id_shape[-1], 1) + + gt_bbox_shape = sample[3].shape + self.assertEqual(gt_bbox_shape[-1], 4) + + gt_class_shape = sample[4].shape + self.assertEqual(gt_class_shape[-1], 1) + self.assertEqual(gt_class_shape[0], gt_bbox_shape[0]) + + is_crowd_shape = sample[5].shape + self.assertEqual(is_crowd_shape[-1], 1) + self.assertEqual(is_crowd_shape[0], gt_bbox_shape[0]) + + mask = sample[6] + self.assertEqual(len(mask), gt_bbox_shape[0]) + self.assertEqual(mask[0][0].shape[-1], 2) + + reader = create_reader(cfg['EvalReader'], 10)() + for samples in reader: + for sample in samples: + im_shape = sample[0].shape + self.assertEqual(im_shape[0], 3) + self.assertEqual(im_shape[1] % 32, 0) + self.assertEqual(im_shape[2] % 32, 0) + + im_info_shape = sample[1].shape + self.assertEqual(im_info_shape[-1], 3) + + im_id_shape = sample[2].shape + self.assertEqual(im_id_shape[-1], 1) + + +if __name__ == '__main__': + enable_static_mode() + unittest.main() diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/transform/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/data/transform/__init__.py new file mode 100644 index 000000000..c5deb535a --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/transform/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import operators +from . import batch_operators + +from .operators import * +from .batch_operators import * + +__all__ = [] +__all__ += registered_ops diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/transform/autoaugment_utils.py b/VisualFL/depends/PaddleDetection/ppdet/data/transform/autoaugment_utils.py new file mode 100644 index 000000000..0cd8a04ee --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/transform/autoaugment_utils.py @@ -0,0 +1,1588 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Reference: +# https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/autoaugment_utils.py +"""AutoAugment util file.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect +import math +from PIL import Image, ImageEnhance +import numpy as np +import os +import sys +import cv2 +from copy import deepcopy + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + +# Represents an invalid bounding box that is used for checking for padding +# lists of bounding box coordinates for a few augmentation operations +_INVALID_BOX = [[-1.0, -1.0, -1.0, -1.0]] + + +def policy_v0(): + """Autoaugment policy that was used in AutoAugment Detection Paper.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('TranslateX_BBox', 0.6, 4), ('Equalize', 0.8, 10)], + [('TranslateY_Only_BBoxes', 0.2, 2), ('Cutout', 0.8, 8)], + [('Sharpness', 0.0, 8), ('ShearX_BBox', 0.4, 0)], + [('ShearY_BBox', 1.0, 2), ('TranslateY_Only_BBoxes', 0.6, 6)], + [('Rotate_BBox', 0.6, 10), ('Color', 1.0, 6)], + ] + return policy + + +def policy_v1(): + """Autoaugment policy that was used in AutoAugment Detection Paper.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('TranslateX_BBox', 0.6, 4), ('Equalize', 0.8, 10)], + [('TranslateY_Only_BBoxes', 0.2, 2), ('Cutout', 0.8, 8)], + [('Sharpness', 0.0, 8), ('ShearX_BBox', 0.4, 0)], + [('ShearY_BBox', 1.0, 2), ('TranslateY_Only_BBoxes', 0.6, 6)], + [('Rotate_BBox', 0.6, 10), ('Color', 1.0, 6)], + [('Color', 0.0, 0), ('ShearX_Only_BBoxes', 0.8, 4)], + [('ShearY_Only_BBoxes', 0.8, 2), ('Flip_Only_BBoxes', 0.0, 10)], + [('Equalize', 0.6, 10), ('TranslateX_BBox', 0.2, 2)], + [('Color', 1.0, 10), ('TranslateY_Only_BBoxes', 0.4, 6)], + [('Rotate_BBox', 0.8, 10), ('Contrast', 0.0, 10)], # , + [('Cutout', 0.2, 2), ('Brightness', 0.8, 10)], + [('Color', 1.0, 6), ('Equalize', 1.0, 2)], + [('Cutout_Only_BBoxes', 0.4, 6), ('TranslateY_Only_BBoxes', 0.8, 2)], + [('Color', 0.2, 8), ('Rotate_BBox', 0.8, 10)], + [('Sharpness', 0.4, 4), ('TranslateY_Only_BBoxes', 0.0, 4)], + [('Sharpness', 1.0, 4), ('SolarizeAdd', 0.4, 4)], + [('Rotate_BBox', 1.0, 8), ('Sharpness', 0.2, 8)], + [('ShearY_BBox', 0.6, 10), ('Equalize_Only_BBoxes', 0.6, 8)], + [('ShearX_BBox', 0.2, 6), ('TranslateY_Only_BBoxes', 0.2, 10)], + [('SolarizeAdd', 0.6, 8), ('Brightness', 0.8, 10)], + ] + return policy + + +def policy_vtest(): + """Autoaugment test policy for debugging.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [[('TranslateX_BBox', 1.0, 4), ('Equalize', 1.0, 10)], ] + return policy + + +def policy_v2(): + """Additional policy that performs well on object detection.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('Color', 0.0, 6), ('Cutout', 0.6, 8), ('Sharpness', 0.4, 8)], + [('Rotate_BBox', 0.4, 8), ('Sharpness', 0.4, 2), + ('Rotate_BBox', 0.8, 10)], + [('TranslateY_BBox', 1.0, 8), ('AutoContrast', 0.8, 2)], + [('AutoContrast', 0.4, 6), ('ShearX_BBox', 0.8, 8), + ('Brightness', 0.0, 10)], + [('SolarizeAdd', 0.2, 6), ('Contrast', 0.0, 10), + ('AutoContrast', 0.6, 0)], + [('Cutout', 0.2, 0), ('Solarize', 0.8, 8), ('Color', 1.0, 4)], + [('TranslateY_BBox', 0.0, 4), ('Equalize', 0.6, 8), + ('Solarize', 0.0, 10)], + [('TranslateY_BBox', 0.2, 2), ('ShearY_BBox', 0.8, 8), + ('Rotate_BBox', 0.8, 8)], + [('Cutout', 0.8, 8), ('Brightness', 0.8, 8), ('Cutout', 0.2, 2)], + [('Color', 0.8, 4), ('TranslateY_BBox', 1.0, 6), + ('Rotate_BBox', 0.6, 6)], + [('Rotate_BBox', 0.6, 10), ('BBox_Cutout', 1.0, 4), ('Cutout', 0.2, 8)], + [('Rotate_BBox', 0.0, 0), ('Equalize', 0.6, 6), + ('ShearY_BBox', 0.6, 8)], + [('Brightness', 0.8, 8), ('AutoContrast', 0.4, 2), + ('Brightness', 0.2, 2)], + [('TranslateY_BBox', 0.4, 8), ('Solarize', 0.4, 6), + ('SolarizeAdd', 0.2, 10)], + [('Contrast', 1.0, 10), ('SolarizeAdd', 0.2, 8), ('Equalize', 0.2, 4)], + ] + return policy + + +def policy_v3(): + """"Additional policy that performs well on object detection.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('Posterize', 0.8, 2), ('TranslateX_BBox', 1.0, 8)], + [('BBox_Cutout', 0.2, 10), ('Sharpness', 1.0, 8)], + [('Rotate_BBox', 0.6, 8), ('Rotate_BBox', 0.8, 10)], + [('Equalize', 0.8, 10), ('AutoContrast', 0.2, 10)], + [('SolarizeAdd', 0.2, 2), ('TranslateY_BBox', 0.2, 8)], + [('Sharpness', 0.0, 2), ('Color', 0.4, 8)], + [('Equalize', 1.0, 8), ('TranslateY_BBox', 1.0, 8)], + [('Posterize', 0.6, 2), ('Rotate_BBox', 0.0, 10)], + [('AutoContrast', 0.6, 0), ('Rotate_BBox', 1.0, 6)], + [('Equalize', 0.0, 4), ('Cutout', 0.8, 10)], + [('Brightness', 1.0, 2), ('TranslateY_BBox', 1.0, 6)], + [('Contrast', 0.0, 2), ('ShearY_BBox', 0.8, 0)], + [('AutoContrast', 0.8, 10), ('Contrast', 0.2, 10)], + [('Rotate_BBox', 1.0, 10), ('Cutout', 1.0, 10)], + [('SolarizeAdd', 0.8, 6), ('Equalize', 0.8, 8)], + ] + return policy + + +def _equal(val1, val2, eps=1e-8): + return abs(val1 - val2) <= eps + + +def blend(image1, image2, factor): + """Blend image1 and image2 using 'factor'. + + Factor can be above 0.0. A value of 0.0 means only image1 is used. + A value of 1.0 means only image2 is used. A value between 0.0 and + 1.0 means we linearly interpolate the pixel values between the two + images. A value greater than 1.0 "extrapolates" the difference + between the two pixel values, and we clip the results to values + between 0 and 255. + + Args: + image1: An image Tensor of type uint8. + image2: An image Tensor of type uint8. + factor: A floating point value above 0.0. + + Returns: + A blended image Tensor of type uint8. + """ + if factor == 0.0: + return image1 + if factor == 1.0: + return image2 + + image1 = image1.astype(np.float32) + image2 = image2.astype(np.float32) + + difference = image2 - image1 + scaled = factor * difference + + # Do addition in float. + temp = image1 + scaled + + # Interpolate + if factor > 0.0 and factor < 1.0: + # Interpolation means we always stay within 0 and 255. + return temp.astype(np.uint8) + + # Extrapolate: + # + # We need to clip and then cast. + return np.clip(temp, a_min=0, a_max=255).astype(np.uint8) + + +def cutout(image, pad_size, replace=0): + """Apply cutout (https://arxiv.org/abs/1708.04552) to image. + + This operation applies a (2*pad_size x 2*pad_size) mask of zeros to + a random location within `img`. The pixel values filled in will be of the + value `replace`. The located where the mask will be applied is randomly + chosen uniformly over the whole image. + + Args: + image: An image Tensor of type uint8. + pad_size: Specifies how big the zero mask that will be generated is that + is applied to the image. The mask will be of size + (2*pad_size x 2*pad_size). + replace: What pixel value to fill in the image in the area that has + the cutout mask applied to it. + + Returns: + An image Tensor that is of type uint8. + Example: + img = cv2.imread( "/home/vis/gry/train/img_data/test.jpg", cv2.COLOR_BGR2RGB ) + new_img = cutout(img, pad_size=50, replace=0) + """ + image_height, image_width = image.shape[0], image.shape[1] + + cutout_center_height = np.random.randint(low=0, high=image_height) + cutout_center_width = np.random.randint(low=0, high=image_width) + + lower_pad = np.maximum(0, cutout_center_height - pad_size) + upper_pad = np.maximum(0, image_height - cutout_center_height - pad_size) + left_pad = np.maximum(0, cutout_center_width - pad_size) + right_pad = np.maximum(0, image_width - cutout_center_width - pad_size) + + cutout_shape = [ + image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad) + ] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + mask = np.pad(np.zeros( + cutout_shape, dtype=image.dtype), + padding_dims, + 'constant', + constant_values=1) + mask = np.expand_dims(mask, -1) + mask = np.tile(mask, [1, 1, 3]) + image = np.where( + np.equal(mask, 0), + np.ones_like( + image, dtype=image.dtype) * replace, + image) + return image.astype(np.uint8) + + +def solarize(image, threshold=128): + # For each pixel in the image, select the pixel + # if the value is less than the threshold. + # Otherwise, subtract 255 from the pixel. + return np.where(image < threshold, image, 255 - image) + + +def solarize_add(image, addition=0, threshold=128): + # For each pixel in the image less than threshold + # we add 'addition' amount to it and then clip the + # pixel value to be between 0 and 255. The value + # of 'addition' is between -128 and 128. + added_image = image.astype(np.int64) + addition + added_image = np.clip(added_image, a_min=0, a_max=255).astype(np.uint8) + return np.where(image < threshold, added_image, image) + + +def color(image, factor): + """use cv2 to deal""" + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + degenerate = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) + return blend(degenerate, image, factor) + + +# refer to https://github.com/4uiiurz1/pytorch-auto-augment/blob/024b2eac4140c38df8342f09998e307234cafc80/auto_augment.py#L197 +def contrast(img, factor): + img = ImageEnhance.Contrast(Image.fromarray(img)).enhance(factor) + return np.array(img) + + +def brightness(image, factor): + """Equivalent of PIL Brightness.""" + degenerate = np.zeros_like(image) + return blend(degenerate, image, factor) + + +def posterize(image, bits): + """Equivalent of PIL Posterize.""" + shift = 8 - bits + return np.left_shift(np.right_shift(image, shift), shift) + + +def rotate(image, degrees, replace): + """Rotates the image by degrees either clockwise or counterclockwise. + + Args: + image: An image Tensor of type uint8. + degrees: Float, a scalar angle in degrees to rotate all images by. If + degrees is positive the image will be rotated clockwise otherwise it will + be rotated counterclockwise. + replace: A one or three value 1D tensor to fill empty pixels caused by + the rotate operation. + + Returns: + The rotated version of image. + """ + image = wrap(image) + image = Image.fromarray(image) + image = image.rotate(degrees) + image = np.array(image, dtype=np.uint8) + return unwrap(image, replace) + + +def random_shift_bbox(image, + bbox, + pixel_scaling, + replace, + new_min_bbox_coords=None): + """Move the bbox and the image content to a slightly new random location. + + Args: + image: 3D uint8 Tensor. + bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x) + of type float that represents the normalized coordinates between 0 and 1. + The potential values for the new min corner of the bbox will be between + [old_min - pixel_scaling * bbox_height/2, + old_min - pixel_scaling * bbox_height/2]. + pixel_scaling: A float between 0 and 1 that specifies the pixel range + that the new bbox location will be sampled from. + replace: A one or three value 1D tensor to fill empty pixels. + new_min_bbox_coords: If not None, then this is a tuple that specifies the + (min_y, min_x) coordinates of the new bbox. Normally this is randomly + specified, but this allows it to be manually set. The coordinates are + the absolute coordinates between 0 and image height/width and are int32. + + Returns: + The new image that will have the shifted bbox location in it along with + the new bbox that contains the new coordinates. + """ + # Obtains image height and width and create helper clip functions. + image_height, image_width = image.shape[0], image.shape[1] + image_height = float(image_height) + image_width = float(image_width) + + def clip_y(val): + return np.clip(val, a_min=0, a_max=image_height - 1).astype(np.int32) + + def clip_x(val): + return np.clip(val, a_min=0, a_max=image_width - 1).astype(np.int32) + + # Convert bbox to pixel coordinates. + min_y = int(image_height * bbox[0]) + min_x = int(image_width * bbox[1]) + max_y = clip_y(image_height * bbox[2]) + max_x = clip_x(image_width * bbox[3]) + + bbox_height, bbox_width = (max_y - min_y + 1, max_x - min_x + 1) + image_height = int(image_height) + image_width = int(image_width) + + # Select the new min/max bbox ranges that are used for sampling the + # new min x/y coordinates of the shifted bbox. + minval_y = clip_y(min_y - np.int32(pixel_scaling * float(bbox_height) / + 2.0)) + maxval_y = clip_y(min_y + np.int32(pixel_scaling * float(bbox_height) / + 2.0)) + minval_x = clip_x(min_x - np.int32(pixel_scaling * float(bbox_width) / 2.0)) + maxval_x = clip_x(min_x + np.int32(pixel_scaling * float(bbox_width) / 2.0)) + + # Sample and calculate the new unclipped min/max coordinates of the new bbox. + if new_min_bbox_coords is None: + unclipped_new_min_y = np.random.randint( + low=minval_y, high=maxval_y, dtype=np.int32) + unclipped_new_min_x = np.random.randint( + low=minval_x, high=maxval_x, dtype=np.int32) + else: + unclipped_new_min_y, unclipped_new_min_x = ( + clip_y(new_min_bbox_coords[0]), clip_x(new_min_bbox_coords[1])) + unclipped_new_max_y = unclipped_new_min_y + bbox_height - 1 + unclipped_new_max_x = unclipped_new_min_x + bbox_width - 1 + + # Determine if any of the new bbox was shifted outside the current image. + # This is used for determining if any of the original bbox content should be + # discarded. + new_min_y, new_min_x, new_max_y, new_max_x = ( + clip_y(unclipped_new_min_y), clip_x(unclipped_new_min_x), + clip_y(unclipped_new_max_y), clip_x(unclipped_new_max_x)) + shifted_min_y = (new_min_y - unclipped_new_min_y) + min_y + shifted_max_y = max_y - (unclipped_new_max_y - new_max_y) + shifted_min_x = (new_min_x - unclipped_new_min_x) + min_x + shifted_max_x = max_x - (unclipped_new_max_x - new_max_x) + + # Create the new bbox tensor by converting pixel integer values to floats. + new_bbox = np.stack([ + float(new_min_y) / float(image_height), float(new_min_x) / + float(image_width), float(new_max_y) / float(image_height), + float(new_max_x) / float(image_width) + ]) + + # Copy the contents in the bbox and fill the old bbox location + # with gray (128). + bbox_content = image[shifted_min_y:shifted_max_y + 1, shifted_min_x: + shifted_max_x + 1, :] + + def mask_and_add_image(min_y_, min_x_, max_y_, max_x_, mask, content_tensor, + image_): + """Applies mask to bbox region in image then adds content_tensor to it.""" + mask = np.pad(mask, [[min_y_, (image_height - 1) - max_y_], + [min_x_, (image_width - 1) - max_x_], [0, 0]], + 'constant', + constant_values=1) + + content_tensor = np.pad(content_tensor, + [[min_y_, (image_height - 1) - max_y_], + [min_x_, (image_width - 1) - max_x_], [0, 0]], + 'constant', + constant_values=0) + return image_ * mask + content_tensor + + # Zero out original bbox location. + mask = np.zeros_like(image)[min_y:max_y + 1, min_x:max_x + 1, :] + grey_tensor = np.zeros_like(mask) + replace[0] + image = mask_and_add_image(min_y, min_x, max_y, max_x, mask, grey_tensor, + image) + + # Fill in bbox content to new bbox location. + mask = np.zeros_like(bbox_content) + image = mask_and_add_image(new_min_y, new_min_x, new_max_y, new_max_x, mask, + bbox_content, image) + + return image.astype(np.uint8), new_bbox + + +def _clip_bbox(min_y, min_x, max_y, max_x): + """Clip bounding box coordinates between 0 and 1. + + Args: + min_y: Normalized bbox coordinate of type float between 0 and 1. + min_x: Normalized bbox coordinate of type float between 0 and 1. + max_y: Normalized bbox coordinate of type float between 0 and 1. + max_x: Normalized bbox coordinate of type float between 0 and 1. + + Returns: + Clipped coordinate values between 0 and 1. + """ + min_y = np.clip(min_y, a_min=0, a_max=1.0) + min_x = np.clip(min_x, a_min=0, a_max=1.0) + max_y = np.clip(max_y, a_min=0, a_max=1.0) + max_x = np.clip(max_x, a_min=0, a_max=1.0) + return min_y, min_x, max_y, max_x + + +def _check_bbox_area(min_y, min_x, max_y, max_x, delta=0.05): + """Adjusts bbox coordinates to make sure the area is > 0. + + Args: + min_y: Normalized bbox coordinate of type float between 0 and 1. + min_x: Normalized bbox coordinate of type float between 0 and 1. + max_y: Normalized bbox coordinate of type float between 0 and 1. + max_x: Normalized bbox coordinate of type float between 0 and 1. + delta: Float, this is used to create a gap of size 2 * delta between + bbox min/max coordinates that are the same on the boundary. + This prevents the bbox from having an area of zero. + + Returns: + Tuple of new bbox coordinates between 0 and 1 that will now have a + guaranteed area > 0. + """ + height = max_y - min_y + width = max_x - min_x + + def _adjust_bbox_boundaries(min_coord, max_coord): + # Make sure max is never 0 and min is never 1. + max_coord = np.maximum(max_coord, 0.0 + delta) + min_coord = np.minimum(min_coord, 1.0 - delta) + return min_coord, max_coord + + if _equal(height, 0): + min_y, max_y = _adjust_bbox_boundaries(min_y, max_y) + + if _equal(width, 0): + min_x, max_x = _adjust_bbox_boundaries(min_x, max_x) + + return min_y, min_x, max_y, max_x + + +def _scale_bbox_only_op_probability(prob): + """Reduce the probability of the bbox-only operation. + + Probability is reduced so that we do not distort the content of too many + bounding boxes that are close to each other. The value of 3.0 was a chosen + hyper parameter when designing the autoaugment algorithm that we found + empirically to work well. + + Args: + prob: Float that is the probability of applying the bbox-only operation. + + Returns: + Reduced probability. + """ + return prob / 3.0 + + +def _apply_bbox_augmentation(image, bbox, augmentation_func, *args): + """Applies augmentation_func to the subsection of image indicated by bbox. + + Args: + image: 3D uint8 Tensor. + bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x) + of type float that represents the normalized coordinates between 0 and 1. + augmentation_func: Augmentation function that will be applied to the + subsection of image. + *args: Additional parameters that will be passed into augmentation_func + when it is called. + + Returns: + A modified version of image, where the bbox location in the image will + have `ugmentation_func applied to it. + """ + image_height = image.shape[0] + image_width = image.shape[1] + + min_y = int(image_height * bbox[0]) + min_x = int(image_width * bbox[1]) + max_y = int(image_height * bbox[2]) + max_x = int(image_width * bbox[3]) + + # Clip to be sure the max values do not fall out of range. + max_y = np.minimum(max_y, image_height - 1) + max_x = np.minimum(max_x, image_width - 1) + + # Get the sub-tensor that is the image within the bounding box region. + bbox_content = image[min_y:max_y + 1, min_x:max_x + 1, :] + + # Apply the augmentation function to the bbox portion of the image. + augmented_bbox_content = augmentation_func(bbox_content, *args) + + # Pad the augmented_bbox_content and the mask to match the shape of original + # image. + augmented_bbox_content = np.pad( + augmented_bbox_content, [[min_y, (image_height - 1) - max_y], + [min_x, (image_width - 1) - max_x], [0, 0]], + 'constant', + constant_values=1) + + # Create a mask that will be used to zero out a part of the original image. + mask_tensor = np.zeros_like(bbox_content) + + mask_tensor = np.pad(mask_tensor, + [[min_y, (image_height - 1) - max_y], + [min_x, (image_width - 1) - max_x], [0, 0]], + 'constant', + constant_values=1) + # Replace the old bbox content with the new augmented content. + image = image * mask_tensor + augmented_bbox_content + return image.astype(np.uint8) + + +def _concat_bbox(bbox, bboxes): + """Helper function that concates bbox to bboxes along the first dimension.""" + + # Note if all elements in bboxes are -1 (_INVALID_BOX), then this means + # we discard bboxes and start the bboxes Tensor with the current bbox. + bboxes_sum_check = np.sum(bboxes) + bbox = np.expand_dims(bbox, 0) + # This check will be true when it is an _INVALID_BOX + if _equal(bboxes_sum_check, -4): + bboxes = bbox + else: + bboxes = np.concatenate([bboxes, bbox], 0) + return bboxes + + +def _apply_bbox_augmentation_wrapper(image, bbox, new_bboxes, prob, + augmentation_func, func_changes_bbox, + *args): + """Applies _apply_bbox_augmentation with probability prob. + + Args: + image: 3D uint8 Tensor. + bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x) + of type float that represents the normalized coordinates between 0 and 1. + new_bboxes: 2D Tensor that is a list of the bboxes in the image after they + have been altered by aug_func. These will only be changed when + func_changes_bbox is set to true. Each bbox has 4 elements + (min_y, min_x, max_y, max_x) of type float that are the normalized + bbox coordinates between 0 and 1. + prob: Float that is the probability of applying _apply_bbox_augmentation. + augmentation_func: Augmentation function that will be applied to the + subsection of image. + func_changes_bbox: Boolean. Does augmentation_func return bbox in addition + to image. + *args: Additional parameters that will be passed into augmentation_func + when it is called. + + Returns: + A tuple. Fist element is a modified version of image, where the bbox + location in the image will have augmentation_func applied to it if it is + chosen to be called with probability `prob`. The second element is a + Tensor of Tensors of length 4 that will contain the altered bbox after + applying augmentation_func. + """ + should_apply_op = (np.random.rand() + prob >= 1) + if func_changes_bbox: + if should_apply_op: + augmented_image, bbox = augmentation_func(image, bbox, *args) + else: + augmented_image, bbox = (image, bbox) + else: + if should_apply_op: + augmented_image = _apply_bbox_augmentation(image, bbox, + augmentation_func, *args) + else: + augmented_image = image + new_bboxes = _concat_bbox(bbox, new_bboxes) + return augmented_image.astype(np.uint8), new_bboxes + + +def _apply_multi_bbox_augmentation(image, bboxes, prob, aug_func, + func_changes_bbox, *args): + """Applies aug_func to the image for each bbox in bboxes. + + Args: + image: 3D uint8 Tensor. + bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox + has 4 elements (min_y, min_x, max_y, max_x) of type float. + prob: Float that is the probability of applying aug_func to a specific + bounding box within the image. + aug_func: Augmentation function that will be applied to the + subsections of image indicated by the bbox values in bboxes. + func_changes_bbox: Boolean. Does augmentation_func return bbox in addition + to image. + *args: Additional parameters that will be passed into augmentation_func + when it is called. + + Returns: + A modified version of image, where each bbox location in the image will + have augmentation_func applied to it if it is chosen to be called with + probability prob independently across all bboxes. Also the final + bboxes are returned that will be unchanged if func_changes_bbox is set to + false and if true, the new altered ones will be returned. + """ + # Will keep track of the new altered bboxes after aug_func is repeatedly + # applied. The -1 values are a dummy value and this first Tensor will be + # removed upon appending the first real bbox. + new_bboxes = np.array(_INVALID_BOX) + + # If the bboxes are empty, then just give it _INVALID_BOX. The result + # will be thrown away. + bboxes = np.array((_INVALID_BOX)) if bboxes.size == 0 else bboxes + + assert bboxes.shape[1] == 4, "bboxes.shape[1] must be 4!!!!" + + # pylint:disable=g-long-lambda + # pylint:disable=line-too-long + wrapped_aug_func = lambda _image, bbox, _new_bboxes: _apply_bbox_augmentation_wrapper(_image, bbox, _new_bboxes, prob, aug_func, func_changes_bbox, *args) + # pylint:enable=g-long-lambda + # pylint:enable=line-too-long + + # Setup the while_loop. + num_bboxes = bboxes.shape[0] # We loop until we go over all bboxes. + idx = 0 # Counter for the while loop. + + # Conditional function when to end the loop once we go over all bboxes + # images_and_bboxes contain (_image, _new_bboxes) + def cond(_idx, _images_and_bboxes): + return _idx < num_bboxes + + # Shuffle the bboxes so that the augmentation order is not deterministic if + # we are not changing the bboxes with aug_func. + # if not func_changes_bbox: + # print(bboxes) + # loop_bboxes = np.take(bboxes,np.random.permutation(bboxes.shape[0]),axis=0) + # print(loop_bboxes) + # else: + # loop_bboxes = bboxes + # we can not shuffle the bbox because it does not contain class information here + loop_bboxes = deepcopy(bboxes) + + # Main function of while_loop where we repeatedly apply augmentation on the + # bboxes in the image. + # pylint:disable=g-long-lambda + body = lambda _idx, _images_and_bboxes: [ + _idx + 1, wrapped_aug_func(_images_and_bboxes[0], + loop_bboxes[_idx], + _images_and_bboxes[1])] + while (cond(idx, (image, new_bboxes))): + idx, (image, new_bboxes) = body(idx, (image, new_bboxes)) + + # Either return the altered bboxes or the original ones depending on if + # we altered them in anyway. + if func_changes_bbox: + final_bboxes = new_bboxes + else: + final_bboxes = bboxes + return image, final_bboxes + + +def _apply_multi_bbox_augmentation_wrapper(image, bboxes, prob, aug_func, + func_changes_bbox, *args): + """Checks to be sure num bboxes > 0 before calling inner function.""" + num_bboxes = len(bboxes) + new_image = deepcopy(image) + new_bboxes = deepcopy(bboxes) + if num_bboxes != 0: + new_image, new_bboxes = _apply_multi_bbox_augmentation( + new_image, new_bboxes, prob, aug_func, func_changes_bbox, *args) + return new_image, new_bboxes + + +def rotate_only_bboxes(image, bboxes, prob, degrees, replace): + """Apply rotate to each bbox in the image with probability prob.""" + func_changes_bbox = False + prob = _scale_bbox_only_op_probability(prob) + return _apply_multi_bbox_augmentation_wrapper( + image, bboxes, prob, rotate, func_changes_bbox, degrees, replace) + + +def shear_x_only_bboxes(image, bboxes, prob, level, replace): + """Apply shear_x to each bbox in the image with probability prob.""" + func_changes_bbox = False + prob = _scale_bbox_only_op_probability(prob) + return _apply_multi_bbox_augmentation_wrapper( + image, bboxes, prob, shear_x, func_changes_bbox, level, replace) + + +def shear_y_only_bboxes(image, bboxes, prob, level, replace): + """Apply shear_y to each bbox in the image with probability prob.""" + func_changes_bbox = False + prob = _scale_bbox_only_op_probability(prob) + return _apply_multi_bbox_augmentation_wrapper( + image, bboxes, prob, shear_y, func_changes_bbox, level, replace) + + +def translate_x_only_bboxes(image, bboxes, prob, pixels, replace): + """Apply translate_x to each bbox in the image with probability prob.""" + func_changes_bbox = False + prob = _scale_bbox_only_op_probability(prob) + return _apply_multi_bbox_augmentation_wrapper( + image, bboxes, prob, translate_x, func_changes_bbox, pixels, replace) + + +def translate_y_only_bboxes(image, bboxes, prob, pixels, replace): + """Apply translate_y to each bbox in the image with probability prob.""" + func_changes_bbox = False + prob = _scale_bbox_only_op_probability(prob) + return _apply_multi_bbox_augmentation_wrapper( + image, bboxes, prob, translate_y, func_changes_bbox, pixels, replace) + + +def flip_only_bboxes(image, bboxes, prob): + """Apply flip_lr to each bbox in the image with probability prob.""" + func_changes_bbox = False + prob = _scale_bbox_only_op_probability(prob) + return _apply_multi_bbox_augmentation_wrapper(image, bboxes, prob, + np.fliplr, func_changes_bbox) + + +def solarize_only_bboxes(image, bboxes, prob, threshold): + """Apply solarize to each bbox in the image with probability prob.""" + func_changes_bbox = False + prob = _scale_bbox_only_op_probability(prob) + return _apply_multi_bbox_augmentation_wrapper(image, bboxes, prob, solarize, + func_changes_bbox, threshold) + + +def equalize_only_bboxes(image, bboxes, prob): + """Apply equalize to each bbox in the image with probability prob.""" + func_changes_bbox = False + prob = _scale_bbox_only_op_probability(prob) + return _apply_multi_bbox_augmentation_wrapper(image, bboxes, prob, equalize, + func_changes_bbox) + + +def cutout_only_bboxes(image, bboxes, prob, pad_size, replace): + """Apply cutout to each bbox in the image with probability prob.""" + func_changes_bbox = False + prob = _scale_bbox_only_op_probability(prob) + return _apply_multi_bbox_augmentation_wrapper( + image, bboxes, prob, cutout, func_changes_bbox, pad_size, replace) + + +def _rotate_bbox(bbox, image_height, image_width, degrees): + """Rotates the bbox coordinated by degrees. + + Args: + bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x) + of type float that represents the normalized coordinates between 0 and 1. + image_height: Int, height of the image. + image_width: Int, height of the image. + degrees: Float, a scalar angle in degrees to rotate all images by. If + degrees is positive the image will be rotated clockwise otherwise it will + be rotated counterclockwise. + + Returns: + A tensor of the same shape as bbox, but now with the rotated coordinates. + """ + image_height, image_width = (float(image_height), float(image_width)) + + # Convert from degrees to radians. + degrees_to_radians = math.pi / 180.0 + radians = degrees * degrees_to_radians + + # Translate the bbox to the center of the image and turn the normalized 0-1 + # coordinates to absolute pixel locations. + # Y coordinates are made negative as the y axis of images goes down with + # increasing pixel values, so we negate to make sure x axis and y axis points + # are in the traditionally positive direction. + min_y = -int(image_height * (bbox[0] - 0.5)) + min_x = int(image_width * (bbox[1] - 0.5)) + max_y = -int(image_height * (bbox[2] - 0.5)) + max_x = int(image_width * (bbox[3] - 0.5)) + coordinates = np.stack([[min_y, min_x], [min_y, max_x], [max_y, min_x], + [max_y, max_x]]).astype(np.float32) + # Rotate the coordinates according to the rotation matrix clockwise if + # radians is positive, else negative + rotation_matrix = np.stack([[math.cos(radians), math.sin(radians)], + [-math.sin(radians), math.cos(radians)]]) + new_coords = np.matmul(rotation_matrix, + np.transpose(coordinates)).astype(np.int32) + + # Find min/max values and convert them back to normalized 0-1 floats. + min_y = -(float(np.max(new_coords[0, :])) / image_height - 0.5) + min_x = float(np.min(new_coords[1, :])) / image_width + 0.5 + max_y = -(float(np.min(new_coords[0, :])) / image_height - 0.5) + max_x = float(np.max(new_coords[1, :])) / image_width + 0.5 + + # Clip the bboxes to be sure the fall between [0, 1]. + min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x) + min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x) + return np.stack([min_y, min_x, max_y, max_x]) + + +def rotate_with_bboxes(image, bboxes, degrees, replace): + # Rotate the image. + image = rotate(image, degrees, replace) + + # Convert bbox coordinates to pixel values. + image_height, image_width = image.shape[:2] + # pylint:disable=g-long-lambda + wrapped_rotate_bbox = lambda bbox: _rotate_bbox(bbox, image_height, image_width, degrees) + # pylint:enable=g-long-lambda + new_bboxes = np.zeros_like(bboxes) + for idx in range(len(bboxes)): + new_bboxes[idx] = wrapped_rotate_bbox(bboxes[idx]) + return image, new_bboxes + + +def translate_x(image, pixels, replace): + """Equivalent of PIL Translate in X dimension.""" + image = Image.fromarray(wrap(image)) + image = image.transform(image.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0)) + return unwrap(np.array(image), replace) + + +def translate_y(image, pixels, replace): + """Equivalent of PIL Translate in Y dimension.""" + image = Image.fromarray(wrap(image)) + image = image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels)) + return unwrap(np.array(image), replace) + + +def _shift_bbox(bbox, image_height, image_width, pixels, shift_horizontal): + """Shifts the bbox coordinates by pixels. + + Args: + bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x) + of type float that represents the normalized coordinates between 0 and 1. + image_height: Int, height of the image. + image_width: Int, width of the image. + pixels: An int. How many pixels to shift the bbox. + shift_horizontal: Boolean. If true then shift in X dimension else shift in + Y dimension. + + Returns: + A tensor of the same shape as bbox, but now with the shifted coordinates. + """ + pixels = int(pixels) + # Convert bbox to integer pixel locations. + min_y = int(float(image_height) * bbox[0]) + min_x = int(float(image_width) * bbox[1]) + max_y = int(float(image_height) * bbox[2]) + max_x = int(float(image_width) * bbox[3]) + + if shift_horizontal: + min_x = np.maximum(0, min_x - pixels) + max_x = np.minimum(image_width, max_x - pixels) + else: + min_y = np.maximum(0, min_y - pixels) + max_y = np.minimum(image_height, max_y - pixels) + + # Convert bbox back to floats. + min_y = float(min_y) / float(image_height) + min_x = float(min_x) / float(image_width) + max_y = float(max_y) / float(image_height) + max_x = float(max_x) / float(image_width) + + # Clip the bboxes to be sure the fall between [0, 1]. + min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x) + min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x) + return np.stack([min_y, min_x, max_y, max_x]) + + +def translate_bbox(image, bboxes, pixels, replace, shift_horizontal): + """Equivalent of PIL Translate in X/Y dimension that shifts image and bbox. + + Args: + image: 3D uint8 Tensor. + bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox + has 4 elements (min_y, min_x, max_y, max_x) of type float with values + between [0, 1]. + pixels: An int. How many pixels to shift the image and bboxes + replace: A one or three value 1D tensor to fill empty pixels. + shift_horizontal: Boolean. If true then shift in X dimension else shift in + Y dimension. + + Returns: + A tuple containing a 3D uint8 Tensor that will be the result of translating + image by pixels. The second element of the tuple is bboxes, where now + the coordinates will be shifted to reflect the shifted image. + """ + if shift_horizontal: + image = translate_x(image, pixels, replace) + else: + image = translate_y(image, pixels, replace) + + # Convert bbox coordinates to pixel values. + image_height, image_width = image.shape[0], image.shape[1] + # pylint:disable=g-long-lambda + wrapped_shift_bbox = lambda bbox: _shift_bbox(bbox, image_height, image_width, pixels, shift_horizontal) + # pylint:enable=g-long-lambda + new_bboxes = deepcopy(bboxes) + num_bboxes = len(bboxes) + for idx in range(num_bboxes): + new_bboxes[idx] = wrapped_shift_bbox(bboxes[idx]) + return image.astype(np.uint8), new_bboxes + + +def shear_x(image, level, replace): + """Equivalent of PIL Shearing in X dimension.""" + # Shear parallel to x axis is a projective transform + # with a matrix form of: + # [1 level + # 0 1]. + image = Image.fromarray(wrap(image)) + image = image.transform(image.size, Image.AFFINE, (1, level, 0, 0, 1, 0)) + return unwrap(np.array(image), replace) + + +def shear_y(image, level, replace): + """Equivalent of PIL Shearing in Y dimension.""" + # Shear parallel to y axis is a projective transform + # with a matrix form of: + # [1 0 + # level 1]. + image = Image.fromarray(wrap(image)) + image = image.transform(image.size, Image.AFFINE, (1, 0, 0, level, 1, 0)) + return unwrap(np.array(image), replace) + + +def _shear_bbox(bbox, image_height, image_width, level, shear_horizontal): + """Shifts the bbox according to how the image was sheared. + + Args: + bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x) + of type float that represents the normalized coordinates between 0 and 1. + image_height: Int, height of the image. + image_width: Int, height of the image. + level: Float. How much to shear the image. + shear_horizontal: If true then shear in X dimension else shear in + the Y dimension. + + Returns: + A tensor of the same shape as bbox, but now with the shifted coordinates. + """ + image_height, image_width = (float(image_height), float(image_width)) + + # Change bbox coordinates to be pixels. + min_y = int(image_height * bbox[0]) + min_x = int(image_width * bbox[1]) + max_y = int(image_height * bbox[2]) + max_x = int(image_width * bbox[3]) + coordinates = np.stack( + [[min_y, min_x], [min_y, max_x], [max_y, min_x], [max_y, max_x]]) + coordinates = coordinates.astype(np.float32) + + # Shear the coordinates according to the translation matrix. + if shear_horizontal: + translation_matrix = np.stack([[1, 0], [-level, 1]]) + else: + translation_matrix = np.stack([[1, -level], [0, 1]]) + translation_matrix = translation_matrix.astype(np.float32) + new_coords = np.matmul(translation_matrix, + np.transpose(coordinates)).astype(np.int32) + + # Find min/max values and convert them back to floats. + min_y = float(np.min(new_coords[0, :])) / image_height + min_x = float(np.min(new_coords[1, :])) / image_width + max_y = float(np.max(new_coords[0, :])) / image_height + max_x = float(np.max(new_coords[1, :])) / image_width + + # Clip the bboxes to be sure the fall between [0, 1]. + min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x) + min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x) + return np.stack([min_y, min_x, max_y, max_x]) + + +def shear_with_bboxes(image, bboxes, level, replace, shear_horizontal): + """Applies Shear Transformation to the image and shifts the bboxes. + + Args: + image: 3D uint8 Tensor. + bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox + has 4 elements (min_y, min_x, max_y, max_x) of type float with values + between [0, 1]. + level: Float. How much to shear the image. This value will be between + -0.3 to 0.3. + replace: A one or three value 1D tensor to fill empty pixels. + shear_horizontal: Boolean. If true then shear in X dimension else shear in + the Y dimension. + + Returns: + A tuple containing a 3D uint8 Tensor that will be the result of shearing + image by level. The second element of the tuple is bboxes, where now + the coordinates will be shifted to reflect the sheared image. + """ + if shear_horizontal: + image = shear_x(image, level, replace) + else: + image = shear_y(image, level, replace) + + # Convert bbox coordinates to pixel values. + image_height, image_width = image.shape[:2] + # pylint:disable=g-long-lambda + wrapped_shear_bbox = lambda bbox: _shear_bbox(bbox, image_height, image_width, level, shear_horizontal) + # pylint:enable=g-long-lambda + new_bboxes = deepcopy(bboxes) + num_bboxes = len(bboxes) + for idx in range(num_bboxes): + new_bboxes[idx] = wrapped_shear_bbox(bboxes[idx]) + return image.astype(np.uint8), new_bboxes + + +def autocontrast(image): + """Implements Autocontrast function from PIL. + + Args: + image: A 3D uint8 tensor. + + Returns: + The image after it has had autocontrast applied to it and will be of type + uint8. + """ + + def scale_channel(image): + """Scale the 2D image using the autocontrast rule.""" + # A possibly cheaper version can be done using cumsum/unique_with_counts + # over the histogram values, rather than iterating over the entire image. + # to compute mins and maxes. + lo = float(np.min(image)) + hi = float(np.max(image)) + + # Scale the image, making the lowest value 0 and the highest value 255. + def scale_values(im): + scale = 255.0 / (hi - lo) + offset = -lo * scale + im = im.astype(np.float32) * scale + offset + img = np.clip(im, a_min=0, a_max=255.0) + return im.astype(np.uint8) + + result = scale_values(image) if hi > lo else image + return result + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image[:, :, 0]) + s2 = scale_channel(image[:, :, 1]) + s3 = scale_channel(image[:, :, 2]) + image = np.stack([s1, s2, s3], 2) + return image + + +def sharpness(image, factor): + """Implements Sharpness function from PIL.""" + orig_image = image + image = image.astype(np.float32) + # Make image 4D for conv operation. + # SMOOTH PIL Kernel. + kernel = np.array([[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=np.float32) / 13. + result = cv2.filter2D(image, -1, kernel).astype(np.uint8) + + # Blend the final result. + return blend(result, orig_image, factor) + + +def equalize(image): + """Implements Equalize function from PIL using.""" + + def scale_channel(im, c): + """Scale the data in the channel to implement equalize.""" + im = im[:, :, c].astype(np.int32) + # Compute the histogram of the image channel. + histo, _ = np.histogram(im, range=[0, 255], bins=256) + + # For the purposes of computing the step, filter out the nonzeros. + nonzero = np.where(np.not_equal(histo, 0)) + nonzero_histo = np.reshape(np.take(histo, nonzero), [-1]) + step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255 + + def build_lut(histo, step): + # Compute the cumulative sum, shifting by step // 2 + # and then normalization by step. + lut = (np.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = np.concatenate([[0], lut[:-1]], 0) + # Clip the counts to be in range. This is done + # in the C code for image.point. + return np.clip(lut, a_min=0, a_max=255).astype(np.uint8) + + # If step is zero, return the original image. Otherwise, build + # lut from the full histogram and step and then index from it. + if step == 0: + result = im + else: + result = np.take(build_lut(histo, step), im) + + return result.astype(np.uint8) + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image, 0) + s2 = scale_channel(image, 1) + s3 = scale_channel(image, 2) + image = np.stack([s1, s2, s3], 2) + return image + + +def wrap(image): + """Returns 'image' with an extra channel set to all 1s.""" + shape = image.shape + extended_channel = 255 * np.ones([shape[0], shape[1], 1], image.dtype) + extended = np.concatenate([image, extended_channel], 2).astype(image.dtype) + return extended + + +def unwrap(image, replace): + """Unwraps an image produced by wrap. + + Where there is a 0 in the last channel for every spatial position, + the rest of the three channels in that spatial dimension are grayed + (set to 128). Operations like translate and shear on a wrapped + Tensor will leave 0s in empty locations. Some transformations look + at the intensity of values to do preprocessing, and we want these + empty pixels to assume the 'average' value, rather than pure black. + + + Args: + image: A 3D Image Tensor with 4 channels. + replace: A one or three value 1D tensor to fill empty pixels. + + Returns: + image: A 3D image Tensor with 3 channels. + """ + image_shape = image.shape + # Flatten the spatial dimensions. + flattened_image = np.reshape(image, [-1, image_shape[2]]) + + # Find all pixels where the last channel is zero. + alpha_channel = flattened_image[:, 3] + + replace = np.concatenate([replace, np.ones([1], image.dtype)], 0) + + # Where they are zero, fill them in with 'replace'. + alpha_channel = np.reshape(alpha_channel, (-1, 1)) + alpha_channel = np.tile(alpha_channel, reps=(1, flattened_image.shape[1])) + + flattened_image = np.where( + np.equal(alpha_channel, 0), + np.ones_like( + flattened_image, dtype=image.dtype) * replace, + flattened_image) + + image = np.reshape(flattened_image, image_shape) + image = image[:, :, :3] + return image.astype(np.uint8) + + +def _cutout_inside_bbox(image, bbox, pad_fraction): + """Generates cutout mask and the mean pixel value of the bbox. + + First a location is randomly chosen within the image as the center where the + cutout mask will be applied. Note this can be towards the boundaries of the + image, so the full cutout mask may not be applied. + + Args: + image: 3D uint8 Tensor. + bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x) + of type float that represents the normalized coordinates between 0 and 1. + pad_fraction: Float that specifies how large the cutout mask should be in + in reference to the size of the original bbox. If pad_fraction is 0.25, + then the cutout mask will be of shape + (0.25 * bbox height, 0.25 * bbox width). + + Returns: + A tuple. Fist element is a tensor of the same shape as image where each + element is either a 1 or 0 that is used to determine where the image + will have cutout applied. The second element is the mean of the pixels + in the image where the bbox is located. + mask value: [0,1] + """ + image_height, image_width = image.shape[0], image.shape[1] + # Transform from shape [1, 4] to [4]. + bbox = np.squeeze(bbox) + + min_y = int(float(image_height) * bbox[0]) + min_x = int(float(image_width) * bbox[1]) + max_y = int(float(image_height) * bbox[2]) + max_x = int(float(image_width) * bbox[3]) + + # Calculate the mean pixel values in the bounding box, which will be used + # to fill the cutout region. + mean = np.mean(image[min_y:max_y + 1, min_x:max_x + 1], axis=(0, 1)) + # Cutout mask will be size pad_size_heigh * 2 by pad_size_width * 2 if the + # region lies entirely within the bbox. + box_height = max_y - min_y + 1 + box_width = max_x - min_x + 1 + pad_size_height = int(pad_fraction * (box_height / 2)) + pad_size_width = int(pad_fraction * (box_width / 2)) + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = np.random.randint(min_y, max_y + 1, dtype=np.int32) + cutout_center_width = np.random.randint(min_x, max_x + 1, dtype=np.int32) + + lower_pad = np.maximum(0, cutout_center_height - pad_size_height) + upper_pad = np.maximum( + 0, image_height - cutout_center_height - pad_size_height) + left_pad = np.maximum(0, cutout_center_width - pad_size_width) + right_pad = np.maximum(0, + image_width - cutout_center_width - pad_size_width) + + cutout_shape = [ + image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad) + ] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + + mask = np.pad(np.zeros( + cutout_shape, dtype=image.dtype), + padding_dims, + 'constant', + constant_values=1) + + mask = np.expand_dims(mask, 2) + mask = np.tile(mask, [1, 1, 3]) + return mask, mean + + +def bbox_cutout(image, bboxes, pad_fraction, replace_with_mean): + """Applies cutout to the image according to bbox information. + + This is a cutout variant that using bbox information to make more informed + decisions on where to place the cutout mask. + + Args: + image: 3D uint8 Tensor. + bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox + has 4 elements (min_y, min_x, max_y, max_x) of type float with values + between [0, 1]. + pad_fraction: Float that specifies how large the cutout mask should be in + in reference to the size of the original bbox. If pad_fraction is 0.25, + then the cutout mask will be of shape + (0.25 * bbox height, 0.25 * bbox width). + replace_with_mean: Boolean that specified what value should be filled in + where the cutout mask is applied. Since the incoming image will be of + uint8 and will not have had any mean normalization applied, by default + we set the value to be 128. If replace_with_mean is True then we find + the mean pixel values across the channel dimension and use those to fill + in where the cutout mask is applied. + + Returns: + A tuple. First element is a tensor of the same shape as image that has + cutout applied to it. Second element is the bboxes that were passed in + that will be unchanged. + """ + + def apply_bbox_cutout(image, bboxes, pad_fraction): + """Applies cutout to a single bounding box within image.""" + # Choose a single bounding box to apply cutout to. + random_index = np.random.randint(0, bboxes.shape[0], dtype=np.int32) + # Select the corresponding bbox and apply cutout. + chosen_bbox = np.take(bboxes, random_index, axis=0) + mask, mean = _cutout_inside_bbox(image, chosen_bbox, pad_fraction) + + # When applying cutout we either set the pixel value to 128 or to the mean + # value inside the bbox. + replace = mean if replace_with_mean else [128] * 3 + + # Apply the cutout mask to the image. Where the mask is 0 we fill it with + # `replace`. + image = np.where( + np.equal(mask, 0), + np.ones_like( + image, dtype=image.dtype) * replace, + image).astype(image.dtype) + return image + + # Check to see if there are boxes, if so then apply boxcutout. + if len(bboxes) != 0: + image = apply_bbox_cutout(image, bboxes, pad_fraction) + + return image, bboxes + + +NAME_TO_FUNC = { + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'Cutout': cutout, + 'BBox_Cutout': bbox_cutout, + 'Rotate_BBox': rotate_with_bboxes, + # pylint:disable=g-long-lambda + 'TranslateX_BBox': lambda image, bboxes, pixels, replace: translate_bbox( + image, bboxes, pixels, replace, shift_horizontal=True), + 'TranslateY_BBox': lambda image, bboxes, pixels, replace: translate_bbox( + image, bboxes, pixels, replace, shift_horizontal=False), + 'ShearX_BBox': lambda image, bboxes, level, replace: shear_with_bboxes( + image, bboxes, level, replace, shear_horizontal=True), + 'ShearY_BBox': lambda image, bboxes, level, replace: shear_with_bboxes( + image, bboxes, level, replace, shear_horizontal=False), + # pylint:enable=g-long-lambda + 'Rotate_Only_BBoxes': rotate_only_bboxes, + 'ShearX_Only_BBoxes': shear_x_only_bboxes, + 'ShearY_Only_BBoxes': shear_y_only_bboxes, + 'TranslateX_Only_BBoxes': translate_x_only_bboxes, + 'TranslateY_Only_BBoxes': translate_y_only_bboxes, + 'Flip_Only_BBoxes': flip_only_bboxes, + 'Solarize_Only_BBoxes': solarize_only_bboxes, + 'Equalize_Only_BBoxes': equalize_only_bboxes, + 'Cutout_Only_BBoxes': cutout_only_bboxes, +} + + +def _randomly_negate_tensor(tensor): + """With 50% prob turn the tensor negative.""" + should_flip = np.floor(np.random.rand() + 0.5) >= 1 + final_tensor = tensor if should_flip else -tensor + return final_tensor + + +def _rotate_level_to_arg(level): + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate_tensor(level) + return (level, ) + + +def _shrink_level_to_arg(level): + """Converts level to ratio by which we shrink the image content.""" + if level == 0: + return (1.0, ) # if level is zero, do not shrink the image + # Maximum shrinking ratio is 2.9. + level = 2. / (_MAX_LEVEL / level) + 0.9 + return (level, ) + + +def _enhance_level_to_arg(level): + return ((level / _MAX_LEVEL) * 1.8 + 0.1, ) + + +def _shear_level_to_arg(level): + level = (level / _MAX_LEVEL) * 0.3 + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level, ) + + +def _translate_level_to_arg(level, translate_const): + level = (level / _MAX_LEVEL) * float(translate_const) + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level, ) + + +def _bbox_cutout_level_to_arg(level, hparams): + cutout_pad_fraction = (level / + _MAX_LEVEL) * 0.75 # hparams.cutout_max_pad_fraction + return (cutout_pad_fraction, False) # hparams.cutout_bbox_replace_with_mean + + +def level_to_arg(hparams): + return { + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4), ), + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256), ), + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110), ), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'Cutout': + lambda level: (int((level / _MAX_LEVEL) * 100), ), # hparams.cutout_const=100 + # pylint:disable=g-long-lambda + 'BBox_Cutout': lambda level: _bbox_cutout_level_to_arg(level, hparams), + 'TranslateX_BBox': + lambda level: _translate_level_to_arg(level, 250), # hparams.translate_const=250 + 'TranslateY_BBox': + lambda level: _translate_level_to_arg(level, 250), # hparams.translate_cons + # pylint:enable=g-long-lambda + 'ShearX_BBox': _shear_level_to_arg, + 'ShearY_BBox': _shear_level_to_arg, + 'Rotate_BBox': _rotate_level_to_arg, + 'Rotate_Only_BBoxes': _rotate_level_to_arg, + 'ShearX_Only_BBoxes': _shear_level_to_arg, + 'ShearY_Only_BBoxes': _shear_level_to_arg, + # pylint:disable=g-long-lambda + 'TranslateX_Only_BBoxes': + lambda level: _translate_level_to_arg(level, 120), # hparams.translate_bbox_const + 'TranslateY_Only_BBoxes': + lambda level: _translate_level_to_arg(level, 120), # hparams.translate_bbox_const + # pylint:enable=g-long-lambda + 'Flip_Only_BBoxes': lambda level: (), + 'Solarize_Only_BBoxes': + lambda level: (int((level / _MAX_LEVEL) * 256), ), + 'Equalize_Only_BBoxes': lambda level: (), + # pylint:disable=g-long-lambda + 'Cutout_Only_BBoxes': + lambda level: (int((level / _MAX_LEVEL) * 50), ), # hparams.cutout_bbox_const + # pylint:enable=g-long-lambda + } + + +def bbox_wrapper(func): + """Adds a bboxes function argument to func and returns unchanged bboxes.""" + + def wrapper(images, bboxes, *args, **kwargs): + return (func(images, *args, **kwargs), bboxes) + + return wrapper + + +def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): + """Return the function that corresponds to `name` and update `level` param.""" + func = NAME_TO_FUNC[name] + args = level_to_arg(augmentation_hparams)[name](level) + + # Check to see if prob is passed into function. This is used for operations + # where we alter bboxes independently. + # pytype:disable=wrong-arg-types + if 'prob' in inspect.getargspec(func)[0]: + args = tuple([prob] + list(args)) + # pytype:enable=wrong-arg-types + + # Add in replace arg if it is required for the function that is being called. + if 'replace' in inspect.getargspec(func)[0]: + # Make sure replace is the final argument + assert 'replace' == inspect.getargspec(func)[0][-1] + args = tuple(list(args) + [replace_value]) + + # Add bboxes as the second positional argument for the function if it does + # not already exist. + if 'bboxes' not in inspect.getargspec(func)[0]: + func = bbox_wrapper(func) + return (func, prob, args) + + +def _apply_func_with_prob(func, image, args, prob, bboxes): + """Apply `func` to image w/ `args` as input with probability `prob`.""" + assert isinstance(args, tuple) + assert 'bboxes' == inspect.getargspec(func)[0][1] + + # If prob is a function argument, then this randomness is being handled + # inside the function, so make sure it is always called. + if 'prob' in inspect.getargspec(func)[0]: + prob = 1.0 + + # Apply the function with probability `prob`. + should_apply_op = np.floor(np.random.rand() + 0.5) >= 1 + if should_apply_op: + augmented_image, augmented_bboxes = func(image, bboxes, *args) + else: + augmented_image, augmented_bboxes = (image, bboxes) + return augmented_image, augmented_bboxes + + +def select_and_apply_random_policy(policies, image, bboxes): + """Select a random policy from `policies` and apply it to `image`.""" + policy_to_select = np.random.randint(0, len(policies), dtype=np.int32) + # policy_to_select = 6 # for test + for (i, policy) in enumerate(policies): + if i == policy_to_select: + image, bboxes = policy(image, bboxes) + return (image, bboxes) + + +def build_and_apply_nas_policy(policies, image, bboxes, augmentation_hparams): + """Build a policy from the given policies passed in and apply to image. + + Args: + policies: list of lists of tuples in the form `(func, prob, level)`, `func` + is a string name of the augmentation function, `prob` is the probability + of applying the `func` operation, `level` is the input argument for + `func`. + image: numpy array that the resulting policy will be applied to. + bboxes: + augmentation_hparams: Hparams associated with the NAS learned policy. + + Returns: + A version of image that now has data augmentation applied to it based on + the `policies` pass into the function. Additionally, returns bboxes if + a value for them is passed in that is not None + """ + replace_value = [128, 128, 128] + + # func is the string name of the augmentation function, prob is the + # probability of applying the operation and level is the parameter associated + + # tf_policies are functions that take in an image and return an augmented + # image. + tf_policies = [] + for policy in policies: + tf_policy = [] + # Link string name to the correct python function and make sure the correct + # argument is passed into that function. + for policy_info in policy: + policy_info = list( + policy_info) + [replace_value, augmentation_hparams] + + tf_policy.append(_parse_policy_info(*policy_info)) + # Now build the tf policy that will apply the augmentation procedue + # on image. + def make_final_policy(tf_policy_): + def final_policy(image_, bboxes_): + for func, prob, args in tf_policy_: + image_, bboxes_ = _apply_func_with_prob(func, image_, args, + prob, bboxes_) + return image_, bboxes_ + + return final_policy + + tf_policies.append(make_final_policy(tf_policy)) + + augmented_images, augmented_bboxes = select_and_apply_random_policy( + tf_policies, image, bboxes) + # If no bounding boxes were specified, then just return the images. + return (augmented_images, augmented_bboxes) + + +# TODO(barretzoph): Add in ArXiv link once paper is out. +def distort_image_with_autoaugment(image, bboxes, augmentation_name): + """Applies the AutoAugment policy to `image` and `bboxes`. + + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + bboxes: `Tensor` of shape [N, 4] representing ground truth boxes that are + normalized between [0, 1]. + augmentation_name: The name of the AutoAugment policy to use. The available + options are `v0`, `v1`, `v2`, `v3` and `test`. `v0` is the policy used for + all of the results in the paper and was found to achieve the best results + on the COCO dataset. `v1`, `v2` and `v3` are additional good policies + found on the COCO dataset that have slight variation in what operations + were used during the search procedure along with how many operations are + applied in parallel to a single image (2 vs 3). + + Returns: + A tuple containing the augmented versions of `image` and `bboxes`. + """ + available_policies = { + 'v0': policy_v0, + 'v1': policy_v1, + 'v2': policy_v2, + 'v3': policy_v3, + 'test': policy_vtest + } + if augmentation_name not in available_policies: + raise ValueError('Invalid augmentation_name: {}'.format( + augmentation_name)) + + policy = available_policies[augmentation_name]() + augmentation_hparams = {} + return build_and_apply_nas_policy(policy, image, bboxes, + augmentation_hparams) diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/transform/batch_operators.py b/VisualFL/depends/PaddleDetection/ppdet/data/transform/batch_operators.py new file mode 100644 index 000000000..c0e9bd6e0 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/transform/batch_operators.py @@ -0,0 +1,752 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence + +import logging +import cv2 +import numpy as np +from scipy import ndimage + +from .operators import register_op, BaseOperator +from .op_helper import jaccard_overlap, gaussian2D + +logger = logging.getLogger(__name__) + +__all__ = [ + 'PadBatch', + 'RandomShape', + 'PadMultiScaleTest', + 'Gt2YoloTarget', + 'Gt2FCOSTarget', + 'Gt2TTFTarget', + 'Gt2Solov2Target', +] + + +@register_op +class PadBatch(BaseOperator): + """ + Pad a batch of samples so they can be divisible by a stride. + The layout of each image should be 'CHW'. + Args: + pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure + height and width is divisible by `pad_to_stride`. + """ + + def __init__(self, pad_to_stride=0, use_padded_im_info=True): + super(PadBatch, self).__init__() + self.pad_to_stride = pad_to_stride + self.use_padded_im_info = use_padded_im_info + + def __call__(self, samples, context=None): + """ + Args: + samples (list): a batch of sample, each is dict. + """ + coarsest_stride = self.pad_to_stride + if coarsest_stride == 0: + return samples + max_shape = np.array([data['image'].shape for data in samples]).max( + axis=0) + + if coarsest_stride > 0: + max_shape[1] = int( + np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride) + max_shape[2] = int( + np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride) + + padding_batch = [] + for data in samples: + im = data['image'] + im_c, im_h, im_w = im.shape[:] + padding_im = np.zeros( + (im_c, max_shape[1], max_shape[2]), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + data['image'] = padding_im + if self.use_padded_im_info: + data['im_info'][:2] = max_shape[1:3] + if 'semantic' in data.keys() and data['semantic'] is not None: + semantic = data['semantic'] + padding_sem = np.zeros( + (1, max_shape[1], max_shape[2]), dtype=np.float32) + padding_sem[:, :im_h, :im_w] = semantic + data['semantic'] = padding_sem + if 'gt_segm' in data.keys() and data['gt_segm'] is not None: + gt_segm = data['gt_segm'] + padding_segm = np.zeros( + (gt_segm.shape[0], max_shape[1], max_shape[2]), + dtype=np.uint8) + padding_segm[:, :im_h, :im_w] = gt_segm + data['gt_segm'] = padding_segm + + return samples + + +@register_op +class RandomShape(BaseOperator): + """ + Randomly reshape a batch. If random_inter is True, also randomly + select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR, + cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is + False, use cv2.INTER_NEAREST. + Args: + sizes (list): list of int, random choose a size from these + random_inter (bool): whether to randomly interpolation, defalut true. + """ + + def __init__(self, sizes=[], random_inter=False, resize_box=False): + super(RandomShape, self).__init__() + self.sizes = sizes + self.random_inter = random_inter + self.interps = [ + cv2.INTER_NEAREST, + cv2.INTER_LINEAR, + cv2.INTER_AREA, + cv2.INTER_CUBIC, + cv2.INTER_LANCZOS4, + ] if random_inter else [] + self.resize_box = resize_box + + def __call__(self, samples, context=None): + shape = np.random.choice(self.sizes) + method = np.random.choice(self.interps) if self.random_inter \ + else cv2.INTER_NEAREST + for i in range(len(samples)): + im = samples[i]['image'] + h, w = im.shape[:2] + scale_x = float(shape) / w + scale_y = float(shape) / h + im = cv2.resize( + im, None, None, fx=scale_x, fy=scale_y, interpolation=method) + samples[i]['image'] = im + if self.resize_box and 'gt_bbox' in samples[i] and len(samples[0][ + 'gt_bbox']) > 0: + scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32) + samples[i]['gt_bbox'] = np.clip(samples[i]['gt_bbox'] * + scale_array, 0, + float(shape) - 1) + return samples + + +@register_op +class PadMultiScaleTest(BaseOperator): + """ + Pad the image so they can be divisible by a stride for multi-scale testing. + + Args: + pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure + height and width is divisible by `pad_to_stride`. + """ + + def __init__(self, pad_to_stride=0): + super(PadMultiScaleTest, self).__init__() + self.pad_to_stride = pad_to_stride + + def __call__(self, samples, context=None): + coarsest_stride = self.pad_to_stride + if coarsest_stride == 0: + return samples + + batch_input = True + if not isinstance(samples, Sequence): + batch_input = False + samples = [samples] + if len(samples) != 1: + raise ValueError("Batch size must be 1 when using multiscale test, " + "but now batch size is {}".format(len(samples))) + for i in range(len(samples)): + sample = samples[i] + for k in sample.keys(): + # hard code + if k.startswith('image'): + im = sample[k] + im_c, im_h, im_w = im.shape + max_h = int( + np.ceil(im_h / coarsest_stride) * coarsest_stride) + max_w = int( + np.ceil(im_w / coarsest_stride) * coarsest_stride) + padding_im = np.zeros( + (im_c, max_h, max_w), dtype=np.float32) + + padding_im[:, :im_h, :im_w] = im + sample[k] = padding_im + info_name = 'im_info' if k == 'image' else 'im_info_' + k + # update im_info + sample[info_name][:2] = [max_h, max_w] + if not batch_input: + samples = samples[0] + return samples + + +@register_op +class Gt2YoloTarget(BaseOperator): + """ + Generate YOLOv3 targets by groud truth data, this operator is only used in + fine grained YOLOv3 loss mode + """ + + def __init__(self, + anchors, + anchor_masks, + downsample_ratios, + num_classes=80, + iou_thresh=1.): + super(Gt2YoloTarget, self).__init__() + self.anchors = anchors + self.anchor_masks = anchor_masks + self.downsample_ratios = downsample_ratios + self.num_classes = num_classes + self.iou_thresh = iou_thresh + + def __call__(self, samples, context=None): + assert len(self.anchor_masks) == len(self.downsample_ratios), \ + "anchor_masks', and 'downsample_ratios' should have same length." + + h, w = samples[0]['image'].shape[1:3] + an_hw = np.array(self.anchors) / np.array([[w, h]]) + for sample in samples: + # im, gt_bbox, gt_class, gt_score = sample + im = sample['image'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + gt_score = sample['gt_score'] + for i, ( + mask, downsample_ratio + ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)): + grid_h = int(h / downsample_ratio) + grid_w = int(w / downsample_ratio) + target = np.zeros( + (len(mask), 6 + self.num_classes, grid_h, grid_w), + dtype=np.float32) + for b in range(gt_bbox.shape[0]): + gx, gy, gw, gh = gt_bbox[b, :] + cls = gt_class[b] + score = gt_score[b] + if gw <= 0. or gh <= 0. or score <= 0.: + continue + + # find best match anchor index + best_iou = 0. + best_idx = -1 + for an_idx in range(an_hw.shape[0]): + iou = jaccard_overlap( + [0., 0., gw, gh], + [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]]) + if iou > best_iou: + best_iou = iou + best_idx = an_idx + + gi = int(gx * grid_w) + gj = int(gy * grid_h) + + # gtbox should be regresed in this layes if best match + # anchor index in anchor mask of this layer + if best_idx in mask: + best_n = mask.index(best_idx) + + # x, y, w, h, scale + target[best_n, 0, gj, gi] = gx * grid_w - gi + target[best_n, 1, gj, gi] = gy * grid_h - gj + target[best_n, 2, gj, gi] = np.log( + gw * w / self.anchors[best_idx][0]) + target[best_n, 3, gj, gi] = np.log( + gh * h / self.anchors[best_idx][1]) + target[best_n, 4, gj, gi] = 2.0 - gw * gh + + # objectness record gt_score + target[best_n, 5, gj, gi] = score + + # classification + target[best_n, 6 + cls, gj, gi] = 1. + + # For non-matched anchors, calculate the target if the iou + # between anchor and gt is larger than iou_thresh + if self.iou_thresh < 1: + for idx, mask_i in enumerate(mask): + if mask_i == best_idx: continue + iou = jaccard_overlap( + [0., 0., gw, gh], + [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]]) + if iou > self.iou_thresh: + # x, y, w, h, scale + target[idx, 0, gj, gi] = gx * grid_w - gi + target[idx, 1, gj, gi] = gy * grid_h - gj + target[idx, 2, gj, gi] = np.log( + gw * w / self.anchors[mask_i][0]) + target[idx, 3, gj, gi] = np.log( + gh * h / self.anchors[mask_i][1]) + target[idx, 4, gj, gi] = 2.0 - gw * gh + + # objectness record gt_score + target[idx, 5, gj, gi] = score + + # classification + target[idx, 6 + cls, gj, gi] = 1. + sample['target{}'.format(i)] = target + return samples + + +@register_op +class Gt2FCOSTarget(BaseOperator): + """ + Generate FCOS targets by groud truth data + """ + + def __init__(self, + object_sizes_boundary, + center_sampling_radius, + downsample_ratios, + norm_reg_targets=False): + super(Gt2FCOSTarget, self).__init__() + self.center_sampling_radius = center_sampling_radius + self.downsample_ratios = downsample_ratios + self.INF = np.inf + self.object_sizes_boundary = [-1] + object_sizes_boundary + [self.INF] + object_sizes_of_interest = [] + for i in range(len(self.object_sizes_boundary) - 1): + object_sizes_of_interest.append([ + self.object_sizes_boundary[i], self.object_sizes_boundary[i + 1] + ]) + self.object_sizes_of_interest = object_sizes_of_interest + self.norm_reg_targets = norm_reg_targets + + def _compute_points(self, w, h): + """ + compute the corresponding points in each feature map + :param h: image height + :param w: image width + :return: points from all feature map + """ + locations = [] + for stride in self.downsample_ratios: + shift_x = np.arange(0, w, stride).astype(np.float32) + shift_y = np.arange(0, h, stride).astype(np.float32) + shift_x, shift_y = np.meshgrid(shift_x, shift_y) + shift_x = shift_x.flatten() + shift_y = shift_y.flatten() + location = np.stack([shift_x, shift_y], axis=1) + stride // 2 + locations.append(location) + num_points_each_level = [len(location) for location in locations] + locations = np.concatenate(locations, axis=0) + return locations, num_points_each_level + + def _convert_xywh2xyxy(self, gt_bbox, w, h): + """ + convert the bounding box from style xywh to xyxy + :param gt_bbox: bounding boxes normalized into [0, 1] + :param w: image width + :param h: image height + :return: bounding boxes in xyxy style + """ + bboxes = gt_bbox.copy() + bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * w + bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * h + bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] + bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] + return bboxes + + def _check_inside_boxes_limited(self, gt_bbox, xs, ys, + num_points_each_level): + """ + check if points is within the clipped boxes + :param gt_bbox: bounding boxes + :param xs: horizontal coordinate of points + :param ys: vertical coordinate of points + :return: the mask of points is within gt_box or not + """ + bboxes = np.reshape( + gt_bbox, newshape=[1, gt_bbox.shape[0], gt_bbox.shape[1]]) + bboxes = np.tile(bboxes, reps=[xs.shape[0], 1, 1]) + ct_x = (bboxes[:, :, 0] + bboxes[:, :, 2]) / 2 + ct_y = (bboxes[:, :, 1] + bboxes[:, :, 3]) / 2 + beg = 0 + clipped_box = bboxes.copy() + for lvl, stride in enumerate(self.downsample_ratios): + end = beg + num_points_each_level[lvl] + stride_exp = self.center_sampling_radius * stride + clipped_box[beg:end, :, 0] = np.maximum( + bboxes[beg:end, :, 0], ct_x[beg:end, :] - stride_exp) + clipped_box[beg:end, :, 1] = np.maximum( + bboxes[beg:end, :, 1], ct_y[beg:end, :] - stride_exp) + clipped_box[beg:end, :, 2] = np.minimum( + bboxes[beg:end, :, 2], ct_x[beg:end, :] + stride_exp) + clipped_box[beg:end, :, 3] = np.minimum( + bboxes[beg:end, :, 3], ct_y[beg:end, :] + stride_exp) + beg = end + l_res = xs - clipped_box[:, :, 0] + r_res = clipped_box[:, :, 2] - xs + t_res = ys - clipped_box[:, :, 1] + b_res = clipped_box[:, :, 3] - ys + clipped_box_reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2) + inside_gt_box = np.min(clipped_box_reg_targets, axis=2) > 0 + return inside_gt_box + + def __call__(self, samples, context=None): + assert len(self.object_sizes_of_interest) == len(self.downsample_ratios), \ + "object_sizes_of_interest', and 'downsample_ratios' should have same length." + + for sample in samples: + # im, gt_bbox, gt_class, gt_score = sample + im = sample['image'] + im_info = sample['im_info'] + bboxes = sample['gt_bbox'] + gt_class = sample['gt_class'] + gt_score = sample['gt_score'] + bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * np.floor(im_info[1]) / \ + np.floor(im_info[1] / im_info[2]) + bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * np.floor(im_info[0]) / \ + np.floor(im_info[0] / im_info[2]) + # calculate the locations + h, w = sample['image'].shape[1:3] + points, num_points_each_level = self._compute_points(w, h) + object_scale_exp = [] + for i, num_pts in enumerate(num_points_each_level): + object_scale_exp.append( + np.tile( + np.array([self.object_sizes_of_interest[i]]), + reps=[num_pts, 1])) + object_scale_exp = np.concatenate(object_scale_exp, axis=0) + + gt_area = (bboxes[:, 2] - bboxes[:, 0]) * ( + bboxes[:, 3] - bboxes[:, 1]) + xs, ys = points[:, 0], points[:, 1] + xs = np.reshape(xs, newshape=[xs.shape[0], 1]) + xs = np.tile(xs, reps=[1, bboxes.shape[0]]) + ys = np.reshape(ys, newshape=[ys.shape[0], 1]) + ys = np.tile(ys, reps=[1, bboxes.shape[0]]) + + l_res = xs - bboxes[:, 0] + r_res = bboxes[:, 2] - xs + t_res = ys - bboxes[:, 1] + b_res = bboxes[:, 3] - ys + reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2) + if self.center_sampling_radius > 0: + is_inside_box = self._check_inside_boxes_limited( + bboxes, xs, ys, num_points_each_level) + else: + is_inside_box = np.min(reg_targets, axis=2) > 0 + # check if the targets is inside the corresponding level + max_reg_targets = np.max(reg_targets, axis=2) + lower_bound = np.tile( + np.expand_dims( + object_scale_exp[:, 0], axis=1), + reps=[1, max_reg_targets.shape[1]]) + high_bound = np.tile( + np.expand_dims( + object_scale_exp[:, 1], axis=1), + reps=[1, max_reg_targets.shape[1]]) + is_match_current_level = \ + (max_reg_targets > lower_bound) & \ + (max_reg_targets < high_bound) + points2gtarea = np.tile( + np.expand_dims( + gt_area, axis=0), reps=[xs.shape[0], 1]) + points2gtarea[is_inside_box == 0] = self.INF + points2gtarea[is_match_current_level == 0] = self.INF + points2min_area = points2gtarea.min(axis=1) + points2min_area_ind = points2gtarea.argmin(axis=1) + labels = gt_class[points2min_area_ind] + 1 + labels[points2min_area == self.INF] = 0 + reg_targets = reg_targets[range(xs.shape[0]), points2min_area_ind] + ctn_targets = np.sqrt((reg_targets[:, [0, 2]].min(axis=1) / \ + reg_targets[:, [0, 2]].max(axis=1)) * \ + (reg_targets[:, [1, 3]].min(axis=1) / \ + reg_targets[:, [1, 3]].max(axis=1))).astype(np.float32) + ctn_targets = np.reshape( + ctn_targets, newshape=[ctn_targets.shape[0], 1]) + ctn_targets[labels <= 0] = 0 + pos_ind = np.nonzero(labels != 0) + reg_targets_pos = reg_targets[pos_ind[0], :] + split_sections = [] + beg = 0 + for lvl in range(len(num_points_each_level)): + end = beg + num_points_each_level[lvl] + split_sections.append(end) + beg = end + labels_by_level = np.split(labels, split_sections, axis=0) + reg_targets_by_level = np.split(reg_targets, split_sections, axis=0) + ctn_targets_by_level = np.split(ctn_targets, split_sections, axis=0) + for lvl in range(len(self.downsample_ratios)): + grid_w = int(np.ceil(w / self.downsample_ratios[lvl])) + grid_h = int(np.ceil(h / self.downsample_ratios[lvl])) + if self.norm_reg_targets: + sample['reg_target{}'.format(lvl)] = \ + np.reshape( + reg_targets_by_level[lvl] / \ + self.downsample_ratios[lvl], + newshape=[grid_h, grid_w, 4]) + else: + sample['reg_target{}'.format(lvl)] = np.reshape( + reg_targets_by_level[lvl], + newshape=[grid_h, grid_w, 4]) + sample['labels{}'.format(lvl)] = np.reshape( + labels_by_level[lvl], newshape=[grid_h, grid_w, 1]) + sample['centerness{}'.format(lvl)] = np.reshape( + ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1]) + return samples + + +@register_op +class Gt2TTFTarget(BaseOperator): + """ + Gt2TTFTarget + Generate TTFNet targets by ground truth data + + Args: + num_classes(int): the number of classes. + down_ratio(int): the down ratio from images to heatmap, 4 by default. + alpha(float): the alpha parameter to generate gaussian target. + 0.54 by default. + """ + + def __init__(self, num_classes, down_ratio=4, alpha=0.54): + super(Gt2TTFTarget, self).__init__() + self.down_ratio = down_ratio + self.num_classes = num_classes + self.alpha = alpha + + def __call__(self, samples, context=None): + output_size = samples[0]['image'].shape[1] + feat_size = output_size // self.down_ratio + for sample in samples: + heatmap = np.zeros( + (self.num_classes, feat_size, feat_size), dtype='float32') + box_target = np.ones( + (4, feat_size, feat_size), dtype='float32') * -1 + reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32') + + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + + bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1 + bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1 + area = bbox_w * bbox_h + boxes_areas_log = np.log(area) + boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1] + boxes_area_topk_log = boxes_areas_log[boxes_ind] + gt_bbox = gt_bbox[boxes_ind] + gt_class = gt_class[boxes_ind] + + feat_gt_bbox = gt_bbox / self.down_ratio + feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1) + feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1], + feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0]) + + ct_inds = np.stack( + [(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2, + (gt_bbox[:, 1] + gt_bbox[:, 3]) / 2], + axis=1) / self.down_ratio + + h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32') + w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32') + + for k in range(len(gt_bbox)): + cls_id = gt_class[k] + fake_heatmap = np.zeros((feat_size, feat_size), dtype='float32') + self.draw_truncate_gaussian(fake_heatmap, ct_inds[k], + h_radiuses_alpha[k], + w_radiuses_alpha[k]) + + heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap) + box_target_inds = fake_heatmap > 0 + box_target[:, box_target_inds] = gt_bbox[k][:, None] + + local_heatmap = fake_heatmap[box_target_inds] + ct_div = np.sum(local_heatmap) + local_heatmap *= boxes_area_topk_log[k] + reg_weight[0, box_target_inds] = local_heatmap / ct_div + sample['ttf_heatmap'] = heatmap + sample['ttf_box_target'] = box_target + sample['ttf_reg_weight'] = reg_weight + return samples + + def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius): + h, w = 2 * h_radius + 1, 2 * w_radius + 1 + sigma_x = w / 6 + sigma_y = h / 6 + gaussian = gaussian2D((h, w), sigma_x, sigma_y) + + x, y = int(center[0]), int(center[1]) + + height, width = heatmap.shape[0:2] + + left, right = min(x, w_radius), min(width - x, w_radius + 1) + top, bottom = min(y, h_radius), min(height - y, h_radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius - + left:w_radius + right] + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + heatmap[y - top:y + bottom, x - left:x + right] = np.maximum( + masked_heatmap, masked_gaussian) + return heatmap + + +@register_op +class Gt2Solov2Target(BaseOperator): + """Assign mask target and labels in SOLOv2 network. + Args: + num_grids (list): The list of feature map grids size. + scale_ranges (list): The list of mask boundary range. + coord_sigma (float): The coefficient of coordinate area length. + sampling_ratio (float): The ratio of down sampling. + """ + + def __init__(self, + num_grids=[40, 36, 24, 16, 12], + scale_ranges=[[1, 96], [48, 192], [96, 384], [192, 768], + [384, 2048]], + coord_sigma=0.2, + sampling_ratio=4.0): + super(Gt2Solov2Target, self).__init__() + self.num_grids = num_grids + self.scale_ranges = scale_ranges + self.coord_sigma = coord_sigma + self.sampling_ratio = sampling_ratio + + def _scale_size(self, im, scale): + h, w = im.shape[:2] + new_size = (int(w * float(scale) + 0.5), int(h * float(scale) + 0.5)) + resized_img = cv2.resize( + im, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) + return resized_img + + def __call__(self, samples, context=None): + sample_id = 0 + for sample in samples: + gt_bboxes_raw = sample['gt_bbox'] + gt_labels_raw = sample['gt_class'] + im_c, im_h, im_w = sample['image'].shape[:] + gt_masks_raw = sample['gt_segm'].astype(np.uint8) + mask_feat_size = [ + int(im_h / self.sampling_ratio), int(im_w / self.sampling_ratio) + ] + gt_areas = np.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * + (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1])) + ins_ind_label_list = [] + idx = 0 + for (lower_bound, upper_bound), num_grid \ + in zip(self.scale_ranges, self.num_grids): + + hit_indices = ((gt_areas >= lower_bound) & + (gt_areas <= upper_bound)).nonzero()[0] + num_ins = len(hit_indices) + + ins_label = [] + grid_order = [] + cate_label = np.zeros([num_grid, num_grid], dtype=np.int64) + ins_ind_label = np.zeros([num_grid**2], dtype=np.bool) + + if num_ins == 0: + ins_label = np.zeros( + [1, mask_feat_size[0], mask_feat_size[1]], + dtype=np.uint8) + ins_ind_label_list.append(ins_ind_label) + sample['cate_label{}'.format(idx)] = cate_label.flatten() + sample['ins_label{}'.format(idx)] = ins_label + sample['grid_order{}'.format(idx)] = np.asarray( + [sample_id * num_grid * num_grid + 0]) + idx += 1 + continue + gt_bboxes = gt_bboxes_raw[hit_indices] + gt_labels = gt_labels_raw[hit_indices] + gt_masks = gt_masks_raw[hit_indices, ...] + + half_ws = 0.5 * ( + gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.coord_sigma + half_hs = 0.5 * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.coord_sigma + + for seg_mask, gt_label, half_h, half_w in zip( + gt_masks, gt_labels, half_hs, half_ws): + if seg_mask.sum() == 0: + continue + # mass center + upsampled_size = (mask_feat_size[0] * 4, + mask_feat_size[1] * 4) + center_h, center_w = ndimage.measurements.center_of_mass( + seg_mask) + coord_w = int( + (center_w / upsampled_size[1]) // (1. / num_grid)) + coord_h = int( + (center_h / upsampled_size[0]) // (1. / num_grid)) + + # left, top, right, down + top_box = max(0, + int(((center_h - half_h) / upsampled_size[0]) + // (1. / num_grid))) + down_box = min(num_grid - 1, + int(((center_h + half_h) / upsampled_size[0]) + // (1. / num_grid))) + left_box = max(0, + int(((center_w - half_w) / upsampled_size[1]) + // (1. / num_grid))) + right_box = min(num_grid - 1, + int(((center_w + half_w) / + upsampled_size[1]) // (1. / num_grid))) + + top = max(top_box, coord_h - 1) + down = min(down_box, coord_h + 1) + left = max(coord_w - 1, left_box) + right = min(right_box, coord_w + 1) + + cate_label[top:(down + 1), left:(right + 1)] = gt_label + seg_mask = self._scale_size( + seg_mask, scale=1. / self.sampling_ratio) + for i in range(top, down + 1): + for j in range(left, right + 1): + label = int(i * num_grid + j) + cur_ins_label = np.zeros( + [mask_feat_size[0], mask_feat_size[1]], + dtype=np.uint8) + cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[ + 1]] = seg_mask + ins_label.append(cur_ins_label) + ins_ind_label[label] = True + grid_order.append( + [sample_id * num_grid * num_grid + label]) + if ins_label == []: + ins_label = np.zeros( + [1, mask_feat_size[0], mask_feat_size[1]], + dtype=np.uint8) + ins_ind_label_list.append(ins_ind_label) + sample['cate_label{}'.format(idx)] = cate_label.flatten() + sample['ins_label{}'.format(idx)] = ins_label + sample['grid_order{}'.format(idx)] = np.asarray( + [sample_id * num_grid * num_grid + 0]) + else: + ins_label = np.stack(ins_label, axis=0) + ins_ind_label_list.append(ins_ind_label) + sample['cate_label{}'.format(idx)] = cate_label.flatten() + sample['ins_label{}'.format(idx)] = ins_label + sample['grid_order{}'.format(idx)] = np.asarray(grid_order) + assert len(grid_order) > 0 + idx += 1 + ins_ind_labels = np.concatenate([ + ins_ind_labels_level_img + for ins_ind_labels_level_img in ins_ind_label_list + ]) + fg_num = np.sum(ins_ind_labels) + sample['fg_num'] = fg_num + sample_id += 1 + + return samples diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/transform/gridmask_utils.py b/VisualFL/depends/PaddleDetection/ppdet/data/transform/gridmask_utils.py new file mode 100644 index 000000000..a23e69b20 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/transform/gridmask_utils.py @@ -0,0 +1,83 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np +from PIL import Image + + +class GridMask(object): + def __init__(self, + use_h=True, + use_w=True, + rotate=1, + offset=False, + ratio=0.5, + mode=1, + prob=0.7, + upper_iter=360000): + super(GridMask, self).__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.prob = prob + self.st_prob = prob + self.upper_iter = upper_iter + + def __call__(self, x, curr_iter): + self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) + if np.random.rand() > self.prob: + return x + _, h, w = x.shape + hh = int(1.5 * h) + ww = int(1.5 * w) + d = np.random.randint(2, h) + self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) + mask = np.ones((hh, ww), np.float32) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + if self.use_h: + for i in range(hh // d): + s = d * i + st_h + t = min(s + self.l, hh) + mask[s:t, :] *= 0 + if self.use_w: + for i in range(ww // d): + s = d * i + st_w + t = min(s + self.l, ww) + mask[:, s:t] *= 0 + + r = np.random.randint(self.rotate) + mask = Image.fromarray(np.uint8(mask)) + mask = mask.rotate(r) + mask = np.asarray(mask) + mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) // 2 + + w].astype(np.float32) + + if self.mode == 1: + mask = 1 - mask + mask = np.expand_dims(mask, axis=0) + if self.offset: + offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32) + x = (x * mask + offset * (1 - mask)).astype(x.dtype) + else: + x = (x * mask).astype(x.dtype) + + return x diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/transform/op_helper.py b/VisualFL/depends/PaddleDetection/ppdet/data/transform/op_helper.py new file mode 100644 index 000000000..02d219546 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/transform/op_helper.py @@ -0,0 +1,464 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# this file contains helper methods for BBOX processing + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import random +import math +import cv2 + + +def meet_emit_constraint(src_bbox, sample_bbox): + center_x = (src_bbox[2] + src_bbox[0]) / 2 + center_y = (src_bbox[3] + src_bbox[1]) / 2 + if center_x >= sample_bbox[0] and \ + center_x <= sample_bbox[2] and \ + center_y >= sample_bbox[1] and \ + center_y <= sample_bbox[3]: + return True + return False + + +def clip_bbox(src_bbox): + src_bbox[0] = max(min(src_bbox[0], 1.0), 0.0) + src_bbox[1] = max(min(src_bbox[1], 1.0), 0.0) + src_bbox[2] = max(min(src_bbox[2], 1.0), 0.0) + src_bbox[3] = max(min(src_bbox[3], 1.0), 0.0) + return src_bbox + + +def bbox_area(src_bbox): + if src_bbox[2] < src_bbox[0] or src_bbox[3] < src_bbox[1]: + return 0. + else: + width = src_bbox[2] - src_bbox[0] + height = src_bbox[3] - src_bbox[1] + return width * height + + +def is_overlap(object_bbox, sample_bbox): + if object_bbox[0] >= sample_bbox[2] or \ + object_bbox[2] <= sample_bbox[0] or \ + object_bbox[1] >= sample_bbox[3] or \ + object_bbox[3] <= sample_bbox[1]: + return False + else: + return True + + +def filter_and_process(sample_bbox, bboxes, labels, scores=None, + keypoints=None): + new_bboxes = [] + new_labels = [] + new_scores = [] + new_keypoints = [] + new_kp_ignore = [] + for i in range(len(bboxes)): + new_bbox = [0, 0, 0, 0] + obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]] + if not meet_emit_constraint(obj_bbox, sample_bbox): + continue + if not is_overlap(obj_bbox, sample_bbox): + continue + sample_width = sample_bbox[2] - sample_bbox[0] + sample_height = sample_bbox[3] - sample_bbox[1] + new_bbox[0] = (obj_bbox[0] - sample_bbox[0]) / sample_width + new_bbox[1] = (obj_bbox[1] - sample_bbox[1]) / sample_height + new_bbox[2] = (obj_bbox[2] - sample_bbox[0]) / sample_width + new_bbox[3] = (obj_bbox[3] - sample_bbox[1]) / sample_height + new_bbox = clip_bbox(new_bbox) + if bbox_area(new_bbox) > 0: + new_bboxes.append(new_bbox) + new_labels.append([labels[i][0]]) + if scores is not None: + new_scores.append([scores[i][0]]) + if keypoints is not None: + sample_keypoint = keypoints[0][i] + for j in range(len(sample_keypoint)): + kp_len = sample_height if j % 2 else sample_width + sample_coord = sample_bbox[1] if j % 2 else sample_bbox[0] + sample_keypoint[j] = ( + sample_keypoint[j] - sample_coord) / kp_len + sample_keypoint[j] = max(min(sample_keypoint[j], 1.0), 0.0) + new_keypoints.append(sample_keypoint) + new_kp_ignore.append(keypoints[1][i]) + + bboxes = np.array(new_bboxes) + labels = np.array(new_labels) + scores = np.array(new_scores) + if keypoints is not None: + keypoints = np.array(new_keypoints) + new_kp_ignore = np.array(new_kp_ignore) + return bboxes, labels, scores, (keypoints, new_kp_ignore) + return bboxes, labels, scores + + +def bbox_area_sampling(bboxes, labels, scores, target_size, min_size): + new_bboxes = [] + new_labels = [] + new_scores = [] + for i, bbox in enumerate(bboxes): + w = float((bbox[2] - bbox[0]) * target_size) + h = float((bbox[3] - bbox[1]) * target_size) + if w * h < float(min_size * min_size): + continue + else: + new_bboxes.append(bbox) + new_labels.append(labels[i]) + if scores is not None and scores.size != 0: + new_scores.append(scores[i]) + bboxes = np.array(new_bboxes) + labels = np.array(new_labels) + scores = np.array(new_scores) + return bboxes, labels, scores + + +def generate_sample_bbox(sampler): + scale = np.random.uniform(sampler[2], sampler[3]) + aspect_ratio = np.random.uniform(sampler[4], sampler[5]) + aspect_ratio = max(aspect_ratio, (scale**2.0)) + aspect_ratio = min(aspect_ratio, 1 / (scale**2.0)) + bbox_width = scale * (aspect_ratio**0.5) + bbox_height = scale / (aspect_ratio**0.5) + xmin_bound = 1 - bbox_width + ymin_bound = 1 - bbox_height + xmin = np.random.uniform(0, xmin_bound) + ymin = np.random.uniform(0, ymin_bound) + xmax = xmin + bbox_width + ymax = ymin + bbox_height + sampled_bbox = [xmin, ymin, xmax, ymax] + return sampled_bbox + + +def generate_sample_bbox_square(sampler, image_width, image_height): + scale = np.random.uniform(sampler[2], sampler[3]) + aspect_ratio = np.random.uniform(sampler[4], sampler[5]) + aspect_ratio = max(aspect_ratio, (scale**2.0)) + aspect_ratio = min(aspect_ratio, 1 / (scale**2.0)) + bbox_width = scale * (aspect_ratio**0.5) + bbox_height = scale / (aspect_ratio**0.5) + if image_height < image_width: + bbox_width = bbox_height * image_height / image_width + else: + bbox_height = bbox_width * image_width / image_height + xmin_bound = 1 - bbox_width + ymin_bound = 1 - bbox_height + xmin = np.random.uniform(0, xmin_bound) + ymin = np.random.uniform(0, ymin_bound) + xmax = xmin + bbox_width + ymax = ymin + bbox_height + sampled_bbox = [xmin, ymin, xmax, ymax] + return sampled_bbox + + +def data_anchor_sampling(bbox_labels, image_width, image_height, scale_array, + resize_width): + num_gt = len(bbox_labels) + # np.random.randint range: [low, high) + rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0 + + if num_gt != 0: + norm_xmin = bbox_labels[rand_idx][0] + norm_ymin = bbox_labels[rand_idx][1] + norm_xmax = bbox_labels[rand_idx][2] + norm_ymax = bbox_labels[rand_idx][3] + + xmin = norm_xmin * image_width + ymin = norm_ymin * image_height + wid = image_width * (norm_xmax - norm_xmin) + hei = image_height * (norm_ymax - norm_ymin) + range_size = 0 + + area = wid * hei + for scale_ind in range(0, len(scale_array) - 1): + if area > scale_array[scale_ind] ** 2 and area < \ + scale_array[scale_ind + 1] ** 2: + range_size = scale_ind + 1 + break + + if area > scale_array[len(scale_array) - 2]**2: + range_size = len(scale_array) - 2 + + scale_choose = 0.0 + if range_size == 0: + rand_idx_size = 0 + else: + # np.random.randint range: [low, high) + rng_rand_size = np.random.randint(0, range_size + 1) + rand_idx_size = rng_rand_size % (range_size + 1) + + if rand_idx_size == range_size: + min_resize_val = scale_array[rand_idx_size] / 2.0 + max_resize_val = min(2.0 * scale_array[rand_idx_size], + 2 * math.sqrt(wid * hei)) + scale_choose = random.uniform(min_resize_val, max_resize_val) + else: + min_resize_val = scale_array[rand_idx_size] / 2.0 + max_resize_val = 2.0 * scale_array[rand_idx_size] + scale_choose = random.uniform(min_resize_val, max_resize_val) + + sample_bbox_size = wid * resize_width / scale_choose + + w_off_orig = 0.0 + h_off_orig = 0.0 + if sample_bbox_size < max(image_height, image_width): + if wid <= sample_bbox_size: + w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size, + xmin) + else: + w_off_orig = np.random.uniform(xmin, + xmin + wid - sample_bbox_size) + + if hei <= sample_bbox_size: + h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size, + ymin) + else: + h_off_orig = np.random.uniform(ymin, + ymin + hei - sample_bbox_size) + + else: + w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0) + h_off_orig = np.random.uniform(image_height - sample_bbox_size, 0.0) + + w_off_orig = math.floor(w_off_orig) + h_off_orig = math.floor(h_off_orig) + + # Figure out top left coordinates. + w_off = float(w_off_orig / image_width) + h_off = float(h_off_orig / image_height) + + sampled_bbox = [ + w_off, h_off, w_off + float(sample_bbox_size / image_width), + h_off + float(sample_bbox_size / image_height) + ] + return sampled_bbox + else: + return 0 + + +def jaccard_overlap(sample_bbox, object_bbox): + if sample_bbox[0] >= object_bbox[2] or \ + sample_bbox[2] <= object_bbox[0] or \ + sample_bbox[1] >= object_bbox[3] or \ + sample_bbox[3] <= object_bbox[1]: + return 0 + intersect_xmin = max(sample_bbox[0], object_bbox[0]) + intersect_ymin = max(sample_bbox[1], object_bbox[1]) + intersect_xmax = min(sample_bbox[2], object_bbox[2]) + intersect_ymax = min(sample_bbox[3], object_bbox[3]) + intersect_size = (intersect_xmax - intersect_xmin) * ( + intersect_ymax - intersect_ymin) + sample_bbox_size = bbox_area(sample_bbox) + object_bbox_size = bbox_area(object_bbox) + overlap = intersect_size / ( + sample_bbox_size + object_bbox_size - intersect_size) + return overlap + + +def intersect_bbox(bbox1, bbox2): + if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \ + bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]: + intersection_box = [0.0, 0.0, 0.0, 0.0] + else: + intersection_box = [ + max(bbox1[0], bbox2[0]), max(bbox1[1], bbox2[1]), + min(bbox1[2], bbox2[2]), min(bbox1[3], bbox2[3]) + ] + return intersection_box + + +def bbox_coverage(bbox1, bbox2): + inter_box = intersect_bbox(bbox1, bbox2) + intersect_size = bbox_area(inter_box) + + if intersect_size > 0: + bbox1_size = bbox_area(bbox1) + return intersect_size / bbox1_size + else: + return 0. + + +def satisfy_sample_constraint(sampler, + sample_bbox, + gt_bboxes, + satisfy_all=False): + if sampler[6] == 0 and sampler[7] == 0: + return True + satisfied = [] + for i in range(len(gt_bboxes)): + object_bbox = [ + gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3] + ] + overlap = jaccard_overlap(sample_bbox, object_bbox) + if sampler[6] != 0 and \ + overlap < sampler[6]: + satisfied.append(False) + continue + if sampler[7] != 0 and \ + overlap > sampler[7]: + satisfied.append(False) + continue + satisfied.append(True) + if not satisfy_all: + return True + + if satisfy_all: + return np.all(satisfied) + else: + return False + + +def satisfy_sample_constraint_coverage(sampler, sample_bbox, gt_bboxes): + if sampler[6] == 0 and sampler[7] == 0: + has_jaccard_overlap = False + else: + has_jaccard_overlap = True + if sampler[8] == 0 and sampler[9] == 0: + has_object_coverage = False + else: + has_object_coverage = True + + if not has_jaccard_overlap and not has_object_coverage: + return True + found = False + for i in range(len(gt_bboxes)): + object_bbox = [ + gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3] + ] + if has_jaccard_overlap: + overlap = jaccard_overlap(sample_bbox, object_bbox) + if sampler[6] != 0 and \ + overlap < sampler[6]: + continue + if sampler[7] != 0 and \ + overlap > sampler[7]: + continue + found = True + if has_object_coverage: + object_coverage = bbox_coverage(object_bbox, sample_bbox) + if sampler[8] != 0 and \ + object_coverage < sampler[8]: + continue + if sampler[9] != 0 and \ + object_coverage > sampler[9]: + continue + found = True + if found: + return True + return found + + +def crop_image_sampling(img, sample_bbox, image_width, image_height, + target_size): + # no clipping here + xmin = int(sample_bbox[0] * image_width) + xmax = int(sample_bbox[2] * image_width) + ymin = int(sample_bbox[1] * image_height) + ymax = int(sample_bbox[3] * image_height) + + w_off = xmin + h_off = ymin + width = xmax - xmin + height = ymax - ymin + cross_xmin = max(0.0, float(w_off)) + cross_ymin = max(0.0, float(h_off)) + cross_xmax = min(float(w_off + width - 1.0), float(image_width)) + cross_ymax = min(float(h_off + height - 1.0), float(image_height)) + cross_width = cross_xmax - cross_xmin + cross_height = cross_ymax - cross_ymin + + roi_xmin = 0 if w_off >= 0 else abs(w_off) + roi_ymin = 0 if h_off >= 0 else abs(h_off) + roi_width = cross_width + roi_height = cross_height + + roi_y1 = int(roi_ymin) + roi_y2 = int(roi_ymin + roi_height) + roi_x1 = int(roi_xmin) + roi_x2 = int(roi_xmin + roi_width) + + cross_y1 = int(cross_ymin) + cross_y2 = int(cross_ymin + cross_height) + cross_x1 = int(cross_xmin) + cross_x2 = int(cross_xmin + cross_width) + + sample_img = np.zeros((height, width, 3)) + sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \ + img[cross_y1: cross_y2, cross_x1: cross_x2] + + sample_img = cv2.resize( + sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA) + + return sample_img + + +def is_poly(segm): + assert isinstance(segm, (list, dict)), \ + "Invalid segm type: {}".format(type(segm)) + return isinstance(segm, list) + + +def gaussian_radius(bbox_size, min_overlap): + height, width = bbox_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = np.sqrt(b1**2 - 4 * a1 * c1) + radius1 = (b1 - sq1) / (2 * a1) + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = np.sqrt(b2**2 - 4 * a2 * c2) + radius2 = (b2 - sq2) / (2 * a2) + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = np.sqrt(b3**2 - 4 * a3 * c3) + radius3 = (b3 + sq3) / (2 * a3) + return min(radius1, radius2, radius3) + + +def draw_gaussian(heatmap, center, radius, k=1, delte=6): + diameter = 2 * radius + 1 + sigma = diameter / delte + gaussian = gaussian2D((diameter, diameter), sigma_x=sigma, sigma_y=sigma) + + x, y = center + + height, width = heatmap.shape[0:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[radius - top:radius + bottom, radius - left: + radius + right] + np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) + + +def gaussian2D(shape, sigma_x=1, sigma_y=1): + m, n = [(ss - 1.) / 2. for ss in shape] + y, x = np.ogrid[-m:m + 1, -n:n + 1] + + h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y * + sigma_y))) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + return h diff --git a/VisualFL/depends/PaddleDetection/ppdet/data/transform/operators.py b/VisualFL/depends/PaddleDetection/ppdet/data/transform/operators.py new file mode 100644 index 000000000..4646e2582 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/data/transform/operators.py @@ -0,0 +1,2679 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# function: +# operators to process sample, +# eg: decode/resize/crop image + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence + +from numbers import Number + +import uuid +import logging +import random +import math +import numpy as np +import os +import six + +import cv2 +from PIL import Image, ImageEnhance, ImageDraw, ImageOps + +from ppdet.core.workspace import serializable +from ppdet.modeling.ops import AnchorGrid + +from .op_helper import (satisfy_sample_constraint, filter_and_process, + generate_sample_bbox, clip_bbox, data_anchor_sampling, + satisfy_sample_constraint_coverage, crop_image_sampling, + generate_sample_bbox_square, bbox_area_sampling, + is_poly, gaussian_radius, draw_gaussian) + +logger = logging.getLogger(__name__) + +registered_ops = [] + + +def register_op(cls): + registered_ops.append(cls.__name__) + if not hasattr(BaseOperator, cls.__name__): + setattr(BaseOperator, cls.__name__, cls) + else: + raise KeyError("The {} class has been registered.".format(cls.__name__)) + return serializable(cls) + + +class BboxError(ValueError): + pass + + +class ImageError(ValueError): + pass + + +class BaseOperator(object): + def __init__(self, name=None): + if name is None: + name = self.__class__.__name__ + self._id = name + '_' + str(uuid.uuid4())[-6:] + + def __call__(self, sample, context=None): + """ Process a sample. + Args: + sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx} + context (dict): info about this sample processing + Returns: + result (dict): a processed sample + """ + return sample + + def __str__(self): + return str(self._id) + + +@register_op +class DecodeImage(BaseOperator): + def __init__(self, to_rgb=True, with_mixup=False, with_cutmix=False): + """ Transform the image data to numpy format. + Args: + to_rgb (bool): whether to convert BGR to RGB + with_mixup (bool): whether or not to mixup image and gt_bbbox/gt_score + with_cutmix (bool): whether or not to cutmix image and gt_bbbox/gt_score + """ + + super(DecodeImage, self).__init__() + self.to_rgb = to_rgb + self.with_mixup = with_mixup + self.with_cutmix = with_cutmix + if not isinstance(self.to_rgb, bool): + raise TypeError("{}: input type is invalid.".format(self)) + if not isinstance(self.with_mixup, bool): + raise TypeError("{}: input type is invalid.".format(self)) + + def __call__(self, sample, context=None): + """ load image if 'im_file' field is not empty but 'image' is""" + if 'image' not in sample: + with open(sample['im_file'], 'rb') as f: + sample['image'] = f.read() + + im = sample['image'] + data = np.frombuffer(im, dtype='uint8') + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + + if self.to_rgb: + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + sample['image'] = im + + if 'h' not in sample: + sample['h'] = im.shape[0] + elif sample['h'] != im.shape[0]: + logger.warn( + "The actual image height: {} is not equal to the " + "height: {} in annotation, and update sample['h'] by actual " + "image height.".format(im.shape[0], sample['h'])) + sample['h'] = im.shape[0] + if 'w' not in sample: + sample['w'] = im.shape[1] + elif sample['w'] != im.shape[1]: + logger.warn( + "The actual image width: {} is not equal to the " + "width: {} in annotation, and update sample['w'] by actual " + "image width.".format(im.shape[1], sample['w'])) + sample['w'] = im.shape[1] + + # make default im_info with [h, w, 1] + sample['im_info'] = np.array( + [im.shape[0], im.shape[1], 1.], dtype=np.float32) + + # decode mixup image + if self.with_mixup and 'mixup' in sample: + self.__call__(sample['mixup'], context) + + # decode cutmix image + if self.with_cutmix and 'cutmix' in sample: + self.__call__(sample['cutmix'], context) + + # decode semantic label + if 'semantic' in sample.keys() and sample['semantic'] is not None: + sem_file = sample['semantic'] + sem = cv2.imread(sem_file, cv2.IMREAD_GRAYSCALE) + sample['semantic'] = sem.astype('int32') + + return sample + + +@register_op +class MultiscaleTestResize(BaseOperator): + def __init__(self, + origin_target_size=800, + origin_max_size=1333, + target_size=[], + max_size=2000, + interp=cv2.INTER_LINEAR, + use_flip=True): + """ + Rescale image to the each size in target size, and capped at max_size. + Args: + origin_target_size(int): original target size of image's short side. + origin_max_size(int): original max size of image. + target_size (list): A list of target sizes of image's short side. + max_size (int): the max size of image. + interp (int): the interpolation method. + use_flip (bool): whether use flip augmentation. + """ + super(MultiscaleTestResize, self).__init__() + self.origin_target_size = int(origin_target_size) + self.origin_max_size = int(origin_max_size) + self.max_size = int(max_size) + self.interp = int(interp) + self.use_flip = use_flip + + if not isinstance(target_size, list): + raise TypeError( + "Type of target_size is invalid. Must be List, now is {}". + format(type(target_size))) + self.target_size = target_size + if not (isinstance(self.origin_target_size, int) and isinstance( + self.origin_max_size, int) and isinstance(self.max_size, int) + and isinstance(self.interp, int)): + raise TypeError("{}: input type is invalid.".format(self)) + + def __call__(self, sample, context=None): + """ Resize the image numpy for multi-scale test. + """ + origin_ims = {} + im = sample['image'] + if not isinstance(im, np.ndarray): + raise TypeError("{}: image type is not numpy.".format(self)) + if len(im.shape) != 3: + raise ImageError('{}: image is not 3-dimensional.'.format(self)) + im_shape = im.shape + im_size_min = np.min(im_shape[0:2]) + im_size_max = np.max(im_shape[0:2]) + if float(im_size_min) == 0: + raise ZeroDivisionError('{}: min size of image is 0'.format(self)) + base_name_list = ['image'] + origin_ims['image'] = im + if self.use_flip: + sample['image_flip'] = im[:, ::-1, :] + base_name_list.append('image_flip') + origin_ims['image_flip'] = sample['image_flip'] + + for base_name in base_name_list: + im_scale = float(self.origin_target_size) / float(im_size_min) + # Prevent the biggest axis from being more than max_size + if np.round(im_scale * im_size_max) > self.origin_max_size: + im_scale = float(self.origin_max_size) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + + resize_w = np.round(im_scale_x * float(im_shape[1])) + resize_h = np.round(im_scale_y * float(im_shape[0])) + im_resize = cv2.resize( + origin_ims[base_name], + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + + sample[base_name] = im_resize + info_name = 'im_info' if base_name == 'image' else 'im_info_image_flip' + sample[base_name] = im_resize + sample[info_name] = np.array( + [resize_h, resize_w, im_scale], dtype=np.float32) + for i, size in enumerate(self.target_size): + im_scale = float(size) / float(im_size_min) + if np.round(im_scale * im_size_max) > self.max_size: + im_scale = float(self.max_size) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + resize_w = np.round(im_scale_x * float(im_shape[1])) + resize_h = np.round(im_scale_y * float(im_shape[0])) + im_resize = cv2.resize( + origin_ims[base_name], + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + + im_info = [resize_h, resize_w, im_scale] + # hard-code here, must be consistent with + # ppdet/modeling/architectures/input_helper.py + name = base_name + '_scale_' + str(i) + info_name = 'im_info_' + name + sample[name] = im_resize + sample[info_name] = np.array( + [resize_h, resize_w, im_scale], dtype=np.float32) + return sample + + +@register_op +class ResizeImage(BaseOperator): + def __init__(self, + target_size=0, + max_size=0, + interp=cv2.INTER_LINEAR, + use_cv2=True, + resize_box=False): + """ + Rescale image to the specified target size, and capped at max_size + if max_size != 0. + If target_size is list, selected a scale randomly as the specified + target size. + Args: + target_size (int|list): the target size of image's short side, + multi-scale training is adopted when type is list. + max_size (int): the max size of image + interp (int): the interpolation method + use_cv2 (bool): use the cv2 interpolation method or use PIL + interpolation method + resize_box (bool): whether resize ground truth bbox annotations. + """ + super(ResizeImage, self).__init__() + self.max_size = int(max_size) + self.interp = int(interp) + self.use_cv2 = use_cv2 + self.resize_box = resize_box + if not (isinstance(target_size, int) or isinstance(target_size, list)): + raise TypeError( + "Type of target_size is invalid. Must be Integer or List, now is {}". + format(type(target_size))) + self.target_size = target_size + if not (isinstance(self.max_size, int) and isinstance(self.interp, + int)): + raise TypeError("{}: input type is invalid.".format(self)) + + def __call__(self, sample, context=None): + """ Resize the image numpy. + """ + im = sample['image'] + if not isinstance(im, np.ndarray): + raise TypeError("{}: image type is not numpy.".format(self)) + if len(im.shape) != 3: + raise ImageError('{}: image is not 3-dimensional.'.format(self)) + im_shape = im.shape + im_size_min = np.min(im_shape[0:2]) + im_size_max = np.max(im_shape[0:2]) + if isinstance(self.target_size, list): + # Case for multi-scale training + selected_size = random.choice(self.target_size) + else: + selected_size = self.target_size + if float(im_size_min) == 0: + raise ZeroDivisionError('{}: min size of image is 0'.format(self)) + if self.max_size != 0: + im_scale = float(selected_size) / float(im_size_min) + # Prevent the biggest axis from being more than max_size + if np.round(im_scale * im_size_max) > self.max_size: + im_scale = float(self.max_size) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + + resize_w = im_scale_x * float(im_shape[1]) + resize_h = im_scale_y * float(im_shape[0]) + im_info = [resize_h, resize_w, im_scale] + if 'im_info' in sample and sample['im_info'][2] != 1.: + sample['im_info'] = np.append( + list(sample['im_info']), im_info).astype(np.float32) + else: + sample['im_info'] = np.array(im_info).astype(np.float32) + else: + im_scale_x = float(selected_size) / float(im_shape[1]) + im_scale_y = float(selected_size) / float(im_shape[0]) + + resize_w = selected_size + resize_h = selected_size + + if self.use_cv2: + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + else: + if self.max_size != 0: + raise TypeError( + 'If you set max_size to cap the maximum size of image,' + 'please set use_cv2 to True to resize the image.') + im = im.astype('uint8') + im = Image.fromarray(im) + im = im.resize((int(resize_w), int(resize_h)), self.interp) + im = np.array(im) + sample['image'] = im + sample['scale_factor'] = [im_scale_x, im_scale_y] * 2 + if 'gt_bbox' in sample and self.resize_box and len(sample[ + 'gt_bbox']) > 0: + bboxes = sample['gt_bbox'] * sample['scale_factor'] + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, resize_w - 1) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, resize_h - 1) + sample['gt_bbox'] = bboxes + if 'semantic' in sample.keys() and sample['semantic'] is not None: + semantic = sample['semantic'] + semantic = cv2.resize( + semantic.astype('float32'), + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + semantic = np.asarray(semantic).astype('int32') + semantic = np.expand_dims(semantic, 0) + sample['semantic'] = semantic + if 'gt_segm' in sample and len(sample['gt_segm']) > 0: + masks = [ + cv2.resize( + gt_segm, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=cv2.INTER_NEAREST) + for gt_segm in sample['gt_segm'] + ] + sample['gt_segm'] = np.asarray(masks).astype(np.uint8) + + return sample + + +@register_op +class RandomFlipImage(BaseOperator): + def __init__(self, prob=0.5, is_normalized=False, is_mask_flip=False): + """ + Args: + prob (float): the probability of flipping image + is_normalized (bool): whether the bbox scale to [0,1] + is_mask_flip (bool): whether flip the segmentation + """ + super(RandomFlipImage, self).__init__() + self.prob = prob + self.is_normalized = is_normalized + self.is_mask_flip = is_mask_flip + if not (isinstance(self.prob, float) and + isinstance(self.is_normalized, bool) and + isinstance(self.is_mask_flip, bool)): + raise TypeError("{}: input type is invalid.".format(self)) + + def flip_segms(self, segms, height, width): + def _flip_poly(poly, width): + flipped_poly = np.array(poly) + flipped_poly[0::2] = width - np.array(poly[0::2]) - 1 + return flipped_poly.tolist() + + def _flip_rle(rle, height, width): + if 'counts' in rle and type(rle['counts']) == list: + rle = mask_util.frPyObjects(rle, height, width) + mask = mask_util.decode(rle) + mask = mask[:, ::-1] + rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) + return rle + + flipped_segms = [] + for segm in segms: + if is_poly(segm): + # Polygon format + flipped_segms.append([_flip_poly(poly, width) for poly in segm]) + else: + # RLE format + import pycocotools.mask as mask_util + flipped_segms.append(_flip_rle(segm, height, width)) + return flipped_segms + + def flip_keypoint(self, gt_keypoint, width): + for i in range(gt_keypoint.shape[1]): + if i % 2 == 0: + old_x = gt_keypoint[:, i].copy() + if self.is_normalized: + gt_keypoint[:, i] = 1 - old_x + else: + gt_keypoint[:, i] = width - old_x - 1 + return gt_keypoint + + def __call__(self, sample, context=None): + """Filp the image and bounding box. + Operators: + 1. Flip the image numpy. + 2. Transform the bboxes' x coordinates. + (Must judge whether the coordinates are normalized!) + 3. Transform the segmentations' x coordinates. + (Must judge whether the coordinates are normalized!) + Output: + sample: the image, bounding box and segmentation part + in sample are flipped. + """ + + samples = sample + batch_input = True + if not isinstance(samples, Sequence): + batch_input = False + samples = [samples] + for sample in samples: + gt_bbox = sample['gt_bbox'] + im = sample['image'] + if not isinstance(im, np.ndarray): + raise TypeError("{}: image is not a numpy array.".format(self)) + if len(im.shape) != 3: + raise ImageError("{}: image is not 3-dimensional.".format(self)) + height, width, _ = im.shape + if np.random.uniform(0, 1) < self.prob: + im = im[:, ::-1, :] + if gt_bbox.shape[0] == 0: + return sample + oldx1 = gt_bbox[:, 0].copy() + oldx2 = gt_bbox[:, 2].copy() + if self.is_normalized: + gt_bbox[:, 0] = 1 - oldx2 + gt_bbox[:, 2] = 1 - oldx1 + else: + gt_bbox[:, 0] = width - oldx2 - 1 + gt_bbox[:, 2] = width - oldx1 - 1 + if gt_bbox.shape[0] != 0 and ( + gt_bbox[:, 2] < gt_bbox[:, 0]).all(): + m = "{}: invalid box, x2 should be greater than x1".format( + self) + raise BboxError(m) + sample['gt_bbox'] = gt_bbox + if self.is_mask_flip and len(sample['gt_poly']) != 0: + sample['gt_poly'] = self.flip_segms(sample['gt_poly'], + height, width) + if 'gt_keypoint' in sample.keys(): + sample['gt_keypoint'] = self.flip_keypoint( + sample['gt_keypoint'], width) + + if 'semantic' in sample.keys() and sample[ + 'semantic'] is not None: + sample['semantic'] = sample['semantic'][:, ::-1] + + if 'gt_segm' in sample.keys() and sample['gt_segm'] is not None: + sample['gt_segm'] = sample['gt_segm'][:, :, ::-1] + + sample['flipped'] = True + sample['image'] = im + sample = samples if batch_input else samples[0] + return sample + + +@register_op +class RandomErasingImage(BaseOperator): + def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3): + """ + Random Erasing Data Augmentation, see https://arxiv.org/abs/1708.04896 + Args: + prob (float): probability to carry out random erasing + sl (float): lower limit of the erasing area ratio + sh (float): upper limit of the erasing area ratio + r1 (float): aspect ratio of the erasing region + """ + super(RandomErasingImage, self).__init__() + self.prob = prob + self.sl = sl + self.sh = sh + self.r1 = r1 + + def __call__(self, sample, context=None): + samples = sample + batch_input = True + if not isinstance(samples, Sequence): + batch_input = False + samples = [samples] + for sample in samples: + gt_bbox = sample['gt_bbox'] + im = sample['image'] + if not isinstance(im, np.ndarray): + raise TypeError("{}: image is not a numpy array.".format(self)) + if len(im.shape) != 3: + raise ImageError("{}: image is not 3-dimensional.".format(self)) + + for idx in range(gt_bbox.shape[0]): + if self.prob <= np.random.rand(): + continue + + x1, y1, x2, y2 = gt_bbox[idx, :] + w_bbox = x2 - x1 + 1 + h_bbox = y2 - y1 + 1 + area = w_bbox * h_bbox + + target_area = random.uniform(self.sl, self.sh) * area + aspect_ratio = random.uniform(self.r1, 1 / self.r1) + + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + + if w < w_bbox and h < h_bbox: + off_y1 = random.randint(0, int(h_bbox - h)) + off_x1 = random.randint(0, int(w_bbox - w)) + im[int(y1 + off_y1):int(y1 + off_y1 + h), int(x1 + off_x1): + int(x1 + off_x1 + w), :] = 0 + sample['image'] = im + + sample = samples if batch_input else samples[0] + return sample + + +@register_op +class GridMaskOp(BaseOperator): + def __init__(self, + use_h=True, + use_w=True, + rotate=1, + offset=False, + ratio=0.5, + mode=1, + prob=0.7, + upper_iter=360000): + """ + GridMask Data Augmentation, see https://arxiv.org/abs/2001.04086 + Args: + use_h (bool): whether to mask vertically + use_w (boo;): whether to mask horizontally + rotate (float): angle for the mask to rotate + offset (float): mask offset + ratio (float): mask ratio + mode (int): gridmask mode + prob (float): max probability to carry out gridmask + upper_iter (int): suggested to be equal to global max_iter + """ + super(GridMaskOp, self).__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.prob = prob + self.upper_iter = upper_iter + + from .gridmask_utils import GridMask + self.gridmask_op = GridMask( + use_h, + use_w, + rotate=rotate, + offset=offset, + ratio=ratio, + mode=mode, + prob=prob, + upper_iter=upper_iter) + + def __call__(self, sample, context=None): + samples = sample + batch_input = True + if not isinstance(samples, Sequence): + batch_input = False + samples = [samples] + for sample in samples: + sample['image'] = self.gridmask_op(sample['image'], + sample['curr_iter']) + if not batch_input: + samples = samples[0] + return sample + + +@register_op +class AutoAugmentImage(BaseOperator): + def __init__(self, is_normalized=False, autoaug_type="v1"): + """ + Args: + is_normalized (bool): whether the bbox scale to [0,1] + autoaug_type (str): autoaug type, support v0, v1, v2, v3, test + """ + super(AutoAugmentImage, self).__init__() + self.is_normalized = is_normalized + self.autoaug_type = autoaug_type + if not isinstance(self.is_normalized, bool): + raise TypeError("{}: input type is invalid.".format(self)) + + def __call__(self, sample, context=None): + """ + Learning Data Augmentation Strategies for Object Detection, see https://arxiv.org/abs/1906.11172 + """ + samples = sample + batch_input = True + if not isinstance(samples, Sequence): + batch_input = False + samples = [samples] + for sample in samples: + gt_bbox = sample['gt_bbox'] + im = sample['image'] + if not isinstance(im, np.ndarray): + raise TypeError("{}: image is not a numpy array.".format(self)) + if len(im.shape) != 3: + raise ImageError("{}: image is not 3-dimensional.".format(self)) + if len(gt_bbox) == 0: + continue + + # gt_boxes : [x1, y1, x2, y2] + # norm_gt_boxes: [y1, x1, y2, x2] + height, width, _ = im.shape + norm_gt_bbox = np.ones_like(gt_bbox, dtype=np.float32) + if not self.is_normalized: + norm_gt_bbox[:, 0] = gt_bbox[:, 1] / float(height) + norm_gt_bbox[:, 1] = gt_bbox[:, 0] / float(width) + norm_gt_bbox[:, 2] = gt_bbox[:, 3] / float(height) + norm_gt_bbox[:, 3] = gt_bbox[:, 2] / float(width) + else: + norm_gt_bbox[:, 0] = gt_bbox[:, 1] + norm_gt_bbox[:, 1] = gt_bbox[:, 0] + norm_gt_bbox[:, 2] = gt_bbox[:, 3] + norm_gt_bbox[:, 3] = gt_bbox[:, 2] + + from .autoaugment_utils import distort_image_with_autoaugment + im, norm_gt_bbox = distort_image_with_autoaugment(im, norm_gt_bbox, + self.autoaug_type) + if not self.is_normalized: + gt_bbox[:, 0] = norm_gt_bbox[:, 1] * float(width) + gt_bbox[:, 1] = norm_gt_bbox[:, 0] * float(height) + gt_bbox[:, 2] = norm_gt_bbox[:, 3] * float(width) + gt_bbox[:, 3] = norm_gt_bbox[:, 2] * float(height) + else: + gt_bbox[:, 0] = norm_gt_bbox[:, 1] + gt_bbox[:, 1] = norm_gt_bbox[:, 0] + gt_bbox[:, 2] = norm_gt_bbox[:, 3] + gt_bbox[:, 3] = norm_gt_bbox[:, 2] + + sample['gt_bbox'] = gt_bbox + sample['image'] = im + + sample = samples if batch_input else samples[0] + return sample + + +@register_op +class NormalizeImage(BaseOperator): + def __init__(self, + mean=[0.485, 0.456, 0.406], + std=[1, 1, 1], + is_scale=True, + is_channel_first=True): + """ + Args: + mean (list): the pixel mean + std (list): the pixel variance + """ + super(NormalizeImage, self).__init__() + self.mean = mean + self.std = std + self.is_scale = is_scale + self.is_channel_first = is_channel_first + if not (isinstance(self.mean, list) and isinstance(self.std, list) and + isinstance(self.is_scale, bool)): + raise TypeError("{}: input type is invalid.".format(self)) + from functools import reduce + if reduce(lambda x, y: x * y, self.std) == 0: + raise ValueError('{}: std is invalid!'.format(self)) + + def __call__(self, sample, context=None): + """Normalize the image. + Operators: + 1.(optional) Scale the image to [0,1] + 2. Each pixel minus mean and is divided by std + """ + samples = sample + batch_input = True + if not isinstance(samples, Sequence): + batch_input = False + samples = [samples] + for sample in samples: + for k in sample.keys(): + # hard code + if k.startswith('image'): + im = sample[k] + im = im.astype(np.float32, copy=False) + if self.is_channel_first: + mean = np.array(self.mean)[:, np.newaxis, np.newaxis] + std = np.array(self.std)[:, np.newaxis, np.newaxis] + else: + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + if self.is_scale: + im = im / 255.0 + im -= mean + im /= std + sample[k] = im + if not batch_input: + samples = samples[0] + return samples + + +@register_op +class RandomDistort(BaseOperator): + def __init__(self, + brightness_lower=0.5, + brightness_upper=1.5, + contrast_lower=0.5, + contrast_upper=1.5, + saturation_lower=0.5, + saturation_upper=1.5, + hue_lower=-18, + hue_upper=18, + brightness_prob=0.5, + contrast_prob=0.5, + saturation_prob=0.5, + hue_prob=0.5, + count=4, + is_order=False): + """ + Args: + brightness_lower/ brightness_upper (float): the brightness + between brightness_lower and brightness_upper + contrast_lower/ contrast_upper (float): the contrast between + contrast_lower and contrast_lower + saturation_lower/ saturation_upper (float): the saturation + between saturation_lower and saturation_upper + hue_lower/ hue_upper (float): the hue between + hue_lower and hue_upper + brightness_prob (float): the probability of changing brightness + contrast_prob (float): the probability of changing contrast + saturation_prob (float): the probability of changing saturation + hue_prob (float): the probability of changing hue + count (int): the kinds of doing distrot + is_order (bool): whether determine the order of distortion + """ + super(RandomDistort, self).__init__() + self.brightness_lower = brightness_lower + self.brightness_upper = brightness_upper + self.contrast_lower = contrast_lower + self.contrast_upper = contrast_upper + self.saturation_lower = saturation_lower + self.saturation_upper = saturation_upper + self.hue_lower = hue_lower + self.hue_upper = hue_upper + self.brightness_prob = brightness_prob + self.contrast_prob = contrast_prob + self.saturation_prob = saturation_prob + self.hue_prob = hue_prob + self.count = count + self.is_order = is_order + + def random_brightness(self, img): + brightness_delta = np.random.uniform(self.brightness_lower, + self.brightness_upper) + prob = np.random.uniform(0, 1) + if prob < self.brightness_prob: + img = ImageEnhance.Brightness(img).enhance(brightness_delta) + return img + + def random_contrast(self, img): + contrast_delta = np.random.uniform(self.contrast_lower, + self.contrast_upper) + prob = np.random.uniform(0, 1) + if prob < self.contrast_prob: + img = ImageEnhance.Contrast(img).enhance(contrast_delta) + return img + + def random_saturation(self, img): + saturation_delta = np.random.uniform(self.saturation_lower, + self.saturation_upper) + prob = np.random.uniform(0, 1) + if prob < self.saturation_prob: + img = ImageEnhance.Color(img).enhance(saturation_delta) + return img + + def random_hue(self, img): + hue_delta = np.random.uniform(self.hue_lower, self.hue_upper) + prob = np.random.uniform(0, 1) + if prob < self.hue_prob: + img = np.array(img.convert('HSV')) + img[:, :, 0] = img[:, :, 0] + hue_delta + img = Image.fromarray(img, mode='HSV').convert('RGB') + return img + + def __call__(self, sample, context): + """random distort the image""" + ops = [ + self.random_brightness, self.random_contrast, + self.random_saturation, self.random_hue + ] + if self.is_order: + prob = np.random.uniform(0, 1) + if prob < 0.5: + ops = [ + self.random_brightness, + self.random_saturation, + self.random_hue, + self.random_contrast, + ] + else: + ops = random.sample(ops, self.count) + assert 'image' in sample, "image data not found" + im = sample['image'] + im = Image.fromarray(im) + for id in range(self.count): + im = ops[id](im) + im = np.asarray(im) + sample['image'] = im + return sample + + +@register_op +class ExpandImage(BaseOperator): + def __init__(self, max_ratio, prob, mean=[127.5, 127.5, 127.5]): + """ + Args: + max_ratio (float): the ratio of expanding + prob (float): the probability of expanding image + mean (list): the pixel mean + """ + super(ExpandImage, self).__init__() + self.max_ratio = max_ratio + self.mean = mean + self.prob = prob + + def __call__(self, sample, context): + """ + Expand the image and modify bounding box. + Operators: + 1. Scale the image width and height. + 2. Construct new images with new height and width. + 3. Fill the new image with the mean. + 4. Put original imge into new image. + 5. Rescale the bounding box. + 6. Determine if the new bbox is satisfied in the new image. + Returns: + sample: the image, bounding box are replaced. + """ + + prob = np.random.uniform(0, 1) + assert 'image' in sample, 'not found image data' + im = sample['image'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + im_width = sample['w'] + im_height = sample['h'] + if prob < self.prob: + if self.max_ratio - 1 >= 0.01: + expand_ratio = np.random.uniform(1, self.max_ratio) + height = int(im_height * expand_ratio) + width = int(im_width * expand_ratio) + h_off = math.floor(np.random.uniform(0, height - im_height)) + w_off = math.floor(np.random.uniform(0, width - im_width)) + expand_bbox = [ + -w_off / im_width, -h_off / im_height, + (width - w_off) / im_width, (height - h_off) / im_height + ] + expand_im = np.ones((height, width, 3)) + expand_im = np.uint8(expand_im * np.squeeze(self.mean)) + expand_im = Image.fromarray(expand_im) + im = Image.fromarray(im) + expand_im.paste(im, (int(w_off), int(h_off))) + expand_im = np.asarray(expand_im) + if 'gt_keypoint' in sample.keys( + ) and 'keypoint_ignore' in sample.keys(): + keypoints = (sample['gt_keypoint'], + sample['keypoint_ignore']) + gt_bbox, gt_class, _, gt_keypoints = filter_and_process( + expand_bbox, gt_bbox, gt_class, keypoints=keypoints) + sample['gt_keypoint'] = gt_keypoints[0] + sample['keypoint_ignore'] = gt_keypoints[1] + else: + gt_bbox, gt_class, _ = filter_and_process(expand_bbox, + gt_bbox, gt_class) + sample['image'] = expand_im + sample['gt_bbox'] = gt_bbox + sample['gt_class'] = gt_class + sample['w'] = width + sample['h'] = height + + return sample + + +@register_op +class CropImage(BaseOperator): + def __init__(self, batch_sampler, satisfy_all=False, avoid_no_bbox=True): + """ + Args: + batch_sampler (list): Multiple sets of different + parameters for cropping. + satisfy_all (bool): whether all boxes must satisfy. + e.g.[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]] + [max sample, max trial, min scale, max scale, + min aspect ratio, max aspect ratio, + min overlap, max overlap] + avoid_no_bbox (bool): whether to to avoid the + situation where the box does not appear. + """ + super(CropImage, self).__init__() + self.batch_sampler = batch_sampler + self.satisfy_all = satisfy_all + self.avoid_no_bbox = avoid_no_bbox + + def __call__(self, sample, context): + """ + Crop the image and modify bounding box. + Operators: + 1. Scale the image width and height. + 2. Crop the image according to a radom sample. + 3. Rescale the bounding box. + 4. Determine if the new bbox is satisfied in the new image. + Returns: + sample: the image, bounding box are replaced. + """ + assert 'image' in sample, "image data not found" + im = sample['image'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + im_width = sample['w'] + im_height = sample['h'] + gt_score = None + if 'gt_score' in sample: + gt_score = sample['gt_score'] + sampled_bbox = [] + gt_bbox = gt_bbox.tolist() + for sampler in self.batch_sampler: + found = 0 + for i in range(sampler[1]): + if found >= sampler[0]: + break + sample_bbox = generate_sample_bbox(sampler) + if satisfy_sample_constraint(sampler, sample_bbox, gt_bbox, + self.satisfy_all): + sampled_bbox.append(sample_bbox) + found = found + 1 + im = np.array(im) + while sampled_bbox: + idx = int(np.random.uniform(0, len(sampled_bbox))) + sample_bbox = sampled_bbox.pop(idx) + sample_bbox = clip_bbox(sample_bbox) + crop_bbox, crop_class, crop_score = \ + filter_and_process(sample_bbox, gt_bbox, gt_class, scores=gt_score) + if self.avoid_no_bbox: + if len(crop_bbox) < 1: + continue + xmin = int(sample_bbox[0] * im_width) + xmax = int(sample_bbox[2] * im_width) + ymin = int(sample_bbox[1] * im_height) + ymax = int(sample_bbox[3] * im_height) + im = im[ymin:ymax, xmin:xmax] + sample['image'] = im + sample['gt_bbox'] = crop_bbox + sample['gt_class'] = crop_class + sample['gt_score'] = crop_score + return sample + return sample + + +@register_op +class CropImageWithDataAchorSampling(BaseOperator): + def __init__(self, + batch_sampler, + anchor_sampler=None, + target_size=None, + das_anchor_scales=[16, 32, 64, 128], + sampling_prob=0.5, + min_size=8., + avoid_no_bbox=True): + """ + Args: + anchor_sampler (list): anchor_sampling sets of different + parameters for cropping. + batch_sampler (list): Multiple sets of different + parameters for cropping. + e.g.[[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]] + [[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]] + [max sample, max trial, min scale, max scale, + min aspect ratio, max aspect ratio, + min overlap, max overlap, min coverage, max coverage] + target_size (bool): target image size. + das_anchor_scales (list[float]): a list of anchor scales in data + anchor smapling. + min_size (float): minimum size of sampled bbox. + avoid_no_bbox (bool): whether to to avoid the + situation where the box does not appear. + """ + super(CropImageWithDataAchorSampling, self).__init__() + self.anchor_sampler = anchor_sampler + self.batch_sampler = batch_sampler + self.target_size = target_size + self.sampling_prob = sampling_prob + self.min_size = min_size + self.avoid_no_bbox = avoid_no_bbox + self.das_anchor_scales = np.array(das_anchor_scales) + + def __call__(self, sample, context): + """ + Crop the image and modify bounding box. + Operators: + 1. Scale the image width and height. + 2. Crop the image according to a radom sample. + 3. Rescale the bounding box. + 4. Determine if the new bbox is satisfied in the new image. + Returns: + sample: the image, bounding box are replaced. + """ + assert 'image' in sample, "image data not found" + im = sample['image'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + image_width = sample['w'] + image_height = sample['h'] + gt_score = None + if 'gt_score' in sample: + gt_score = sample['gt_score'] + sampled_bbox = [] + gt_bbox = gt_bbox.tolist() + + prob = np.random.uniform(0., 1.) + if prob > self.sampling_prob: # anchor sampling + assert self.anchor_sampler + for sampler in self.anchor_sampler: + found = 0 + for i in range(sampler[1]): + if found >= sampler[0]: + break + sample_bbox = data_anchor_sampling( + gt_bbox, image_width, image_height, + self.das_anchor_scales, self.target_size) + if sample_bbox == 0: + break + if satisfy_sample_constraint_coverage(sampler, sample_bbox, + gt_bbox): + sampled_bbox.append(sample_bbox) + found = found + 1 + im = np.array(im) + while sampled_bbox: + idx = int(np.random.uniform(0, len(sampled_bbox))) + sample_bbox = sampled_bbox.pop(idx) + + if 'gt_keypoint' in sample.keys(): + keypoints = (sample['gt_keypoint'], + sample['keypoint_ignore']) + crop_bbox, crop_class, crop_score, gt_keypoints = \ + filter_and_process(sample_bbox, gt_bbox, gt_class, + scores=gt_score, + keypoints=keypoints) + else: + crop_bbox, crop_class, crop_score = filter_and_process( + sample_bbox, gt_bbox, gt_class, scores=gt_score) + crop_bbox, crop_class, crop_score = bbox_area_sampling( + crop_bbox, crop_class, crop_score, self.target_size, + self.min_size) + + if self.avoid_no_bbox: + if len(crop_bbox) < 1: + continue + im = crop_image_sampling(im, sample_bbox, image_width, + image_height, self.target_size) + sample['image'] = im + sample['gt_bbox'] = crop_bbox + sample['gt_class'] = crop_class + sample['gt_score'] = crop_score + if 'gt_keypoint' in sample.keys(): + sample['gt_keypoint'] = gt_keypoints[0] + sample['keypoint_ignore'] = gt_keypoints[1] + return sample + return sample + + else: + for sampler in self.batch_sampler: + found = 0 + for i in range(sampler[1]): + if found >= sampler[0]: + break + sample_bbox = generate_sample_bbox_square( + sampler, image_width, image_height) + if satisfy_sample_constraint_coverage(sampler, sample_bbox, + gt_bbox): + sampled_bbox.append(sample_bbox) + found = found + 1 + im = np.array(im) + while sampled_bbox: + idx = int(np.random.uniform(0, len(sampled_bbox))) + sample_bbox = sampled_bbox.pop(idx) + sample_bbox = clip_bbox(sample_bbox) + + if 'gt_keypoint' in sample.keys(): + keypoints = (sample['gt_keypoint'], + sample['keypoint_ignore']) + crop_bbox, crop_class, crop_score, gt_keypoints = \ + filter_and_process(sample_bbox, gt_bbox, gt_class, + scores=gt_score, + keypoints=keypoints) + else: + crop_bbox, crop_class, crop_score = filter_and_process( + sample_bbox, gt_bbox, gt_class, scores=gt_score) + # sampling bbox according the bbox area + crop_bbox, crop_class, crop_score = bbox_area_sampling( + crop_bbox, crop_class, crop_score, self.target_size, + self.min_size) + + if self.avoid_no_bbox: + if len(crop_bbox) < 1: + continue + xmin = int(sample_bbox[0] * image_width) + xmax = int(sample_bbox[2] * image_width) + ymin = int(sample_bbox[1] * image_height) + ymax = int(sample_bbox[3] * image_height) + im = im[ymin:ymax, xmin:xmax] + sample['image'] = im + sample['gt_bbox'] = crop_bbox + sample['gt_class'] = crop_class + sample['gt_score'] = crop_score + if 'gt_keypoint' in sample.keys(): + sample['gt_keypoint'] = gt_keypoints[0] + sample['keypoint_ignore'] = gt_keypoints[1] + return sample + return sample + + +@register_op +class NormalizeBox(BaseOperator): + """Transform the bounding box's coornidates to [0,1].""" + + def __init__(self): + super(NormalizeBox, self).__init__() + + def __call__(self, sample, context): + gt_bbox = sample['gt_bbox'] + width = sample['w'] + height = sample['h'] + for i in range(gt_bbox.shape[0]): + gt_bbox[i][0] = gt_bbox[i][0] / width + gt_bbox[i][1] = gt_bbox[i][1] / height + gt_bbox[i][2] = gt_bbox[i][2] / width + gt_bbox[i][3] = gt_bbox[i][3] / height + sample['gt_bbox'] = gt_bbox + + if 'gt_keypoint' in sample.keys(): + gt_keypoint = sample['gt_keypoint'] + + for i in range(gt_keypoint.shape[1]): + if i % 2: + gt_keypoint[:, i] = gt_keypoint[:, i] / height + else: + gt_keypoint[:, i] = gt_keypoint[:, i] / width + sample['gt_keypoint'] = gt_keypoint + + return sample + + +@register_op +class Permute(BaseOperator): + def __init__(self, to_bgr=True, channel_first=True): + """ + Change the channel. + Args: + to_bgr (bool): confirm whether to convert RGB to BGR + channel_first (bool): confirm whether to change channel + """ + super(Permute, self).__init__() + self.to_bgr = to_bgr + self.channel_first = channel_first + if not (isinstance(self.to_bgr, bool) and + isinstance(self.channel_first, bool)): + raise TypeError("{}: input type is invalid.".format(self)) + + def __call__(self, sample, context=None): + samples = sample + batch_input = True + if not isinstance(samples, Sequence): + batch_input = False + samples = [samples] + for sample in samples: + assert 'image' in sample, "image data not found" + for k in sample.keys(): + # hard code + if k.startswith('image'): + im = sample[k] + if self.channel_first: + im = np.swapaxes(im, 1, 2) + im = np.swapaxes(im, 1, 0) + if self.to_bgr: + im = im[[2, 1, 0], :, :] + sample[k] = im + if not batch_input: + samples = samples[0] + return samples + + +@register_op +class MixupImage(BaseOperator): + def __init__(self, alpha=1.5, beta=1.5): + """ Mixup image and gt_bbbox/gt_score + Args: + alpha (float): alpha parameter of beta distribute + beta (float): beta parameter of beta distribute + """ + super(MixupImage, self).__init__() + self.alpha = alpha + self.beta = beta + if self.alpha <= 0.0: + raise ValueError("alpha shold be positive in {}".format(self)) + if self.beta <= 0.0: + raise ValueError("beta shold be positive in {}".format(self)) + + def _mixup_img(self, img1, img2, factor): + h = max(img1.shape[0], img2.shape[0]) + w = max(img1.shape[1], img2.shape[1]) + img = np.zeros((h, w, img1.shape[2]), 'float32') + img[:img1.shape[0], :img1.shape[1], :] = \ + img1.astype('float32') * factor + img[:img2.shape[0], :img2.shape[1], :] += \ + img2.astype('float32') * (1.0 - factor) + return img.astype('uint8') + + def __call__(self, sample, context=None): + if 'mixup' not in sample: + return sample + factor = np.random.beta(self.alpha, self.beta) + factor = max(0.0, min(1.0, factor)) + if factor >= 1.0: + sample.pop('mixup') + return sample + if factor <= 0.0: + return sample['mixup'] + im = self._mixup_img(sample['image'], sample['mixup']['image'], factor) + gt_bbox1 = sample['gt_bbox'].reshape((-1, 4)) + gt_bbox2 = sample['mixup']['gt_bbox'].reshape((-1, 4)) + gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0) + gt_class1 = sample['gt_class'] + gt_class2 = sample['mixup']['gt_class'] + gt_class = np.concatenate((gt_class1, gt_class2), axis=0) + + gt_score1 = sample['gt_score'] + gt_score2 = sample['mixup']['gt_score'] + gt_score = np.concatenate( + (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) + + is_crowd1 = sample['is_crowd'] + is_crowd2 = sample['mixup']['is_crowd'] + is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0) + + sample['image'] = im + sample['gt_bbox'] = gt_bbox + sample['gt_score'] = gt_score + sample['gt_class'] = gt_class + sample['is_crowd'] = is_crowd + sample['h'] = im.shape[0] + sample['w'] = im.shape[1] + sample.pop('mixup') + return sample + + +@register_op +class CutmixImage(BaseOperator): + def __init__(self, alpha=1.5, beta=1.5): + """ + CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, see https://https://arxiv.org/abs/1905.04899 + Cutmix image and gt_bbbox/gt_score + Args: + alpha (float): alpha parameter of beta distribute + beta (float): beta parameter of beta distribute + """ + super(CutmixImage, self).__init__() + self.alpha = alpha + self.beta = beta + if self.alpha <= 0.0: + raise ValueError("alpha shold be positive in {}".format(self)) + if self.beta <= 0.0: + raise ValueError("beta shold be positive in {}".format(self)) + + def _rand_bbox(self, img1, img2, factor): + """ _rand_bbox """ + h = max(img1.shape[0], img2.shape[0]) + w = max(img1.shape[1], img2.shape[1]) + cut_rat = np.sqrt(1. - factor) + + cut_w = np.int(w * cut_rat) + cut_h = np.int(h * cut_rat) + + # uniform + cx = np.random.randint(w) + cy = np.random.randint(h) + + bbx1 = np.clip(cx - cut_w // 2, 0, w) + bby1 = np.clip(cy - cut_h // 2, 0, h) + bbx2 = np.clip(cx + cut_w // 2, 0, w) + bby2 = np.clip(cy + cut_h // 2, 0, h) + + img_1 = np.zeros((h, w, img1.shape[2]), 'float32') + img_1[:img1.shape[0], :img1.shape[1], :] = \ + img1.astype('float32') + img_2 = np.zeros((h, w, img2.shape[2]), 'float32') + img_2[:img2.shape[0], :img2.shape[1], :] = \ + img2.astype('float32') + img_1[bby1:bby2, bbx1:bbx2, :] = img2[bby1:bby2, bbx1:bbx2, :] + return img_1 + + def __call__(self, sample, context=None): + if 'cutmix' not in sample: + return sample + factor = np.random.beta(self.alpha, self.beta) + factor = max(0.0, min(1.0, factor)) + if factor >= 1.0: + sample.pop('cutmix') + return sample + if factor <= 0.0: + return sample['cutmix'] + img1 = sample['image'] + img2 = sample['cutmix']['image'] + img = self._rand_bbox(img1, img2, factor) + gt_bbox1 = sample['gt_bbox'] + gt_bbox2 = sample['cutmix']['gt_bbox'] + gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0) + gt_class1 = sample['gt_class'] + gt_class2 = sample['cutmix']['gt_class'] + gt_class = np.concatenate((gt_class1, gt_class2), axis=0) + gt_score1 = sample['gt_score'] + gt_score2 = sample['cutmix']['gt_score'] + gt_score = np.concatenate( + (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) + sample['image'] = img + sample['gt_bbox'] = gt_bbox + sample['gt_score'] = gt_score + sample['gt_class'] = gt_class + sample['h'] = img.shape[0] + sample['w'] = img.shape[1] + sample.pop('cutmix') + return sample + + +@register_op +class RandomInterpImage(BaseOperator): + def __init__(self, target_size=0, max_size=0): + """ + Random reisze image by multiply interpolate method. + Args: + target_size (int): the taregt size of image's short side + max_size (int): the max size of image + """ + super(RandomInterpImage, self).__init__() + self.target_size = target_size + self.max_size = max_size + if not (isinstance(self.target_size, int) and + isinstance(self.max_size, int)): + raise TypeError('{}: input type is invalid.'.format(self)) + interps = [ + cv2.INTER_NEAREST, + cv2.INTER_LINEAR, + cv2.INTER_AREA, + cv2.INTER_CUBIC, + cv2.INTER_LANCZOS4, + ] + self.resizers = [] + for interp in interps: + self.resizers.append(ResizeImage(target_size, max_size, interp)) + + def __call__(self, sample, context=None): + """Resise the image numpy by random resizer.""" + resizer = random.choice(self.resizers) + return resizer(sample, context) + + +@register_op +class Resize(BaseOperator): + """Resize image and bbox. + Args: + target_dim (int or list): target size, can be a single number or a list + (for random shape). + interp (int or str): interpolation method, can be an integer or + 'random' (for randomized interpolation). + default to `cv2.INTER_LINEAR`. + """ + + def __init__(self, target_dim=[], interp=cv2.INTER_LINEAR): + super(Resize, self).__init__() + self.target_dim = target_dim + self.interp = interp # 'random' for yolov3 + + def __call__(self, sample, context=None): + w = sample['w'] + h = sample['h'] + + interp = self.interp + if interp == 'random': + interp = np.random.choice(range(5)) + + if isinstance(self.target_dim, Sequence): + dim = np.random.choice(self.target_dim) + else: + dim = self.target_dim + resize_w = resize_h = dim + scale_x = dim / w + scale_y = dim / h + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32) + sample['gt_bbox'] = np.clip(sample['gt_bbox'] * scale_array, 0, + dim - 1) + sample['scale_factor'] = [scale_x, scale_y] * 2 + sample['h'] = resize_h + sample['w'] = resize_w + + sample['image'] = cv2.resize( + sample['image'], (resize_w, resize_h), interpolation=interp) + return sample + + +@register_op +class ColorDistort(BaseOperator): + """Random color distortion. + Args: + hue (list): hue settings. + in [lower, upper, probability] format. + saturation (list): saturation settings. + in [lower, upper, probability] format. + contrast (list): contrast settings. + in [lower, upper, probability] format. + brightness (list): brightness settings. + in [lower, upper, probability] format. + random_apply (bool): whether to apply in random (yolo) or fixed (SSD) + order. + hsv_format (bool): whether to convert color from BGR to HSV + random_channel (bool): whether to swap channels randomly + """ + + def __init__(self, + hue=[-18, 18, 0.5], + saturation=[0.5, 1.5, 0.5], + contrast=[0.5, 1.5, 0.5], + brightness=[0.5, 1.5, 0.5], + random_apply=True, + hsv_format=False, + random_channel=False): + super(ColorDistort, self).__init__() + self.hue = hue + self.saturation = saturation + self.contrast = contrast + self.brightness = brightness + self.random_apply = random_apply + self.hsv_format = hsv_format + self.random_channel = random_channel + + def apply_hue(self, img): + low, high, prob = self.hue + if np.random.uniform(0., 1.) < prob: + return img + + img = img.astype(np.float32) + if self.hsv_format: + img[..., 0] += random.uniform(low, high) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + return img + + # XXX works, but result differ from HSV version + delta = np.random.uniform(low, high) + u = np.cos(delta * np.pi) + w = np.sin(delta * np.pi) + bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]]) + tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321], + [0.211, -0.523, 0.311]]) + ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647], + [1.0, -1.107, 1.705]]) + t = np.dot(np.dot(ityiq, bt), tyiq).T + img = np.dot(img, t) + return img + + def apply_saturation(self, img): + low, high, prob = self.saturation + if np.random.uniform(0., 1.) < prob: + return img + delta = np.random.uniform(low, high) + img = img.astype(np.float32) + if self.hsv_format: + img[..., 1] *= delta + return img + gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32) + gray = gray.sum(axis=2, keepdims=True) + gray *= (1.0 - delta) + img *= delta + img += gray + return img + + def apply_contrast(self, img): + low, high, prob = self.contrast + if np.random.uniform(0., 1.) < prob: + return img + delta = np.random.uniform(low, high) + + img = img.astype(np.float32) + img *= delta + return img + + def apply_brightness(self, img): + low, high, prob = self.brightness + if np.random.uniform(0., 1.) < prob: + return img + delta = np.random.uniform(low, high) + + img = img.astype(np.float32) + img += delta + return img + + def __call__(self, sample, context=None): + img = sample['image'] + if self.random_apply: + functions = [ + self.apply_brightness, + self.apply_contrast, + self.apply_saturation, + self.apply_hue, + ] + distortions = np.random.permutation(functions) + for func in distortions: + img = func(img) + sample['image'] = img + return sample + + img = self.apply_brightness(img) + + if np.random.randint(0, 2): + img = self.apply_contrast(img) + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + img = self.apply_saturation(img) + img = self.apply_hue(img) + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + else: + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + img = self.apply_saturation(img) + img = self.apply_hue(img) + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + img = self.apply_contrast(img) + + if self.random_channel: + if np.random.randint(0, 2): + img = img[..., np.random.permutation(3)] + sample['image'] = img + return sample + + +@register_op +class CornerRandColor(ColorDistort): + """Random color for CornerNet series models. + Args: + saturation (float): saturation settings. + contrast (float): contrast settings. + brightness (float): brightness settings. + is_scale (bool): whether to scale the input image. + """ + + def __init__(self, + saturation=0.4, + contrast=0.4, + brightness=0.4, + is_scale=True): + super(CornerRandColor, self).__init__( + saturation=saturation, contrast=contrast, brightness=brightness) + self.is_scale = is_scale + + def apply_saturation(self, img, img_gray): + alpha = 1. + np.random.uniform( + low=-self.saturation, high=self.saturation) + self._blend(alpha, img, img_gray[:, :, None]) + return img + + def apply_contrast(self, img, img_gray): + alpha = 1. + np.random.uniform(low=-self.contrast, high=self.contrast) + img_mean = img_gray.mean() + self._blend(alpha, img, img_mean) + return img + + def apply_brightness(self, img, img_gray): + alpha = 1 + np.random.uniform( + low=-self.brightness, high=self.brightness) + img *= alpha + return img + + def _blend(self, alpha, img, img_mean): + img *= alpha + img_mean *= (1 - alpha) + img += img_mean + + def __call__(self, sample, context=None): + img = sample['image'] + if self.is_scale: + img = img.astype(np.float32, copy=False) + img /= 255. + img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + functions = [ + self.apply_brightness, + self.apply_contrast, + self.apply_saturation, + ] + distortions = np.random.permutation(functions) + for func in distortions: + img = func(img, img_gray) + sample['image'] = img + return sample + + +@register_op +class NormalizePermute(BaseOperator): + """Normalize and permute channel order. + Args: + mean (list): mean values in RGB order. + std (list): std values in RGB order. + """ + + def __init__(self, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.120, 57.375]): + super(NormalizePermute, self).__init__() + self.mean = mean + self.std = std + + def __call__(self, sample, context=None): + img = sample['image'] + img = img.astype(np.float32) + + img = img.transpose((2, 0, 1)) + mean = np.array(self.mean, dtype=np.float32) + std = np.array(self.std, dtype=np.float32) + invstd = 1. / std + for v, m, s in zip(img, mean, invstd): + v.__isub__(m).__imul__(s) + sample['image'] = img + return sample + + +@register_op +class RandomExpand(BaseOperator): + """Random expand the canvas. + Args: + ratio (float): maximum expansion ratio. + prob (float): probability to expand. + fill_value (list): color value used to fill the canvas. in RGB order. + is_mask_expand(bool): whether expand the segmentation. + """ + + def __init__(self, + ratio=4., + prob=0.5, + fill_value=(127.5, ) * 3, + is_mask_expand=False): + super(RandomExpand, self).__init__() + assert ratio > 1.01, "expand ratio must be larger than 1.01" + self.ratio = ratio + self.prob = prob + assert isinstance(fill_value, (Number, Sequence)), \ + "fill value must be either float or sequence" + if isinstance(fill_value, Number): + fill_value = (fill_value, ) * 3 + if not isinstance(fill_value, tuple): + fill_value = tuple(fill_value) + self.fill_value = fill_value + self.is_mask_expand = is_mask_expand + + def expand_segms(self, segms, x, y, height, width, ratio): + def _expand_poly(poly, x, y): + expanded_poly = np.array(poly) + expanded_poly[0::2] += x + expanded_poly[1::2] += y + return expanded_poly.tolist() + + def _expand_rle(rle, x, y, height, width, ratio): + if 'counts' in rle and type(rle['counts']) == list: + rle = mask_util.frPyObjects(rle, height, width) + mask = mask_util.decode(rle) + expanded_mask = np.full((int(height * ratio), int(width * ratio)), + 0).astype(mask.dtype) + expanded_mask[y:y + height, x:x + width] = mask + rle = mask_util.encode( + np.array( + expanded_mask, order='F', dtype=np.uint8)) + return rle + + expanded_segms = [] + for segm in segms: + if is_poly(segm): + # Polygon format + expanded_segms.append( + [_expand_poly(poly, x, y) for poly in segm]) + else: + # RLE format + import pycocotools.mask as mask_util + expanded_segms.append( + _expand_rle(segm, x, y, height, width, ratio)) + return expanded_segms + + def __call__(self, sample, context=None): + if np.random.uniform(0., 1.) < self.prob: + return sample + + img = sample['image'] + height = int(sample['h']) + width = int(sample['w']) + + expand_ratio = np.random.uniform(1., self.ratio) + h = int(height * expand_ratio) + w = int(width * expand_ratio) + if not h > height or not w > width: + return sample + y = np.random.randint(0, h - height) + x = np.random.randint(0, w - width) + canvas = np.ones((h, w, 3), dtype=np.uint8) + canvas *= np.array(self.fill_value, dtype=np.uint8) + canvas[y:y + height, x:x + width, :] = img.astype(np.uint8) + + sample['h'] = h + sample['w'] = w + sample['image'] = canvas + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + sample['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32) + if self.is_mask_expand and 'gt_poly' in sample and len(sample[ + 'gt_poly']) > 0: + sample['gt_poly'] = self.expand_segms(sample['gt_poly'], x, y, + height, width, expand_ratio) + return sample + + +@register_op +class RandomCrop(BaseOperator): + """Random crop image and bboxes. + Args: + aspect_ratio (list): aspect ratio of cropped region. + in [min, max] format. + thresholds (list): iou thresholds for decide a valid bbox crop. + scaling (list): ratio between a cropped region and the original image. + in [min, max] format. + num_attempts (int): number of tries before giving up. + allow_no_crop (bool): allow return without actually cropping them. + cover_all_box (bool): ensure all bboxes are covered in the final crop. + is_mask_crop(bool): whether crop the segmentation. + """ + + def __init__(self, + aspect_ratio=[.5, 2.], + thresholds=[.0, .1, .3, .5, .7, .9], + scaling=[.3, 1.], + num_attempts=50, + allow_no_crop=True, + cover_all_box=False, + is_mask_crop=False): + super(RandomCrop, self).__init__() + self.aspect_ratio = aspect_ratio + self.thresholds = thresholds + self.scaling = scaling + self.num_attempts = num_attempts + self.allow_no_crop = allow_no_crop + self.cover_all_box = cover_all_box + self.is_mask_crop = is_mask_crop + + def crop_segms(self, segms, valid_ids, crop, height, width): + def _crop_poly(segm, crop): + xmin, ymin, xmax, ymax = crop + crop_coord = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin] + crop_p = np.array(crop_coord).reshape(4, 2) + crop_p = Polygon(crop_p) + + crop_segm = list() + for poly in segm: + poly = np.array(poly).reshape(len(poly) // 2, 2) + polygon = Polygon(poly) + if not polygon.is_valid: + exterior = polygon.exterior + multi_lines = exterior.intersection(exterior) + polygons = shapely.ops.polygonize(multi_lines) + polygon = MultiPolygon(polygons) + multi_polygon = list() + if isinstance(polygon, MultiPolygon): + multi_polygon = copy.deepcopy(polygon) + else: + multi_polygon.append(copy.deepcopy(polygon)) + for per_polygon in multi_polygon: + inter = per_polygon.intersection(crop_p) + if not inter: + continue + if isinstance(inter, (MultiPolygon, GeometryCollection)): + for part in inter: + if not isinstance(part, Polygon): + continue + part = np.squeeze( + np.array(part.exterior.coords[:-1]).reshape(1, + -1)) + part[0::2] -= xmin + part[1::2] -= ymin + crop_segm.append(part.tolist()) + elif isinstance(inter, Polygon): + crop_poly = np.squeeze( + np.array(inter.exterior.coords[:-1]).reshape(1, -1)) + crop_poly[0::2] -= xmin + crop_poly[1::2] -= ymin + crop_segm.append(crop_poly.tolist()) + else: + continue + return crop_segm + + def _crop_rle(rle, crop, height, width): + if 'counts' in rle and type(rle['counts']) == list: + rle = mask_util.frPyObjects(rle, height, width) + mask = mask_util.decode(rle) + mask = mask[crop[1]:crop[3], crop[0]:crop[2]] + rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) + return rle + + crop_segms = [] + for id in valid_ids: + segm = segms[id] + if is_poly(segm): + import copy + import shapely.ops + from shapely.geometry import Polygon, MultiPolygon, GeometryCollection + logging.getLogger("shapely").setLevel(logging.WARNING) + # Polygon format + crop_segms.append(_crop_poly(segm, crop)) + else: + # RLE format + import pycocotools.mask as mask_util + crop_segms.append(_crop_rle(segm, crop, height, width)) + return crop_segms + + def __call__(self, sample, context=None): + if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0: + return sample + + h = sample['h'] + w = sample['w'] + gt_bbox = sample['gt_bbox'] + + # NOTE Original method attempts to generate one candidate for each + # threshold then randomly sample one from the resulting list. + # Here a short circuit approach is taken, i.e., randomly choose a + # threshold and attempt to find a valid crop, and simply return the + # first one found. + # The probability is not exactly the same, kinda resembling the + # "Monty Hall" problem. Actually carrying out the attempts will affect + # observability (just like opening doors in the "Monty Hall" game). + thresholds = list(self.thresholds) + if self.allow_no_crop: + thresholds.append('no_crop') + np.random.shuffle(thresholds) + + for thresh in thresholds: + if thresh == 'no_crop': + return sample + + found = False + for i in range(self.num_attempts): + scale = np.random.uniform(*self.scaling) + if self.aspect_ratio is not None: + min_ar, max_ar = self.aspect_ratio + aspect_ratio = np.random.uniform( + max(min_ar, scale**2), min(max_ar, scale**-2)) + h_scale = scale / np.sqrt(aspect_ratio) + w_scale = scale * np.sqrt(aspect_ratio) + else: + h_scale = np.random.uniform(*self.scaling) + w_scale = np.random.uniform(*self.scaling) + crop_h = h * h_scale + crop_w = w * w_scale + if self.aspect_ratio is None: + if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0: + continue + + crop_h = int(crop_h) + crop_w = int(crop_w) + crop_y = np.random.randint(0, h - crop_h) + crop_x = np.random.randint(0, w - crop_w) + crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] + iou = self._iou_matrix( + gt_bbox, np.array( + [crop_box], dtype=np.float32)) + if iou.max() < thresh: + continue + + if self.cover_all_box and iou.min() < thresh: + continue + + cropped_box, valid_ids = self._crop_box_with_center_constraint( + gt_bbox, np.array( + crop_box, dtype=np.float32)) + if valid_ids.size > 0: + found = True + break + + if found: + if self.is_mask_crop and 'gt_poly' in sample and len(sample[ + 'gt_poly']) > 0: + crop_polys = self.crop_segms( + sample['gt_poly'], + valid_ids, + np.array( + crop_box, dtype=np.int64), + h, + w) + if [] in crop_polys: + delete_id = list() + valid_polys = list() + for id, crop_poly in enumerate(crop_polys): + if crop_poly == []: + delete_id.append(id) + else: + valid_polys.append(crop_poly) + valid_ids = np.delete(valid_ids, delete_id) + if len(valid_polys) == 0: + return sample + sample['gt_poly'] = valid_polys + else: + sample['gt_poly'] = crop_polys + + if 'gt_segm' in sample: + sample['gt_segm'] = self._crop_segm(sample['gt_segm'], + crop_box) + sample['gt_segm'] = np.take( + sample['gt_segm'], valid_ids, axis=0) + sample['image'] = self._crop_image(sample['image'], crop_box) + sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0) + sample['gt_class'] = np.take( + sample['gt_class'], valid_ids, axis=0) + sample['w'] = crop_box[2] - crop_box[0] + sample['h'] = crop_box[3] - crop_box[1] + if 'gt_score' in sample: + sample['gt_score'] = np.take( + sample['gt_score'], valid_ids, axis=0) + + if 'is_crowd' in sample: + sample['is_crowd'] = np.take( + sample['is_crowd'], valid_ids, axis=0) + return sample + + return sample + + def _iou_matrix(self, a, b): + tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + area_o = (area_a[:, np.newaxis] + area_b - area_i) + return area_i / (area_o + 1e-10) + + def _crop_box_with_center_constraint(self, box, crop): + cropped_box = box.copy() + + cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2]) + cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:]) + cropped_box[:, :2] -= crop[:2] + cropped_box[:, 2:] -= crop[:2] + + centers = (box[:, :2] + box[:, 2:]) / 2 + valid = np.logical_and(crop[:2] <= centers, + centers < crop[2:]).all(axis=1) + valid = np.logical_and( + valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1)) + + return cropped_box, np.where(valid)[0] + + def _crop_image(self, img, crop): + x1, y1, x2, y2 = crop + return img[y1:y2, x1:x2, :] + + def _crop_segm(self, segm, crop): + x1, y1, x2, y2 = crop + return segm[:, y1:y2, x1:x2] + + +@register_op +class PadBox(BaseOperator): + def __init__(self, num_max_boxes=50): + """ + Pad zeros to bboxes if number of bboxes is less than num_max_boxes. + Args: + num_max_boxes (int): the max number of bboxes + """ + self.num_max_boxes = num_max_boxes + super(PadBox, self).__init__() + + def __call__(self, sample, context=None): + assert 'gt_bbox' in sample + bbox = sample['gt_bbox'] + gt_num = min(self.num_max_boxes, len(bbox)) + num_max = self.num_max_boxes + fields = context['fields'] if context else [] + pad_bbox = np.zeros((num_max, 4), dtype=np.float32) + if gt_num > 0: + pad_bbox[:gt_num, :] = bbox[:gt_num, :] + sample['gt_bbox'] = pad_bbox + if 'gt_class' in fields: + pad_class = np.zeros((num_max), dtype=np.int32) + if gt_num > 0: + pad_class[:gt_num] = sample['gt_class'][:gt_num, 0] + sample['gt_class'] = pad_class + if 'gt_score' in fields: + pad_score = np.zeros((num_max), dtype=np.float32) + if gt_num > 0: + pad_score[:gt_num] = sample['gt_score'][:gt_num, 0] + sample['gt_score'] = pad_score + # in training, for example in op ExpandImage, + # the bbox and gt_class is expandded, but the difficult is not, + # so, judging by it's length + if 'is_difficult' in fields: + pad_diff = np.zeros((num_max), dtype=np.int32) + if gt_num > 0: + pad_diff[:gt_num] = sample['difficult'][:gt_num, 0] + sample['difficult'] = pad_diff + return sample + + +@register_op +class BboxXYXY2XYWH(BaseOperator): + """ + Convert bbox XYXY format to XYWH format. + """ + + def __init__(self): + super(BboxXYXY2XYWH, self).__init__() + + def __call__(self, sample, context=None): + assert 'gt_bbox' in sample + bbox = sample['gt_bbox'] + bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2] + bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2. + sample['gt_bbox'] = bbox + return sample + + +@register_op +class Lighting(BaseOperator): + """ + Lighting the imagen by eigenvalues and eigenvectors + Args: + eigval (list): eigenvalues + eigvec (list): eigenvectors + alphastd (float): random weight of lighting, 0.1 by default + """ + + def __init__(self, eigval, eigvec, alphastd=0.1): + super(Lighting, self).__init__() + self.alphastd = alphastd + self.eigval = np.array(eigval).astype('float32') + self.eigvec = np.array(eigvec).astype('float32') + + def __call__(self, sample, context=None): + alpha = np.random.normal(scale=self.alphastd, size=(3, )) + sample['image'] += np.dot(self.eigvec, self.eigval * alpha) + return sample + + +@register_op +class CornerTarget(BaseOperator): + """ + Generate targets for CornerNet by ground truth data. + Args: + output_size (int): the size of output heatmaps. + num_classes (int): num of classes. + gaussian_bump (bool): whether to apply gaussian bump on gt targets. + True by default. + gaussian_rad (int): radius of gaussian bump. If it is set to -1, the + radius will be calculated by iou. -1 by default. + gaussian_iou (float): the threshold iou of predicted bbox to gt bbox. + If the iou is larger than threshold, the predicted bboox seems as + positive sample. 0.3 by default + max_tag_len (int): max num of gt box per image. + """ + + def __init__(self, + output_size, + num_classes, + gaussian_bump=True, + gaussian_rad=-1, + gaussian_iou=0.3, + max_tag_len=128): + super(CornerTarget, self).__init__() + self.num_classes = num_classes + self.output_size = output_size + self.gaussian_bump = gaussian_bump + self.gaussian_rad = gaussian_rad + self.gaussian_iou = gaussian_iou + self.max_tag_len = max_tag_len + + def __call__(self, sample, context=None): + tl_heatmaps = np.zeros( + (self.num_classes, self.output_size[0], self.output_size[1]), + dtype=np.float32) + br_heatmaps = np.zeros( + (self.num_classes, self.output_size[0], self.output_size[1]), + dtype=np.float32) + + tl_regrs = np.zeros((self.max_tag_len, 2), dtype=np.float32) + br_regrs = np.zeros((self.max_tag_len, 2), dtype=np.float32) + tl_tags = np.zeros((self.max_tag_len), dtype=np.int64) + br_tags = np.zeros((self.max_tag_len), dtype=np.int64) + tag_masks = np.zeros((self.max_tag_len), dtype=np.uint8) + tag_lens = np.zeros((), dtype=np.int32) + tag_nums = np.zeros((1), dtype=np.int32) + + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + keep_inds = ((gt_bbox[:, 2] - gt_bbox[:, 0]) > 0) & \ + ((gt_bbox[:, 3] - gt_bbox[:, 1]) > 0) + gt_bbox = gt_bbox[keep_inds] + gt_class = gt_class[keep_inds] + sample['gt_bbox'] = gt_bbox + sample['gt_class'] = gt_class + width_ratio = self.output_size[1] / sample['w'] + height_ratio = self.output_size[0] / sample['h'] + for i in range(gt_bbox.shape[0]): + width = gt_bbox[i][2] - gt_bbox[i][0] + height = gt_bbox[i][3] - gt_bbox[i][1] + + xtl, ytl = gt_bbox[i][0], gt_bbox[i][1] + xbr, ybr = gt_bbox[i][2], gt_bbox[i][3] + + fxtl = (xtl * width_ratio) + fytl = (ytl * height_ratio) + fxbr = (xbr * width_ratio) + fybr = (ybr * height_ratio) + + xtl = int(fxtl) + ytl = int(fytl) + xbr = int(fxbr) + ybr = int(fybr) + if self.gaussian_bump: + width = math.ceil(width * width_ratio) + height = math.ceil(height * height_ratio) + if self.gaussian_rad == -1: + radius = gaussian_radius((height, width), self.gaussian_iou) + radius = max(0, int(radius)) + else: + radius = self.gaussian_rad + draw_gaussian(tl_heatmaps[gt_class[i][0]], [xtl, ytl], radius) + draw_gaussian(br_heatmaps[gt_class[i][0]], [xbr, ybr], radius) + else: + tl_heatmaps[gt_class[i][0], ytl, xtl] = 1 + br_heatmaps[gt_class[i][0], ybr, xbr] = 1 + + tl_regrs[i, :] = [fxtl - xtl, fytl - ytl] + br_regrs[i, :] = [fxbr - xbr, fybr - ybr] + tl_tags[tag_lens] = ytl * self.output_size[1] + xtl + br_tags[tag_lens] = ybr * self.output_size[1] + xbr + tag_lens += 1 + + tag_masks[:tag_lens] = 1 + + sample['tl_heatmaps'] = tl_heatmaps + sample['br_heatmaps'] = br_heatmaps + sample['tl_regrs'] = tl_regrs + sample['br_regrs'] = br_regrs + sample['tl_tags'] = tl_tags + sample['br_tags'] = br_tags + sample['tag_masks'] = tag_masks + + return sample + + +@register_op +class CornerCrop(BaseOperator): + """ + Random crop for CornerNet + Args: + random_scales (list): scales of output_size to input_size. + border (int): border of corp center + is_train (bool): train or test + input_size (int): size of input image + """ + + def __init__(self, + random_scales=[0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3], + border=128, + is_train=True, + input_size=511): + super(CornerCrop, self).__init__() + self.random_scales = random_scales + self.border = border + self.is_train = is_train + self.input_size = input_size + + def __call__(self, sample, context=None): + im_h, im_w = int(sample['h']), int(sample['w']) + if self.is_train: + scale = np.random.choice(self.random_scales) + height = int(self.input_size * scale) + width = int(self.input_size * scale) + + w_border = self._get_border(self.border, im_w) + h_border = self._get_border(self.border, im_h) + + ctx = np.random.randint(low=w_border, high=im_w - w_border) + cty = np.random.randint(low=h_border, high=im_h - h_border) + + else: + cty, ctx = im_h // 2, im_w // 2 + height = im_h | 127 + width = im_w | 127 + + cropped_image = np.zeros( + (height, width, 3), dtype=sample['image'].dtype) + + x0, x1 = max(ctx - width // 2, 0), min(ctx + width // 2, im_w) + y0, y1 = max(cty - height // 2, 0), min(cty + height // 2, im_h) + + left_w, right_w = ctx - x0, x1 - ctx + top_h, bottom_h = cty - y0, y1 - cty + + # crop image + cropped_ctx, cropped_cty = width // 2, height // 2 + x_slice = slice(int(cropped_ctx - left_w), int(cropped_ctx + right_w)) + y_slice = slice(int(cropped_cty - top_h), int(cropped_cty + bottom_h)) + cropped_image[y_slice, x_slice, :] = sample['image'][y0:y1, x0:x1, :] + + sample['image'] = cropped_image + sample['h'], sample['w'] = height, width + + if self.is_train: + # crop detections + gt_bbox = sample['gt_bbox'] + gt_bbox[:, 0:4:2] -= x0 + gt_bbox[:, 1:4:2] -= y0 + gt_bbox[:, 0:4:2] += cropped_ctx - left_w + gt_bbox[:, 1:4:2] += cropped_cty - top_h + else: + sample['borders'] = np.array( + [ + cropped_cty - top_h, cropped_cty + bottom_h, + cropped_ctx - left_w, cropped_ctx + right_w + ], + dtype=np.float32) + + return sample + + def _get_border(self, border, size): + i = 1 + while size - border // i <= border // i: + i *= 2 + return border // i + + +@register_op +class CornerRatio(BaseOperator): + """ + Ratio of output size to image size + Args: + input_size (int): the size of input size + output_size (int): the size of heatmap + """ + + def __init__(self, input_size=511, output_size=64): + super(CornerRatio, self).__init__() + self.input_size = input_size + self.output_size = output_size + + def __call__(self, sample, context=None): + scale = (self.input_size + 1) // self.output_size + out_height, out_width = (sample['h'] + 1) // scale, ( + sample['w'] + 1) // scale + height_ratio = out_height / float(sample['h']) + width_ratio = out_width / float(sample['w']) + sample['ratios'] = np.array([height_ratio, width_ratio]) + + return sample + + +@register_op +class RandomScaledCrop(BaseOperator): + """Resize image and bbox based on long side (with optional random scaling), + then crop or pad image to target size. + Args: + target_dim (int): target size. + scale_range (list): random scale range. + interp (int): interpolation method, default to `cv2.INTER_LINEAR`. + """ + + def __init__(self, + target_dim=512, + scale_range=[.1, 2.], + interp=cv2.INTER_LINEAR): + super(RandomScaledCrop, self).__init__() + self.target_dim = target_dim + self.scale_range = scale_range + self.interp = interp + + def __call__(self, sample, context=None): + w = sample['w'] + h = sample['h'] + random_scale = np.random.uniform(*self.scale_range) + dim = self.target_dim + random_dim = int(dim * random_scale) + dim_max = max(h, w) + scale = random_dim / dim_max + resize_w = int(round(w * scale)) + resize_h = int(round(h * scale)) + offset_x = int(max(0, np.random.uniform(0., resize_w - dim))) + offset_y = int(max(0, np.random.uniform(0., resize_h - dim))) + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + scale_array = np.array([scale, scale] * 2, dtype=np.float32) + shift_array = np.array([offset_x, offset_y] * 2, dtype=np.float32) + boxes = sample['gt_bbox'] * scale_array - shift_array + boxes = np.clip(boxes, 0, dim - 1) + # filter boxes with no area + area = np.prod(boxes[..., 2:] - boxes[..., :2], axis=1) + valid = (area > 1.).nonzero()[0] + sample['gt_bbox'] = boxes[valid] + sample['gt_class'] = sample['gt_class'][valid] + + img = sample['image'] + img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interp) + img = np.array(img) + canvas = np.zeros((dim, dim, 3), dtype=img.dtype) + canvas[:min(dim, resize_h), :min(dim, resize_w), :] = img[ + offset_y:offset_y + dim, offset_x:offset_x + dim, :] + sample['h'] = dim + sample['w'] = dim + sample['image'] = canvas + sample['im_info'] = [resize_h, resize_w, scale] + return sample + + +@register_op +class ResizeAndPad(BaseOperator): + """Resize image and bbox, then pad image to target size. + Args: + target_dim (int): target size + interp (int): interpolation method, default to `cv2.INTER_LINEAR`. + """ + + def __init__(self, target_dim=512, interp=cv2.INTER_LINEAR): + super(ResizeAndPad, self).__init__() + self.target_dim = target_dim + self.interp = interp + + def __call__(self, sample, context=None): + w = sample['w'] + h = sample['h'] + interp = self.interp + dim = self.target_dim + dim_max = max(h, w) + scale = self.target_dim / dim_max + resize_w = int(round(w * scale)) + resize_h = int(round(h * scale)) + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + scale_array = np.array([scale, scale] * 2, dtype=np.float32) + sample['gt_bbox'] = np.clip(sample['gt_bbox'] * scale_array, 0, + dim - 1) + img = sample['image'] + img = cv2.resize(img, (resize_w, resize_h), interpolation=interp) + img = np.array(img) + canvas = np.zeros((dim, dim, 3), dtype=img.dtype) + canvas[:resize_h, :resize_w, :] = img + sample['h'] = dim + sample['w'] = dim + sample['image'] = canvas + sample['im_info'] = [resize_h, resize_w, scale] + return sample + + +@register_op +class TargetAssign(BaseOperator): + """Assign regression target and labels. + Args: + image_size (int or list): input image size, a single integer or list of + [h, w]. Default: 512 + min_level (int): min level of the feature pyramid. Default: 3 + max_level (int): max level of the feature pyramid. Default: 7 + anchor_base_scale (int): base anchor scale. Default: 4 + num_scales (int): number of anchor scales. Default: 3 + aspect_ratios (list): aspect ratios. + Default: [(1, 1), (1.4, 0.7), (0.7, 1.4)] + match_threshold (float): threshold for foreground IoU. Default: 0.5 + """ + + def __init__(self, + image_size=512, + min_level=3, + max_level=7, + anchor_base_scale=4, + num_scales=3, + aspect_ratios=[(1, 1), (1.4, 0.7), (0.7, 1.4)], + match_threshold=0.5): + super(TargetAssign, self).__init__() + assert image_size % 2 ** max_level == 0, \ + "image size should be multiple of the max level stride" + self.image_size = image_size + self.min_level = min_level + self.max_level = max_level + self.anchor_base_scale = anchor_base_scale + self.num_scales = num_scales + self.aspect_ratios = aspect_ratios + self.match_threshold = match_threshold + + @property + def anchors(self): + if not hasattr(self, '_anchors'): + anchor_grid = AnchorGrid(self.image_size, self.min_level, + self.max_level, self.anchor_base_scale, + self.num_scales, self.aspect_ratios) + self._anchors = np.concatenate(anchor_grid.generate()) + return self._anchors + + def iou_matrix(self, a, b): + tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + area_o = (area_a[:, np.newaxis] + area_b - area_i) + # return area_i / (area_o + 1e-10) + return np.where(area_i == 0., np.zeros_like(area_i), area_i / area_o) + + def match(self, anchors, gt_boxes): + # XXX put smaller matrix first would be a little bit faster + mat = self.iou_matrix(gt_boxes, anchors) + max_anchor_for_each_gt = mat.argmax(axis=1) + max_for_each_anchor = mat.max(axis=0) + anchor_to_gt = mat.argmax(axis=0) + anchor_to_gt[max_for_each_anchor < self.match_threshold] = -1 + # XXX ensure each gt has at least one anchor assigned, + # see `force_match_for_each_row` in TF implementation + one_hot = np.zeros_like(mat) + one_hot[np.arange(mat.shape[0]), max_anchor_for_each_gt] = 1. + max_anchor_indices = one_hot.sum(axis=0).nonzero()[0] + max_gt_indices = one_hot.argmax(axis=0)[max_anchor_indices] + anchor_to_gt[max_anchor_indices] = max_gt_indices + return anchor_to_gt + + def encode(self, anchors, boxes): + wha = anchors[..., 2:] - anchors[..., :2] + 1 + ca = anchors[..., :2] + wha * .5 + whb = boxes[..., 2:] - boxes[..., :2] + 1 + cb = boxes[..., :2] + whb * .5 + offsets = np.empty_like(anchors) + offsets[..., :2] = (cb - ca) / wha + offsets[..., 2:] = np.log(whb / wha) + return offsets + + def __call__(self, sample, context=None): + gt_boxes = sample['gt_bbox'] + gt_labels = sample['gt_class'] + labels = np.full((self.anchors.shape[0], 1), 0, dtype=np.int32) + targets = np.full((self.anchors.shape[0], 4), 0., dtype=np.float32) + sample['gt_label'] = labels + sample['gt_target'] = targets + + if len(gt_boxes) < 1: + sample['fg_num'] = np.array(0, dtype=np.int32) + return sample + + anchor_to_gt = self.match(self.anchors, gt_boxes) + matched_indices = (anchor_to_gt >= 0).nonzero()[0] + labels[matched_indices] = gt_labels[anchor_to_gt[matched_indices]] + + matched_boxes = gt_boxes[anchor_to_gt[matched_indices]] + matched_anchors = self.anchors[matched_indices] + matched_targets = self.encode(matched_anchors, matched_boxes) + targets[matched_indices] = matched_targets + sample['fg_num'] = np.array(len(matched_targets), dtype=np.int32) + return sample + + +@register_op +class DebugVisibleImage(BaseOperator): + """ + In debug mode, visualize images according to `gt_box`. + (Currently only supported when not cropping and flipping image.) + """ + + def __init__(self, + output_dir='output/debug', + use_vdl=False, + is_normalized=False): + super(DebugVisibleImage, self).__init__() + self.is_normalized = is_normalized + self.output_dir = output_dir + self.use_vdl = use_vdl + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + if not isinstance(self.is_normalized, bool): + raise TypeError("{}: input type is invalid.".format(self)) + if self.use_vdl: + assert six.PY3, "VisualDL requires Python >= 3.5" + from visualdl import LogWriter + self.vdl_writer = LogWriter(self.output_dir) + + def __call__(self, sample, context=None): + out_file_name = sample['im_file'].split('/')[-1] + if self.use_vdl: + origin_image = Image.open(sample['im_file']).convert('RGB') + origin_image = ImageOps.exif_transpose(origin_image) + image_np = np.array(origin_image) + self.vdl_writer.add_image("original/{}".format(out_file_name), + image_np, 0) + + if not isinstance(sample['image'], np.ndarray): + raise TypeError("{}: sample[image] type is not numpy.".format(self)) + image = Image.fromarray(np.uint8(sample['image'])) + + width = sample['w'] + height = sample['h'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + + if 'gt_poly' in sample.keys(): + poly_to_mask = Poly2Mask() + sample = poly_to_mask(sample) + + if 'gt_segm' in sample.keys(): + import pycocotools.mask as mask_util + from ppdet.utils.colormap import colormap + image_np = np.array(image).astype('float32') + mask_color_id = 0 + w_ratio = .4 + alpha = 0.7 + color_list = colormap(rgb=True) + gt_segm = sample['gt_segm'] + for mask in gt_segm: + color_mask = color_list[mask_color_id % len(color_list), 0:3] + mask_color_id += 1 + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio + ) + w_ratio * 255 + idx = np.nonzero(mask) + image_np[idx[0], idx[1], :] *= 1.0 - alpha + image_np[idx[0], idx[1], :] += alpha * color_mask + image = Image.fromarray(np.uint8(image_np)) + + draw = ImageDraw.Draw(image) + for i in range(gt_bbox.shape[0]): + if self.is_normalized: + gt_bbox[i][0] = gt_bbox[i][0] * width + gt_bbox[i][1] = gt_bbox[i][1] * height + gt_bbox[i][2] = gt_bbox[i][2] * width + gt_bbox[i][3] = gt_bbox[i][3] * height + + xmin, ymin, xmax, ymax = gt_bbox[i] + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=2, + fill='green') + # draw label + text = 'id' + str(gt_class[i][0]) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green') + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + + if 'gt_keypoint' in sample.keys(): + gt_keypoint = sample['gt_keypoint'] + if self.is_normalized: + for i in range(gt_keypoint.shape[1]): + if i % 2: + gt_keypoint[:, i] = gt_keypoint[:, i] * height + else: + gt_keypoint[:, i] = gt_keypoint[:, i] * width + for i in range(gt_keypoint.shape[0]): + keypoint = gt_keypoint[i] + for j in range(int(keypoint.shape[0] / 2)): + x1 = round(keypoint[2 * j]) + y1 = round(keypoint[2 * j + 1]) + draw.ellipse( + (x1, y1, x1 + 5, y1 + 5), fill='green', outline='green') + save_path = os.path.join(self.output_dir, out_file_name) + if self.use_vdl: + preprocess_image_np = np.array(image) + self.vdl_writer.add_image("preprocess/{}".format(out_file_name), + preprocess_image_np, 0) + else: + image.save(save_path, quality=95) + return sample + + +@register_op +class Poly2Mask(BaseOperator): + """ + gt poly to mask annotations + """ + + def __init__(self): + super(Poly2Mask, self).__init__() + import pycocotools.mask as maskUtils + self.maskutils = maskUtils + + def _poly2mask(self, mask_ann, img_h, img_w): + if isinstance(mask_ann, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = self.maskutils.frPyObjects(mask_ann, img_h, img_w) + rle = self.maskutils.merge(rles) + elif isinstance(mask_ann['counts'], list): + # uncompressed RLE + rle = self.maskutils.frPyObjects(mask_ann, img_h, img_w) + else: + # rle + rle = mask_ann + mask = self.maskutils.decode(rle) + return mask + + def __call__(self, sample, context=None): + assert 'gt_poly' in sample + im_h = sample['h'] + im_w = sample['w'] + masks = [ + self._poly2mask(gt_poly, im_h, im_w) + for gt_poly in sample['gt_poly'] + ] + sample['gt_segm'] = np.asarray(masks).astype(np.uint8) + return sample diff --git a/VisualFL/depends/PaddleDetection/ppdet/experimental/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/experimental/__init__.py new file mode 100755 index 000000000..f70396193 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/experimental/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from .mixed_precision import * +from . import mixed_precision + +__all__ = mixed_precision.__all__ diff --git a/VisualFL/depends/PaddleDetection/ppdet/experimental/mixed_precision.py b/VisualFL/depends/PaddleDetection/ppdet/experimental/mixed_precision.py new file mode 100755 index 000000000..e13a72142 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/experimental/mixed_precision.py @@ -0,0 +1,333 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import print_function + +import six +from paddle.fluid.framework import Parameter +from paddle.fluid import layers +from paddle.fluid import core +from paddle.fluid import unique_name +import paddle.fluid.layer_helper_base as lhb +import paddle.fluid.optimizer as optim + +__all__ = [ + 'mixed_precision_global_state', 'mixed_precision_context', + 'StaticLossScale', 'DynamicLossScale' +] + +_mixed_precision_global_state = None + + +def mixed_precision_global_state(): + return _mixed_precision_global_state + + +class LossScale(object): + def __init__(self): + super(LossScale, self).__init__() + + def get_loss_scale_var(self): + return self.scale + + def increment(self): + raise NotImplementedError() + + def decrement(self): + raise NotImplementedError() + + +class StaticLossScale(LossScale): + """ + Static (fixed) loss scale manager. + + Args: + init_loss_scale (float): initial loss scale value. + + Examples: + + .. code-block:: python + + from paddle import fluid + from ppdet.experimental import (mixed_precision_context, + StaticLossScale) + + with mixed_precision_context(StaticLossScale(8.), True) as ctx: + # ... + # scale loss + loss_scale = ctx.get_loss_scale_var() + + """ + + def __init__(self, init_loss_scale=1.): + super(StaticLossScale, self).__init__() + self.scale = layers.create_global_var( + name=unique_name.generate("loss_scale"), + shape=[1], + value=init_loss_scale, + dtype='float32', + persistable=True) + + +class DynamicLossScale(LossScale): + """ + Dynamic loss scale manager. it works as follows: + if gradients is valid for `increment_every` steps, loss scale values is + increased by `factor`, otherwise loss scale values is decreased by `factor` + + Args: + init_loss_scale (float): initial loss scale value. + increment_every (int): minimum 'good' steps before loss scale increase. + factor (float): increase/decrease loss scale by this much. + + Examples: + + .. code-block:: python + + from paddle import fluid + from ppdet.experimental import (mixed_precision_context, + DynamicLossScale) + + loss_scale = DynamicLossScale(8., 1000, 4.) + with mixed_precision_context(loss_scale, True) as ctx: + # ... + # scale loss + loss_scale = ctx.get_loss_scale_var() + + """ + + def __init__(self, init_loss_scale=2**15, increment_every=2000, factor=2.): + super(DynamicLossScale, self).__init__() + self.scale = layers.create_global_var( + name=unique_name.generate("loss_scale"), + shape=[1], + value=init_loss_scale, + dtype='float32', + persistable=True) + self.good_steps = layers.create_global_var( + name=unique_name.generate("good_steps"), + shape=[1], + value=0, + dtype='int32', + persistable=True) + self.increment_every = layers.fill_constant( + shape=[1], dtype='int32', value=increment_every) + self.factor = factor + + def increment(self): + enough_steps = layers.less_than(self.increment_every, + self.good_steps + 1) + + def increment_step(): + layers.increment(self.good_steps) + + def maybe_update(): + new_scale = self.scale * self.factor + scale_valid = layers.isfinite(new_scale) + + def update_scale_and_step(): + layers.assign(new_scale, self.scale) + layers.assign( + layers.zeros_like(self.good_steps), self.good_steps) + + layers.cond(scale_valid, update_scale_and_step) + + layers.cond(enough_steps, maybe_update, increment_step) + + def decrement(self): + new_scale = self.scale / self.factor + one = layers.fill_constant(shape=[1], dtype='float32', value=1.0) + layers.assign(layers.elementwise_max(new_scale, one), self.scale) + layers.assign(layers.zeros_like(self.good_steps), self.good_steps) + + +class mixed_precision_context(object): + """ + Context manager for mixed precision training. + + Args: + loss_scale (float, str or obj): loss scale settings, can be: + 1. an number: use fixed loss scale. + 2. 'dynamic': use a default `DynamicLossScale`. + 3. `DynamicLossScale` or `StaticLossScale` instance. + enabled (bool): enable mixed precision training. + + Examples: + + .. code-block:: python + + from paddle import fluid + from ppdet.experimental import mixed_precision_context + + with mixed_precision_context('dynamic', True) as ctx: + # cast inputs to float16 + inputs = fluid.layers.cast(inputs, "float16") + # build model here + logits = model(inputs) + # use float32 for softmax + logits = fluid.layers.cast(logits, "float32") + softmax = fluid.layers.softmax(logits) + loss = fluid.layers.cross_entropy(input=softmax, label=label) + avg_loss = fluid.layers.mean(loss) + # scale loss + loss_scale = ctx.get_loss_scale_var() + avg_loss *= loss_scale + optimizer = fluid.optimizer.Momentum(...) + optimizer.minimize(avg_loss) + + """ + + def __init__(self, loss_scale=1., enabled=True): + super(mixed_precision_context, self).__init__() + self.enabled = enabled + if not enabled: + return + monkey_patch() + if isinstance(loss_scale, six.integer_types + (float, )): + self.loss_scale = StaticLossScale(loss_scale) + elif loss_scale == 'dynamic': + self.loss_scale = DynamicLossScale() + else: + assert isinstance(loss_scale, LossScale), \ + "Invalid loss scale argument" + self.loss_scale = loss_scale + + @property + def dynamic_scaling(self): + return isinstance(self.loss_scale, DynamicLossScale) + + def __getattr__(self, attr): + if attr in ['get_loss_scale_var', 'increment', 'decrement']: + return getattr(self.loss_scale, attr) + + def __enter__(self): + if not self.enabled: + return + global _mixed_precision_global_state + _mixed_precision_global_state = self + return mixed_precision_global_state() + + def __exit__(self, *args): + if not self.enabled: + return + global _mixed_precision_global_state + _mixed_precision_global_state = None + return mixed_precision_global_state() + + +def create_parameter(self, + attr, + shape, + dtype, + is_bias=False, + default_initializer=None): + mp_state = mixed_precision_global_state() + is_half = (isinstance(dtype, str) and dtype == 'float16') \ + or (isinstance(dtype, core.VarDesc.VarType) + and dtype == core.VarDesc.VarType.FP16) + + if is_half and mp_state is not None: + dtype = 'float32' + + param = self._create_parameter(attr, shape, dtype, is_bias, + default_initializer) + if not is_half or mp_state is None: + return param + + param16 = self.main_program.current_block().create_var( + name=param.name + '.fp16', + dtype='float16', + type=param.type, + persistable=False) + self.append_op( + type='cast', + inputs={'X': [param]}, + outputs={'Out': [param16]}, + attrs={'in_dtype': param.dtype, + 'out_dtype': param16.dtype}) + return param16 + + +def scale_gradient(block, context): + state = mixed_precision_global_state() + if state is None: + return + scale = state.get_loss_scale_var() + op_desc = block.desc.op(block.desc.op_size() - 1) + op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() + bwd_role = core.op_proto_and_checker_maker.OpRole.Backward + for name in [n for n in op_desc.output_arg_names() if n in context]: + fwd_var = block._var_recursive(context[name]) + if not isinstance(fwd_var, Parameter): + continue # TODO verify all use cases + scale_op_desc = block.desc.append_op() + scale_op_desc.set_type("elementwise_div") + scale_op_desc.set_input("X", [name]) + scale_op_desc.set_input("Y", [scale.name]) + scale_op_desc.set_output("Out", [name]) + scale_op_desc._set_attr("axis", -1) + scale_op_desc._set_attr(op_role_attr_name, bwd_role) + + +def update_loss_scale(grads): + state = mixed_precision_global_state() + if state is None or not state.dynamic_scaling: + return + per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads]) + grad_valid = layers.isfinite(per_grad_check) + layers.cond(grad_valid, lambda: state.increment(), + lambda: state.decrement()) + return grad_valid + + +def backward(self, loss, **kwargs): + state = mixed_precision_global_state() + callbacks = 'callbacks' in kwargs and kwargs['callbacks'] or None + if callbacks is None: + from paddle.fluid.clip import error_clip_callback + callbacks = [error_clip_callback] # XXX what if gradient is zero? + if state is not None: + kwargs['callbacks'] = [scale_gradient] + callbacks + else: + kwargs['callbacks'] = callbacks + param_grads = self._backward(loss, **kwargs) + + def zero_grad(): + for _, g in param_grads: + layers.assign(layers.zeros_like(g), g) + + if state is not None: + grad_valid = update_loss_scale(v for k, v in param_grads) + if state.dynamic_scaling: + layers.cond(grad_valid, None, zero_grad) + + return param_grads + + +mixed_precision_patched = False + + +# XXX this is a temporary measure, until thoroughly evaluated +def monkey_patch(): + global mixed_precision_patched + if mixed_precision_patched: + return + create_parameter_orig = lhb.LayerHelperBase.create_parameter + lhb.LayerHelperBase.create_parameter = create_parameter + lhb.LayerHelperBase._create_parameter = create_parameter_orig + backward_orig = optim.Optimizer.backward + optim.Optimizer.backward = backward + optim.Optimizer._backward = backward_orig + mixed_precision_patched = True diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/README.md b/VisualFL/depends/PaddleDetection/ppdet/ext_op/README.md new file mode 100755 index 000000000..0516cc702 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/README.md @@ -0,0 +1,77 @@ +# 自定义OP的编译过程 + +**注意:** 编译自定义OP使用的gcc版本须与Paddle编译使用gcc版本一致,Paddle develop每日版本目前采用**gcc 4.8.2**版本编译,若使用每日版本,请使用**gcc 4.8.2**版本编译自定义OP,否则可能出现兼容性问题。 + +## 代码结构 + + - src: 扩展OP C++/CUDA 源码 + - cornerpool_lib.py: Python API封装 + - tests: 各OP单测程序 + + +## 编译自定义OP + +自定义op需要将实现的C++、CUDA代码编译成动态库,```src/mask.sh```中通过g++/nvcc编译,当然您也可以写Makefile或者CMake。 + +编译需要include PaddlePaddle的相关头文件,链接PaddlePaddle的lib库。 头文件和lib库可通过下面命令获取到: + +``` +# python +>>> import paddle +>>> print(paddle.sysconfig.get_include()) +/paddle/pyenv/local/lib/python2.7/site-packages/paddle/include +>>> print(paddle.sysconfig.get_lib()) +/paddle/pyenv/local/lib/python2.7/site-packages/paddle/libs +``` + +我们提供动态库编译脚本如下: + +``` +cd src +sh make.sh +``` + +最终编译会产出`cornerpool_lib.so` + +**说明:** 若使用源码编译安装PaddlePaddle的方式,编译过程中`cmake`未设置`WITH_MKLDNN`的方式, +编译自定义OP时会报错找不到`mkldnn.h`等文件,可在`make.sh`中删除编译命令中的`-DPADDLE_WITH_MKLDNN`选项。 + + +## 设置环境变量 + +需要将Paddle的核心库设置到`LD_LIBRARY_PATH`里, 先运行下面程序获取路径: + +``` +import paddle +print(paddle.sysconfig.get_lib()) +``` + +可通过如下方式添加动态库路径: + +``` +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c 'import paddle; print(paddle.sysconfig.get_lib())'` +``` + + + +## 执行单测 + +执行下列单测,确保自定义算子可在网络中正确使用: + +``` +# 回到 ext_op 目录,运行单测 +cd .. +python test/test_corner_pool.py +``` + +单测运行成功会输出提示信息,如下所示: + +``` +. +---------------------------------------------------------------------- +Ran 4 test in 2.858s + +OK +``` + +更多关于如何在框架外部自定义 C++ OP,可阅读[官网说明文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/index_cn.html) diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/ext_op/__init__.py new file mode 100755 index 000000000..5d38f757f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from . import cornerpool_lib +from .cornerpool_lib import * + +__all__ = cornerpool_lib.__all__ diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/cornerpool_lib.py b/VisualFL/depends/PaddleDetection/ppdet/ext_op/cornerpool_lib.py new file mode 100755 index 000000000..4cae5a088 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/cornerpool_lib.py @@ -0,0 +1,255 @@ +import os +import paddle.fluid as fluid + +use_cpp = False + +file_dir = os.path.dirname(os.path.abspath(__file__)) +try: + fluid.load_op_library(os.path.join(file_dir, 'src/cornerpool_lib.so')) + use_cpp = True +except: + print( + 'Warning: cornerpool_lib.so not found, use python version instead which may drop the inference speed. Compile in ppdet/ext_op at first if you need cpp version.' + ) + +from paddle.fluid.layer_helper import LayerHelper + +__all__ = [ + 'bottom_pool', + 'top_pool', + 'right_pool', + 'left_pool', +] + + +def cornerpool_op(layer_type, input, name): + helper = LayerHelper(layer_type, input=input, name=name) + dtype = helper.input_dtype() + output = helper.create_variable_for_type_inference(dtype) + max_map = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type=layer_type, + inputs={"X": input}, + outputs={"Output": output, + "MaxMap": max_map}) + return output + + +def bottom_pool(input, is_test=False, name=None): + """ + This layer calculates the bottom pooling output based on the input. + Scan the input from top to bottm for the vertical max-pooling. + The output has the same shape with input. + Args: + input(Variable): This input is a Tensor with shape [N, C, H, W]. + The data type is float32 or float64. + Returns: + Variable(Tensor): The output of bottom_pool, with shape [N, C, H, W]. + The data type is float32 or float64. + Examples: + ..code-block:: python + import paddle.fluid as fluid + import cornerpool_lib + input = fluid.data( + name='input', shape=[2, 64, 10, 10], dtype='float32') + output = corner_pool.bottom_pool(input) + """ + if is_test: + if use_cpp: + output = cornerpool_op("bottom_pool", input, name) + return output + + def cond(i, output): + return i < H + + def body(i, output): + cur = fluid.layers.slice(output, [2], [i], [H]) + next = fluid.layers.slice(output, [2], [0], [H - i]) + max_v = fluid.layers.elementwise_max(cur, next) + orig = fluid.layers.slice(output, [2], [0], [i]) + output = fluid.layers.concat([orig, max_v], axis=2) + i = i * 2 + return [i, output] + + H = fluid.layers.shape(input)[2] + i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=1) + output = input + output = fluid.layers.while_loop(cond, body, [i, output]) + return output[-1] + + H = input.shape[2] + i = 1 + output = input + while i < H: + cur = output[:, :, i:, :] + next = output[:, :, :H - i, :] + max_v = fluid.layers.elementwise_max(cur, next) + output = fluid.layers.concat([output[:, :, :i, :], max_v], axis=2) + i *= 2 + + return output + + +def top_pool(input, is_test=False, name=None): + """ + This layer calculates the top pooling output based on the input. + Scan the input from bottom to top for the vertical max-pooling. + The output has the same shape with input. + Args: + input(Variable): This input is a Tensor with shape [N, C, H, W]. + The data type is float32 or float64. + Returns: + Variable(Tensor): The output of top_pool, with shape [N, C, H, W]. + The data type is float32 or float64. + Examples: + ..code-block:: python + import paddle.fluid as fluid + import cornerpool_lib + input = fluid.data( + name='input', shape=[2, 64, 10, 10], dtype='float32') + output = corner_pool.top_pool(input) + """ + if is_test: + if use_cpp: + output = cornerpool_op("top_pool", input, name) + return output + + def cond(i, output): + return i < H + + def body(i, output): + cur = fluid.layers.slice(output, [2], [0], [H - i]) + next = fluid.layers.slice(output, [2], [i], [H]) + max_v = fluid.layers.elementwise_max(cur, next) + orig = fluid.layers.slice(output, [2], [H - i], [H]) + output = fluid.layers.concat([max_v, orig], axis=2) + i = i * 2 + return [i, output] + + H = fluid.layers.shape(input)[2] + i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=1) + output = input + output = fluid.layers.while_loop(cond, body, [i, output]) + return output[-1] + + H = input.shape[2] + i = 1 + output = input + while i < H: + cur = output[:, :, :H - i, :] + next = output[:, :, i:, :] + max_v = fluid.layers.elementwise_max(cur, next) + output = fluid.layers.concat([max_v, output[:, :, H - i:, :]], axis=2) + i *= 2 + + return output + + +def right_pool(input, is_test=False, name=None): + """ + This layer calculates the right pooling output based on the input. + Scan the input from left to right for the horizontal max-pooling. + The output has the same shape with input. + Args: + input(Variable): This input is a Tensor with shape [N, C, H, W]. + The data type is float32 or float64. + Returns: + Variable(Tensor): The output of right_pool, with shape [N, C, H, W]. + The data type is float32 or float64. + Examples: + ..code-block:: python + import paddle.fluid as fluid + import cornerpool_lib + input = fluid.data( + name='input', shape=[2, 64, 10, 10], dtype='float32') + output = corner_pool.right_pool(input) + """ + if is_test: + if use_cpp: + output = cornerpool_op("right_pool", input, name) + return output + + def cond(i, output): + return i < W + + def body(i, output): + cur = fluid.layers.slice(output, [3], [i], [W]) + next = fluid.layers.slice(output, [3], [0], [W - i]) + max_v = fluid.layers.elementwise_max(cur, next) + orig = fluid.layers.slice(output, [3], [0], [i]) + output = fluid.layers.concat([orig, max_v], axis=-1) + i = i * 2 + return [i, output] + + W = fluid.layers.shape(input)[3] + i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=1) + output = input + output = fluid.layers.while_loop(cond, body, [i, output]) + return output[-1] + + W = input.shape[3] + i = 1 + output = input + while i < W: + cur = output[:, :, :, i:] + next = output[:, :, :, :W - i] + max_v = fluid.layers.elementwise_max(cur, next) + output = fluid.layers.concat([output[:, :, :, :i], max_v], axis=-1) + i *= 2 + + return output + + +def left_pool(input, is_test=False, name=None): + """ + This layer calculates the left pooling output based on the input. + Scan the input from right to left for the horizontal max-pooling. + The output has the same shape with input. + Args: + input(Variable): This input is a Tensor with shape [N, C, H, W]. + The data type is float32 or float64. + Returns: + Variable(Tensor): The output of left_pool, with shape [N, C, H, W]. + The data type is float32 or float64. + Examples: + ..code-block:: python + import paddle.fluid as fluid + import cornerpool_lib + input = fluid.data( + name='input', shape=[2, 64, 10, 10], dtype='float32') + output = corner_pool.left_pool(input) + """ + if is_test: + if use_cpp: + output = cornerpool_op("left_pool", input, name) + return output + + def cond(i, output): + return i < W + + def body(i, output): + cur = fluid.layers.slice(output, [3], [0], [W - i]) + next = fluid.layers.slice(output, [3], [i], [W]) + max_v = fluid.layers.elementwise_max(cur, next) + orig = fluid.layers.slice(output, [3], [W - i], [W]) + output = fluid.layers.concat([max_v, orig], axis=-1) + i = i * 2 + return [i, output] + + W = fluid.layers.shape(input)[3] + i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=1) + output = input + output = fluid.layers.while_loop(cond, body, [i, output]) + return output[-1] + + W = input.shape[3] + i = 1 + output = input + while i < W: + cur = output[:, :, :, :W - i] + next = output[:, :, :, i:] + max_v = fluid.layers.elementwise_max(cur, next) + output = fluid.layers.concat([max_v, output[:, :, :, W - i:]], axis=-1) + i *= 2 + + return output diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/bottom_pool_op.cc b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/bottom_pool_op.cc new file mode 100755 index 000000000..6a867d1f1 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/bottom_pool_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class BottomPoolOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + ctx->ShareDim("X", /*->*/ "MaxMap"); + ctx->ShareDim("X", /*->*/ "Output"); + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class BottomPoolOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", + "Input with shape (batch, C, H, W)"); + AddOutput("MaxMap", "Max map with index of maximum value of input"); + AddOutput("Output", "output with same shape as input(X)"); + AddComment( + R"Doc( +This operatio calculates the bottom pooling output based on the input. +Scan the input from top to bottom for the vertical max-pooling. +The output has the same shape with input. + )Doc"); + } +}; + +class BottomPoolOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + +protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), + "Input(Output@GRAD) should not be null"); + auto out_grad_name = framework::GradVarName("Output"); + ctx->ShareDim(out_grad_name, framework::GradVarName("X")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Output"))->type(), + ctx.GetPlace()); + } +}; + +template +class BottomPoolGradDescMaker : public framework::SingleGradOpMaker { +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + +protected: + void Apply(GradOpPtr op) const override { + op->SetType("bottom_pool_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetInput("MaxMap", this->Output("MaxMap")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(bottom_pool, + ops::BottomPoolOp, + ops::BottomPoolOpMaker, + ops::BottomPoolGradDescMaker, + ops::BottomPoolGradDescMaker); +REGISTER_OPERATOR(bottom_pool_grad, ops::BottomPoolOpGrad); diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/bottom_pool_op.cu b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/bottom_pool_op.cu new file mode 100755 index 000000000..4912ec3c0 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/bottom_pool_op.cu @@ -0,0 +1,104 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include +#include "util.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +class BottomPoolOpCUDAKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *max_map = ctx.Output("MaxMap"); + auto *output = ctx.Output("Output"); + auto *x_data = x->data(); + auto x_dims = x->dims(); + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int num = x->numel(); + auto& dev_ctx = ctx.cuda_device_context(); + + int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int blocks = NumBlocks(num / height); + + auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T)); + T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); + auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int)); + int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); + + GetMaxInfo<<>>(x->data(), NC_num, height, width, 2, false, max_val_data, max_ind_data, max_map_data); + + blocks = NumBlocks(num); + ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 2, output_data); + } +}; + +template +class BottomPoolGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* max_map = ctx.Input("MaxMap"); + auto* out_grad = ctx.Input(framework::GradVarName("Output")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto x_dims = x->dims(); + + auto& dev_ctx = ctx.cuda_device_context(); + T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int grad_num = in_grad->numel(); + int blocks = NumBlocks(grad_num); + FillConstant<<>>(in_grad_data, 0, grad_num); + + ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 2, in_grad_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(bottom_pool, + ops::BottomPoolOpCUDAKernel, + ops::BottomPoolOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(bottom_pool_grad, + ops::BottomPoolGradOpCUDAKernel, + ops::BottomPoolGradOpCUDAKernel); diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/left_pool_op.cc b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/left_pool_op.cc new file mode 100755 index 000000000..c2a8f169f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/left_pool_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class LeftPoolOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + ctx->ShareDim("X", /*->*/ "MaxMap"); + ctx->ShareDim("X", /*->*/ "Output"); + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class LeftPoolOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", + "Input with shape (batch, C, H, W)"); + AddOutput("MaxMap", "Max map with index of maximum value of input"); + AddOutput("Output", "output with same shape as input(X)"); + AddComment( + R"Doc( +This operatio calculates the left pooling output based on the input. +Scan the input from right to left for the horizontal max-pooling. +The output has the same shape with input. + )Doc"); + } +}; + +class LeftPoolOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + +protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), + "Input(Output@GRAD) should not be null"); + auto out_grad_name = framework::GradVarName("Output"); + ctx->ShareDim(out_grad_name, framework::GradVarName("X")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Output"))->type(), + ctx.GetPlace()); + } +}; + +template +class LeftPoolGradDescMaker : public framework::SingleGradOpMaker { +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + +protected: + void Apply(GradOpPtr op) const override { + op->SetType("left_pool_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetInput("MaxMap", this->Output("MaxMap")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(left_pool, + ops::LeftPoolOp, + ops::LeftPoolOpMaker, + ops::LeftPoolGradDescMaker, + ops::LeftPoolGradDescMaker); +REGISTER_OPERATOR(left_pool_grad, ops::LeftPoolOpGrad); diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/left_pool_op.cu b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/left_pool_op.cu new file mode 100755 index 000000000..a5e9323ad --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/left_pool_op.cu @@ -0,0 +1,106 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include +#include "util.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +class LeftPoolOpCUDAKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *max_map = ctx.Output("MaxMap"); + auto *output = ctx.Output("Output"); + auto *x_data = x->data(); + auto x_dims = x->dims(); + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int num = x->numel(); + auto& dev_ctx = ctx.cuda_device_context(); + + int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int blocks = NumBlocks(num / width); + + auto max_val_ptr = memory::Alloc(gpu_place, num / width * sizeof(T)); + T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); + auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int)); + int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); + + GetMaxInfo<<>>(x->data(), NC_num, height, width, 3, true, max_val_data, max_ind_data, max_map_data); + + blocks = NumBlocks(num); + ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 3, output_data); + + } +}; + +template +class LeftPoolGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* max_map = ctx.Input("MaxMap"); + auto* out_grad = ctx.Input(framework::GradVarName("Output")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto x_dims = x->dims(); + + auto& dev_ctx = ctx.cuda_device_context(); + T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int grad_num = in_grad->numel(); + int blocks = NumBlocks(grad_num); + FillConstant<<>>(in_grad_data, 0, grad_num); + + ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 3, in_grad_data); + } +}; + + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(left_pool, + ops::LeftPoolOpCUDAKernel, + ops::LeftPoolOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(left_pool_grad, + ops::LeftPoolGradOpCUDAKernel, + ops::LeftPoolGradOpCUDAKernel); diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/make.sh b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/make.sh new file mode 100755 index 000000000..ffc66034f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/make.sh @@ -0,0 +1,21 @@ +include_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_include())' ) +lib_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_lib())' ) + +echo $include_dir +echo $lib_dir + +OPS='bottom_pool_op top_pool_op right_pool_op left_pool_op' +for op in ${OPS} +do +nvcc ${op}.cu -c -o ${op}.cu.o -ccbin cc -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -DPADDLE_WITH_MKLDNN -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O0 -g -DNVCC \ + -I ${include_dir}/third_party/ \ + -I ${include_dir} +done + +g++ bottom_pool_op.cc bottom_pool_op.cu.o top_pool_op.cc top_pool_op.cu.o right_pool_op.cc right_pool_op.cu.o left_pool_op.cc left_pool_op.cu.o -o cornerpool_lib.so -DPADDLE_WITH_MKLDNN -shared -fPIC -std=c++11 -O0 -g \ + -I ${include_dir}/third_party/ \ + -I ${include_dir} \ + -L ${lib_dir} \ + -L /usr/local/cuda/lib64 -lpaddle_framework -lcudart + +rm *.cu.o diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/right_pool_op.cc b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/right_pool_op.cc new file mode 100755 index 000000000..6bf74a1b0 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/right_pool_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class RightPoolOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + ctx->ShareDim("X", /*->*/ "MaxMap"); + ctx->ShareDim("X", /*->*/ "Output"); + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class RightPoolOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", + "Input with shape (batch, C, H, W)"); + AddOutput("MaxMap", "Max map with index of maximum value of input"); + AddOutput("Output", "output with same shape as input(X)"); + AddComment( + R"Doc( +This operatio calculates the right pooling output based on the input. +Scan the input from left to right or the horizontal max-pooling. +The output has the same shape with input. + )Doc"); + } +}; + +class RightPoolOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + +protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), + "Input(Output@GRAD) should not be null"); + auto out_grad_name = framework::GradVarName("Output"); + ctx->ShareDim(out_grad_name, framework::GradVarName("X")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Output"))->type(), + ctx.GetPlace()); + } +}; + +template +class RightPoolGradDescMaker : public framework::SingleGradOpMaker { +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + +protected: + void Apply(GradOpPtr op) const override { + op->SetType("right_pool_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetInput("MaxMap", this->Output("MaxMap")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(right_pool, + ops::RightPoolOp, + ops::RightPoolOpMaker, + ops::RightPoolGradDescMaker, + ops::RightPoolGradDescMaker); +REGISTER_OPERATOR(right_pool_grad, ops::RightPoolOpGrad); diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/right_pool_op.cu b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/right_pool_op.cu new file mode 100755 index 000000000..08a52ecf1 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/right_pool_op.cu @@ -0,0 +1,105 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include +#include "util.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +class RightPoolOpCUDAKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *max_map = ctx.Output("MaxMap"); + auto *output = ctx.Output("Output"); + auto *x_data = x->data(); + auto x_dims = x->dims(); + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int num = x->numel(); + auto& dev_ctx = ctx.cuda_device_context(); + + int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int blocks = NumBlocks(num / width); + + auto max_val_ptr = memory::Alloc(gpu_place, num / width * sizeof(T)); + T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); + auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int)); + int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); + + GetMaxInfo<<>>(x->data(), NC_num, height, width, 3, false, max_val_data, max_ind_data, max_map_data); + + blocks = NumBlocks(num); + ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 3, output_data); + + } +}; + +template +class RightPoolGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* max_map = ctx.Input("MaxMap"); + auto* out_grad = ctx.Input(framework::GradVarName("Output")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto x_dims = x->dims(); + + auto& dev_ctx = ctx.cuda_device_context(); + T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int grad_num = in_grad->numel(); + int blocks = NumBlocks(grad_num); + FillConstant<<>>(in_grad_data, 0, grad_num); + + ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 3, in_grad_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(right_pool, + ops::RightPoolOpCUDAKernel, + ops::RightPoolOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(right_pool_grad, + ops::RightPoolGradOpCUDAKernel, + ops::RightPoolGradOpCUDAKernel); diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/top_pool_op.cc b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/top_pool_op.cc new file mode 100755 index 000000000..29cba6660 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/top_pool_op.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class TopPoolOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + ctx->ShareDim("X", /*->*/ "MaxMap"); + ctx->ShareDim("X", /*->*/ "Output"); + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class TopPoolOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", + "Input with shape (batch, C, H, W)"); + AddOutput("MaxMap", "Max map with index of maximum value of input"); + AddOutput("Output", "Output with same shape as input(X)"); + AddComment( + R"Doc( +This operatio calculates the top pooling output based on the input. +Scan the input from bottom to top for the vertical max-pooling. +The output has the same shape with input. + )Doc"); + } +}; + +class TopPoolOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + +protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), + "Input(Output@GRAD) should not be null"); + + auto out_grad_name = framework::GradVarName("Output"); + ctx->ShareDim(out_grad_name, framework::GradVarName("X")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Output"))->type(), + ctx.GetPlace()); + } +}; + +template +class TopPoolGradDescMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("top_pool_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetInput("MaxMap", this->Output("MaxMap")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(top_pool, + ops::TopPoolOp, + ops::TopPoolOpMaker, + ops::TopPoolGradDescMaker, + ops::TopPoolGradDescMaker); +REGISTER_OPERATOR(top_pool_grad, ops::TopPoolOpGrad); diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/top_pool_op.cu b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/top_pool_op.cu new file mode 100755 index 000000000..f6237fe79 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/top_pool_op.cu @@ -0,0 +1,104 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +GUnless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include +#include "util.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +class TopPoolOpCUDAKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *max_map = ctx.Output("MaxMap"); + auto *output = ctx.Output("Output"); + auto *x_data = x->data(); + auto x_dims = x->dims(); + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int num = x->numel(); + auto& dev_ctx = ctx.cuda_device_context(); + + int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int blocks = NumBlocks(num / height); + + auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T)); + T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); + auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int)); + int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); + + GetMaxInfo<<>>(x->data(), NC_num, height, width, 2, true, max_val_data, max_ind_data, max_map_data); + + blocks = NumBlocks(num); + ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 2, output_data); + } +}; + +template +class TopPoolGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* max_map = ctx.Input("MaxMap"); + auto* out_grad = ctx.Input(framework::GradVarName("Output")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto x_dims = x->dims(); + auto& dev_ctx = ctx.cuda_device_context(); + T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int grad_num = in_grad->numel(); + int blocks = NumBlocks(grad_num); + FillConstant<<>>(in_grad_data, 0, grad_num); + + ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 2, in_grad_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(top_pool, + ops::TopPoolOpCUDAKernel, + ops::TopPoolOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(top_pool_grad, + ops::TopPoolGradOpCUDAKernel, + ops::TopPoolGradOpCUDAKernel); diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/util.cu.h b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/util.cu.h new file mode 100755 index 000000000..615e45a78 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/src/util.cu.h @@ -0,0 +1,223 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void FillConstant(T* x, int num, int fill_num) { + CUDA_1D_KERNEL_LOOP(i, fill_num) { + x[i] = static_cast(num); + } +} + +template +__global__ void SliceOnAxis(const T* x, const int NC_num, const int H, const int W, + const int axis, const int start, const int end, + T* output) { + int HW_num = H * W; + int length = axis == 2 ? W : H; + int sliced_len = end - start; + int cur_HW_num = length * sliced_len; + // slice input on H or W (axis is 2 or 3) + CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) { + int NC_id = i / cur_HW_num; + int HW_id = i % cur_HW_num; + if (axis == 2){ + output[i] = x[NC_id * HW_num + start * W + HW_id]; + } else if (axis == 3) { + int col = HW_id % sliced_len; + int row = HW_id / sliced_len; + output[i] = x[NC_id * HW_num + row * W + start + col]; + } + } +} + +template +__global__ void MaxOut(const T* input, const int next_ind, const int NC_num, + const int H, const int W, const int axis, + const int start, const int end, T* output) { + int HW_num = H * W; + int length = axis == 2 ? W : H; + T cur = static_cast(0.); + T next = static_cast(0.); + T max_v = static_cast(0.); + int sliced_len = end - start; + int cur_HW_num = length * sliced_len; + // compare cur and next and assign max values to output + CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) { + int NC_id = i / cur_HW_num; + int HW_id = i % cur_HW_num; + + if (axis == 2){ + cur = input[NC_id * HW_num + start * W + HW_id]; + next = input[NC_id * HW_num + next_ind * W + HW_id]; + max_v = cur > next ? cur : next; + output[NC_id * HW_num + start * W + HW_id] = max_v; + } else if (axis == 3) { + int col = HW_id % sliced_len; + int row = HW_id / sliced_len; + cur = input[NC_id * HW_num + row * W + start + col]; + next = input[NC_id * HW_num + row * W + next_ind + col]; + max_v = cur > next ? cur : next; + output[NC_id * HW_num + row * W + start + col] = max_v; + } + __syncthreads(); + } +} + +template +__global__ void UpdateMaxInfo(const T* input, const int NC_num, + const int H, const int W, const int axis, + const int index, T* max_val, int* max_ind) { + int length = axis == 2 ? W : H; + int HW_num = H * W; + T val = static_cast(0.); + CUDA_1D_KERNEL_LOOP(i, NC_num * length) { + int NC_id = i / length; + int length_id = i % length; + if (axis == 2) { + val = input[NC_id * HW_num + index * W + length_id]; + } else if (axis == 3) { + val = input[NC_id * HW_num + length_id * W + index]; + } + if (val > max_val[i]) { + max_val[i] = val; + max_ind[i] = index; + } + __syncthreads(); + } +} + +template +__global__ void ScatterAddOnAxis(const T* input, const int start, const int* max_ind, const int NC_num, const int H, const int W, const int axis, T* output) { + int length = axis == 2 ? W : H; + int HW_num = H * W; + CUDA_1D_KERNEL_LOOP(i, NC_num * length) { + int NC_id = i / length; + int length_id = i % length; + int id_ = max_ind[i]; + if (axis == 2) { + platform::CudaAtomicAdd(output + NC_id * HW_num + id_ * W + length_id, input[NC_id * HW_num + start * W + length_id]); + //output[NC_id * HW_num + id_ * W + length_id] += input[NC_id * HW_num + start * W + length_id]; + } else if (axis == 3) { + platform::CudaAtomicAdd(output + NC_id * HW_num + length_id * W + id_, input[NC_id * HW_num + length_id * W + start]); + //output[NC_id * HW_num + length_id * W + id_] += input[NC_id * HW_num + length_id * W + start]; + } + __syncthreads(); + } +} + +template +__global__ void GetMaxInfo(const T* input, const int NC_num, + const int H, const int W, const int axis, + const bool reverse, T* max_val, int* max_ind, + int* max_map) { + int start = 0; + int end = axis == 2 ? H: W; + int s = reverse ? end-1 : start; + int e = reverse ? start-1 : end; + int step = reverse ? -1 : 1; + int len = axis == 2 ? W : H; + int loc = 0; + T val = static_cast(0.); + for (int i = s; ; ) { + if (i == s) { + CUDA_1D_KERNEL_LOOP(j, NC_num * len) { + int NC_id = j / len; + int len_id = j % len; + if (axis == 2) { + loc = NC_id * H * W + i * W + len_id; + } else if (axis == 3){ + loc = NC_id * H * W + len_id * W + i; + } + max_ind[j] = i; + max_map[loc] = max_ind[j]; + max_val[j] = input[loc]; + __syncthreads(); + } + } else { + CUDA_1D_KERNEL_LOOP(j, NC_num * len) { + int NC_id = j / len; + int len_id = j % len; + + if (axis == 2) { + loc = NC_id * H * W + i * W + len_id; + } else if (axis == 3){ + loc = NC_id * H * W + len_id * W + i; + } + val = input[loc]; + T max_v = max_val[j]; + if (val > max_v) { + max_val[j] = val; + max_map[loc] = i; + max_ind[j] = i; + } else { + max_map[loc] = max_ind[j]; + } + __syncthreads(); + } + } + i += step; + if (s < e && i >= e) break; + if (s > e && i <= e) break; + } +} + +template +__global__ void ScatterAddFw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){ + CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) { + int loc = max_map[i]; + int NC_id = i / (H * W); + int len_id = 0; + if (axis == 2) { + len_id = i % W; + output[i] = input[NC_id * H * W + loc * W + len_id]; + } else { + len_id = i % (H * W) / W; + output[i] = input[NC_id * H * W + len_id * W + loc]; + } + } +} + +template +__global__ void ScatterAddBw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){ + CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) { + int loc = max_map[i]; + int NC_id = i / (H * W); + int len_id = 0; + int offset = 0; + if (axis == 2) { + len_id = i % W; + offset = NC_id * H * W + loc * W + len_id; + } else { + len_id = i % (H * W) / W; + offset = NC_id * H * W + len_id * W + loc; + } + platform::CudaAtomicAdd(output + offset, input[i]); + } +} + +} // namespace operators +} // namespace paddle diff --git a/VisualFL/depends/PaddleDetection/ppdet/ext_op/test/test_corner_pool.py b/VisualFL/depends/PaddleDetection/ppdet/ext_op/test/test_corner_pool.py new file mode 100755 index 000000000..adb5d3c74 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/ext_op/test/test_corner_pool.py @@ -0,0 +1,125 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import os +import sys +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.ext_op import cornerpool_lib +from ppdet.utils.check import enable_static_mode + + +def bottom_pool_np(x): + height = x.shape[2] + output = x.copy() + for ind in range(height): + cur = output[:, :, ind:height, :] + next = output[:, :, :height - ind, :] + output[:, :, ind:height, :] = np.maximum(cur, next) + return output + + +def top_pool_np(x): + height = x.shape[2] + output = x.copy() + for ind in range(height): + cur = output[:, :, :height - ind, :] + next = output[:, :, ind:height, :] + output[:, :, :height - ind, :] = np.maximum(cur, next) + return output + + +def right_pool_np(x): + width = x.shape[3] + output = x.copy() + for ind in range(width): + cur = output[:, :, :, ind:width] + next = output[:, :, :, :width - ind] + output[:, :, :, ind:width] = np.maximum(cur, next) + return output + + +def left_pool_np(x): + width = x.shape[3] + output = x.copy() + for ind in range(width): + cur = output[:, :, :, :width - ind] + next = output[:, :, :, ind:width] + output[:, :, :, :width - ind] = np.maximum(cur, next) + return output + + +class TestRightPoolOp(unittest.TestCase): + def funcmap(self): + self.func_map = { + 'bottom_x': [cornerpool_lib.bottom_pool, bottom_pool_np], + 'top_x': [cornerpool_lib.top_pool, top_pool_np], + 'right_x': [cornerpool_lib.right_pool, right_pool_np], + 'left_x': [cornerpool_lib.left_pool, left_pool_np] + } + + def setup(self): + self.name = 'right_x' + + def test_check_output(self): + self.funcmap() + self.setup() + x_shape = (2, 10, 16, 16) + x_type = "float64" + + sp = fluid.Program() + tp = fluid.Program() + place = fluid.CUDAPlace(0) + + with fluid.program_guard(tp, sp): + x = fluid.data(name=self.name, shape=x_shape, dtype=x_type) + y = self.func_map[self.name][0](x) + + np.random.seed(0) + x_np = np.random.uniform(-1000, 1000, x_shape).astype(x_type) + + out_np = self.func_map[self.name][1](x_np) + + exe = fluid.Executor(place) + outs = exe.run(tp, feed={self.name: x_np}, fetch_list=[y]) + + self.assertTrue(np.allclose(outs, out_np)) + + +class TestTopPoolOp(TestRightPoolOp): + def setup(self): + self.name = 'top_x' + + +class TestBottomPoolOp(TestRightPoolOp): + def setup(self): + self.name = 'bottom_x' + + +class TestLeftPoolOp(TestRightPoolOp): + def setup(self): + self.name = 'left_x' + + +if __name__ == "__main__": + enable_static_mode() + unittest.main() diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/__init__.py new file mode 100755 index 000000000..423a4802e --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +# XXX for triggering decorators +from . import anchor_heads +from . import architectures +from . import backbones +from . import roi_extractors +from . import roi_heads +from . import ops +from . import target_assigners +from . import mask_head + +from .anchor_heads import * +from .architectures import * +from .backbones import * +from .roi_extractors import * +from .roi_heads import * +from .ops import * +from .target_assigners import * +from .mask_head import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/__init__.py new file mode 100755 index 000000000..547999cb3 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from . import rpn_head +from . import yolo_head +from . import retina_head +from . import fcos_head +from . import corner_head +from . import efficient_head +from . import ttf_head +from . import solov2_head + +from .rpn_head import * +from .yolo_head import * +from .retina_head import * +from .fcos_head import * +from .corner_head import * +from .efficient_head import * +from .ttf_head import * +from .solov2_head import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/corner_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/corner_head.py new file mode 100755 index 000000000..de504ccad --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/corner_head.py @@ -0,0 +1,496 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Constant + +from ..backbones.hourglass import _conv_norm, kaiming_init +from ppdet.core.workspace import register +import numpy as np +import logging +logger = logging.getLogger(__name__) + +__all__ = ['CornerHead'] + + +def corner_output(x, pool1, pool2, dim, name=None): + p_conv1 = fluid.layers.conv2d( + pool1 + pool2, + filter_size=3, + num_filters=dim, + padding=1, + param_attr=ParamAttr( + name=name + "_p_conv1_weight", + initializer=kaiming_init(pool1 + pool2, 3)), + bias_attr=False, + name=name + '_p_conv1') + p_bn1 = fluid.layers.batch_norm( + p_conv1, + param_attr=ParamAttr(name=name + '_p_bn1_weight'), + bias_attr=ParamAttr(name=name + '_p_bn1_bias'), + moving_mean_name=name + '_p_bn1_running_mean', + moving_variance_name=name + '_p_bn1_running_var', + name=name + '_p_bn1') + + conv1 = fluid.layers.conv2d( + x, + filter_size=1, + num_filters=dim, + param_attr=ParamAttr( + name=name + "_conv1_weight", initializer=kaiming_init(x, 1)), + bias_attr=False, + name=name + '_conv1') + bn1 = fluid.layers.batch_norm( + conv1, + param_attr=ParamAttr(name=name + '_bn1_weight'), + bias_attr=ParamAttr(name=name + '_bn1_bias'), + moving_mean_name=name + '_bn1_running_mean', + moving_variance_name=name + '_bn1_running_var', + name=name + '_bn1') + + relu1 = fluid.layers.relu(p_bn1 + bn1) + conv2 = _conv_norm( + relu1, 3, dim, pad=1, bn_act='relu', name=name + '_conv2') + return conv2 + + +def corner_pool(x, dim, pool1, pool2, is_test=False, name=None): + p1_conv1 = _conv_norm( + x, 3, 128, pad=1, bn_act='relu', name=name + '_p1_conv1') + pool1 = pool1(p1_conv1, is_test=is_test, name=name + '_pool1') + p2_conv1 = _conv_norm( + x, 3, 128, pad=1, bn_act='relu', name=name + '_p2_conv1') + pool2 = pool2(p2_conv1, is_test=is_test, name=name + '_pool2') + + conv2 = corner_output(x, pool1, pool2, dim, name) + return conv2 + + +def gather_feat(feat, ind, batch_size=1): + feats = [] + for bind in range(batch_size): + feat_b = feat[bind] + ind_b = ind[bind] + ind_b.stop_gradient = True + feat_bg = fluid.layers.gather(feat_b, ind_b) + feats.append(fluid.layers.unsqueeze(feat_bg, axes=[0])) + feat_g = fluid.layers.concat(feats, axis=0) + return feat_g + + +def mask_feat(feat, ind, batch_size=1): + feat_t = fluid.layers.transpose(feat, [0, 2, 3, 1]) + C = feat_t.shape[3] + feat_r = fluid.layers.reshape(feat_t, [0, -1, C]) + return gather_feat(feat_r, ind, batch_size) + + +def nms(heat): + hmax = fluid.layers.pool2d(heat, pool_size=3, pool_padding=1) + keep = fluid.layers.cast(heat == hmax, 'float32') + return heat * keep + + +def _topk(scores, batch_size, height, width, K): + scores_r = fluid.layers.reshape(scores, [batch_size, -1]) + topk_scores, topk_inds = fluid.layers.topk(scores_r, K) + topk_inds = fluid.layers.cast(topk_inds, 'int32') + topk_clses = topk_inds // (height * width) + topk_inds = topk_inds % (height * width) + topk_ys = fluid.layers.cast(topk_inds // width, 'float32') + topk_xs = fluid.layers.cast(topk_inds % width, 'float32') + return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs + + +def filter_scores(scores, index_list): + for ind in index_list: + tmp = scores * fluid.layers.cast((1 - ind), 'float32') + scores = tmp - fluid.layers.cast(ind, 'float32') + return scores + + +def decode(tl_heat, + br_heat, + tl_tag, + br_tag, + tl_regr, + br_regr, + ae_threshold=1, + num_dets=1000, + K=100, + batch_size=1): + shape = fluid.layers.shape(tl_heat) + H, W = shape[2], shape[3] + + tl_heat = fluid.layers.sigmoid(tl_heat) + br_heat = fluid.layers.sigmoid(br_heat) + + tl_heat_nms = nms(tl_heat) + br_heat_nms = nms(br_heat) + + tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat_nms, batch_size, + H, W, K) + br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat_nms, batch_size, + H, W, K) + tl_ys = fluid.layers.expand( + fluid.layers.reshape(tl_ys, [-1, K, 1]), [1, 1, K]) + tl_xs = fluid.layers.expand( + fluid.layers.reshape(tl_xs, [-1, K, 1]), [1, 1, K]) + br_ys = fluid.layers.expand( + fluid.layers.reshape(br_ys, [-1, 1, K]), [1, K, 1]) + br_xs = fluid.layers.expand( + fluid.layers.reshape(br_xs, [-1, 1, K]), [1, K, 1]) + + tl_regr = mask_feat(tl_regr, tl_inds, batch_size) + br_regr = mask_feat(br_regr, br_inds, batch_size) + tl_regr = fluid.layers.reshape(tl_regr, [-1, K, 1, 2]) + br_regr = fluid.layers.reshape(br_regr, [-1, 1, K, 2]) + + tl_xs = tl_xs + tl_regr[:, :, :, 0] + tl_ys = tl_ys + tl_regr[:, :, :, 1] + br_xs = br_xs + br_regr[:, :, :, 0] + br_ys = br_ys + br_regr[:, :, :, 1] + + bboxes = fluid.layers.stack([tl_xs, tl_ys, br_xs, br_ys], axis=-1) + + tl_tag = mask_feat(tl_tag, tl_inds, batch_size) + br_tag = mask_feat(br_tag, br_inds, batch_size) + tl_tag = fluid.layers.expand( + fluid.layers.reshape(tl_tag, [-1, K, 1]), [1, 1, K]) + br_tag = fluid.layers.expand( + fluid.layers.reshape(br_tag, [-1, 1, K]), [1, K, 1]) + dists = fluid.layers.abs(tl_tag - br_tag) + + tl_scores = fluid.layers.expand( + fluid.layers.reshape(tl_scores, [-1, K, 1]), [1, 1, K]) + br_scores = fluid.layers.expand( + fluid.layers.reshape(br_scores, [-1, 1, K]), [1, K, 1]) + scores = (tl_scores + br_scores) / 2. + + tl_clses = fluid.layers.expand( + fluid.layers.reshape(tl_clses, [-1, K, 1]), [1, 1, K]) + br_clses = fluid.layers.expand( + fluid.layers.reshape(br_clses, [-1, 1, K]), [1, K, 1]) + cls_inds = fluid.layers.cast(tl_clses != br_clses, 'int32') + dist_inds = fluid.layers.cast(dists > ae_threshold, 'int32') + + width_inds = fluid.layers.cast(br_xs < tl_xs, 'int32') + height_inds = fluid.layers.cast(br_ys < tl_ys, 'int32') + + scores = filter_scores(scores, + [cls_inds, dist_inds, width_inds, height_inds]) + scores = fluid.layers.reshape(scores, [-1, K * K]) + + scores, inds = fluid.layers.topk(scores, num_dets) + scores = fluid.layers.reshape(scores, [-1, num_dets, 1]) + + bboxes = fluid.layers.reshape(bboxes, [batch_size, -1, 4]) + bboxes = gather_feat(bboxes, inds, batch_size) + + clses = fluid.layers.reshape(tl_clses, [batch_size, -1, 1]) + clses = gather_feat(clses, inds, batch_size) + + tl_scores = fluid.layers.reshape(tl_scores, [batch_size, -1, 1]) + tl_scores = gather_feat(tl_scores, inds, batch_size) + br_scores = fluid.layers.reshape(br_scores, [batch_size, -1, 1]) + br_scores = gather_feat(br_scores, inds, batch_size) + + bboxes = fluid.layers.cast(bboxes, 'float32') + clses = fluid.layers.cast(clses, 'float32') + return bboxes, scores, tl_scores, br_scores, clses + + +@register +class CornerHead(object): + """ + CornerNet head with corner_pooling + + Args: + train_batch_size(int): batch_size in training process + test_batch_size(int): batch_size in test process, 1 by default + num_classes(int): num of classes, 80 by default + stack(int): stack of backbone, 2 by default + pull_weight(float): weight of pull_loss, 0.1 by default + push_weight(float): weight of push_loss, 0.1 by default + ae_threshold(float|int): threshold for valid distance of predicted tags, 1 by default + num_dets(int): num of detections, 1000 by default + top_k(int): choose top_k pair of corners in prediction, 100 by default + """ + __shared__ = ['num_classes', 'stack', 'train_batch_size'] + + def __init__(self, + train_batch_size=14, + test_batch_size=1, + num_classes=80, + stack=2, + pull_weight=0.1, + push_weight=0.1, + ae_threshold=1, + num_dets=1000, + top_k=100): + self.train_batch_size = train_batch_size + self.test_batch_size = test_batch_size + self.num_classes = num_classes + self.stack = stack + self.pull_weight = pull_weight + self.push_weight = push_weight + self.ae_threshold = ae_threshold + self.num_dets = num_dets + self.K = top_k + self.tl_heats = [] + self.br_heats = [] + self.tl_tags = [] + self.br_tags = [] + self.tl_offs = [] + self.br_offs = [] + + def pred_mod(self, x, dim, name=None): + conv0 = _conv_norm( + x, 1, 256, with_bn=False, bn_act='relu', name=name + '_0') + conv1 = fluid.layers.conv2d( + input=conv0, + filter_size=1, + num_filters=dim, + param_attr=ParamAttr( + name=name + "_1_weight", initializer=kaiming_init(conv0, 1)), + bias_attr=ParamAttr( + name=name + "_1_bias", initializer=Constant(-2.19)), + name=name + '_1') + return conv1 + + def get_output(self, input): + try: + from ppdet.ext_op import cornerpool_lib + except: + logger.error( + "cornerpool_lib not found, compile in ppdet/ext_op at first") + for ind in range(self.stack): + cnv = input[ind] + tl_modules = corner_pool( + cnv, + 256, + cornerpool_lib.top_pool, + cornerpool_lib.left_pool, + name='tl_modules_' + str(ind)) + br_modules = corner_pool( + cnv, + 256, + cornerpool_lib.bottom_pool, + cornerpool_lib.right_pool, + name='br_modules_' + str(ind)) + + tl_heat = self.pred_mod( + tl_modules, self.num_classes, name='tl_heats_' + str(ind)) + br_heat = self.pred_mod( + br_modules, self.num_classes, name='br_heats_' + str(ind)) + + tl_tag = self.pred_mod(tl_modules, 1, name='tl_tags_' + str(ind)) + br_tag = self.pred_mod(br_modules, 1, name='br_tags_' + str(ind)) + + tl_off = self.pred_mod(tl_modules, 2, name='tl_offs_' + str(ind)) + br_off = self.pred_mod(br_modules, 2, name='br_offs_' + str(ind)) + + self.tl_heats.append(tl_heat) + self.br_heats.append(br_heat) + self.tl_tags.append(tl_tag) + self.br_tags.append(br_tag) + self.tl_offs.append(tl_off) + self.br_offs.append(br_off) + + def focal_loss(self, preds, gt, gt_masks): + preds_clip = [] + none_pos = fluid.layers.cast( + fluid.layers.reduce_sum(gt_masks) == 0, 'float32') + none_pos.stop_gradient = True + min = fluid.layers.assign(np.array([1e-4], dtype='float32')) + max = fluid.layers.assign(np.array([1 - 1e-4], dtype='float32')) + for pred in preds: + pred_s = fluid.layers.sigmoid(pred) + pred_min = fluid.layers.elementwise_max(pred_s, min) + pred_max = fluid.layers.elementwise_min(pred_min, max) + preds_clip.append(pred_max) + + ones = fluid.layers.ones_like(gt) + + fg_map = fluid.layers.cast(gt == ones, 'float32') + fg_map.stop_gradient = True + num_pos = fluid.layers.reduce_sum(fg_map) + min_num = fluid.layers.ones_like(num_pos) + num_pos = fluid.layers.elementwise_max(num_pos, min_num) + num_pos.stop_gradient = True + bg_map = fluid.layers.cast(gt < ones, 'float32') + bg_map.stop_gradient = True + neg_weights = fluid.layers.pow(1 - gt, 4) * bg_map + neg_weights.stop_gradient = True + loss = fluid.layers.assign(np.array([0], dtype='float32')) + for ind, pred in enumerate(preds_clip): + pos_loss = fluid.layers.log(pred) * fluid.layers.pow(1 - pred, + 2) * fg_map + + neg_loss = fluid.layers.log(1 - pred) * fluid.layers.pow( + pred, 2) * neg_weights + + pos_loss = fluid.layers.reduce_sum(pos_loss) + neg_loss = fluid.layers.reduce_sum(neg_loss) + focal_loss_ = (neg_loss + pos_loss) / (num_pos + none_pos) + loss -= focal_loss_ + return loss + + def ae_loss(self, tl_tag, br_tag, gt_masks): + num = fluid.layers.reduce_sum(gt_masks, dim=1) + num_stop_gradient = True + tag0 = fluid.layers.squeeze(tl_tag, [2]) + tag1 = fluid.layers.squeeze(br_tag, [2]) + tag_mean = (tag0 + tag1) / 2 + + tag0 = fluid.layers.pow(tag0 - tag_mean, 2) + tag1 = fluid.layers.pow(tag1 - tag_mean, 2) + + tag0 = fluid.layers.elementwise_div(tag0, num + 1e-4, axis=0) + tag1 = fluid.layers.elementwise_div(tag1, num + 1e-4, axis=0) + tag0 = tag0 * gt_masks + tag1 = tag1 * gt_masks + tag0 = fluid.layers.reduce_sum(tag0) + tag1 = fluid.layers.reduce_sum(tag1) + + pull = tag0 + tag1 + + mask_1 = fluid.layers.expand( + fluid.layers.unsqueeze(gt_masks, [1]), [1, gt_masks.shape[1], 1]) + mask_2 = fluid.layers.expand( + fluid.layers.unsqueeze(gt_masks, [2]), [1, 1, gt_masks.shape[1]]) + mask = fluid.layers.cast((mask_1 + mask_2) == 2, 'float32') + mask.stop_gradient = True + + num2 = (num - 1) * num + num2.stop_gradient = True + tag_mean_1 = fluid.layers.expand( + fluid.layers.unsqueeze(tag_mean, [1]), [1, tag_mean.shape[1], 1]) + tag_mean_2 = fluid.layers.expand( + fluid.layers.unsqueeze(tag_mean, [2]), [1, 1, tag_mean.shape[1]]) + dist = tag_mean_1 - tag_mean_2 + dist = 1 - fluid.layers.abs(dist) + dist = fluid.layers.relu(dist) + dist = fluid.layers.elementwise_sub(dist, 1 / (num + 1e-4), axis=0) + dist = fluid.layers.elementwise_div(dist, (num2 + 1e-4), axis=0) + dist = dist * mask + push = fluid.layers.reduce_sum(dist) + return pull, push + + def off_loss(self, off, gt_off, gt_masks): + mask = fluid.layers.unsqueeze(gt_masks, [2]) + mask = fluid.layers.expand_as(mask, gt_off) + mask.stop_gradient = True + off_loss = fluid.layers.smooth_l1(off, gt_off, mask, mask) + off_loss = fluid.layers.reduce_sum(off_loss) + total_num = fluid.layers.reduce_sum(gt_masks) + total_num.stop_gradient = True + return off_loss / (total_num + 1e-4) + + def get_loss(self, targets): + gt_tl_heat = targets['tl_heatmaps'] + gt_br_heat = targets['br_heatmaps'] + gt_masks = targets['tag_masks'] + gt_tl_off = targets['tl_regrs'] + gt_br_off = targets['br_regrs'] + gt_tl_ind = targets['tl_tags'] + gt_br_ind = targets['br_tags'] + gt_masks = fluid.layers.cast(gt_masks, 'float32') + + focal_loss = 0 + focal_loss_ = self.focal_loss(self.tl_heats, gt_tl_heat, gt_masks) + focal_loss += focal_loss_ + focal_loss_ = self.focal_loss(self.br_heats, gt_br_heat, gt_masks) + focal_loss += focal_loss_ + + pull_loss = 0 + push_loss = 0 + + ones = fluid.layers.assign(np.array([1], dtype='float32')) + tl_tags = [ + mask_feat(tl_tag, gt_tl_ind, self.train_batch_size) + for tl_tag in self.tl_tags + ] + br_tags = [ + mask_feat(br_tag, gt_br_ind, self.train_batch_size) + for br_tag in self.br_tags + ] + + pull_loss, push_loss = 0, 0 + + for tl_tag, br_tag in zip(tl_tags, br_tags): + pull, push = self.ae_loss(tl_tag, br_tag, gt_masks) + pull_loss += pull + push_loss += push + + tl_offs = [ + mask_feat(tl_off, gt_tl_ind, self.train_batch_size) + for tl_off in self.tl_offs + ] + br_offs = [ + mask_feat(br_off, gt_br_ind, self.train_batch_size) + for br_off in self.br_offs + ] + + off_loss = 0 + for tl_off, br_off in zip(tl_offs, br_offs): + off_loss += self.off_loss(tl_off, gt_tl_off, gt_masks) + off_loss += self.off_loss(br_off, gt_br_off, gt_masks) + + pull_loss = self.pull_weight * pull_loss + push_loss = self.push_weight * push_loss + + loss = ( + focal_loss + pull_loss + push_loss + off_loss) / len(self.tl_heats) + return {'loss': loss} + + def get_prediction(self, input): + try: + from ppdet.ext_op import cornerpool_lib + except: + logger.error( + "cornerpool_lib not found, compile in ppdet/ext_op at first") + ind = self.stack - 1 + tl_modules = corner_pool( + input, + 256, + cornerpool_lib.top_pool, + cornerpool_lib.left_pool, + is_test=True, + name='tl_modules_' + str(ind)) + br_modules = corner_pool( + input, + 256, + cornerpool_lib.bottom_pool, + cornerpool_lib.right_pool, + is_test=True, + name='br_modules_' + str(ind)) + tl_heat = self.pred_mod( + tl_modules, self.num_classes, name='tl_heats_' + str(ind)) + br_heat = self.pred_mod( + br_modules, self.num_classes, name='br_heats_' + str(ind)) + tl_tag = self.pred_mod(tl_modules, 1, name='tl_tags_' + str(ind)) + br_tag = self.pred_mod(br_modules, 1, name='br_tags_' + str(ind)) + + tl_off = self.pred_mod(tl_modules, 2, name='tl_offs_' + str(ind)) + br_off = self.pred_mod(br_modules, 2, name='br_offs_' + str(ind)) + + return decode(tl_heat, br_heat, tl_tag, br_tag, tl_off, br_off, + self.ae_threshold, self.num_dets, self.K, + self.test_batch_size) diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/efficient_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/efficient_head.py new file mode 100755 index 000000000..5db9ffb95 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/efficient_head.py @@ -0,0 +1,189 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import TruncatedNormal, Constant +from paddle.fluid.regularizer import L2Decay +from ppdet.modeling.ops import RetinaOutputDecoder + +from ppdet.core.workspace import register + +__all__ = ['EfficientHead'] + + +@register +class EfficientHead(object): + """ + EfficientDet Head + + Args: + output_decoder (object): `RetinaOutputDecoder` instance. + repeat (int): Number of convolution layers. + num_chan (int): Number of octave output channels. + prior_prob (float): Initial value of the class prediction layer bias. + num_anchors (int): Number of anchors per cell. + num_classes (int): Number of classes. + gamma (float): Gamma parameter for focal loss. + alpha (float): Alpha parameter for focal loss. + sigma (float): Sigma parameter for smooth l1 loss. + """ + __inject__ = ['output_decoder'] + __shared__ = ['num_classes'] + + def __init__(self, + output_decoder=RetinaOutputDecoder().__dict__, + repeat=3, + num_chan=64, + prior_prob=0.01, + num_anchors=9, + num_classes=81, + gamma=1.5, + alpha=0.25, + delta=0.1): + super(EfficientHead, self).__init__() + self.output_decoder = output_decoder + self.repeat = repeat + self.num_chan = num_chan + self.prior_prob = prior_prob + self.num_anchors = num_anchors + self.num_classes = num_classes + self.gamma = gamma + self.alpha = alpha + self.delta = delta + if isinstance(output_decoder, dict): + self.output_decoder = RetinaOutputDecoder(**output_decoder) + + def _get_output(self, body_feats): + def separable_conv(inputs, num_chan, bias_init=None, name=''): + dw_conv_name = name + '_dw' + pw_conv_name = name + '_pw' + in_chan = inputs.shape[1] + fan_in = np.sqrt(1. / (in_chan * 3 * 3)) + feat = fluid.layers.conv2d( + input=inputs, + num_filters=in_chan, + groups=in_chan, + filter_size=3, + stride=1, + padding='SAME', + param_attr=ParamAttr( + name=dw_conv_name + '_w', + initializer=TruncatedNormal(scale=fan_in)), + bias_attr=False) + fan_in = np.sqrt(1. / in_chan) + feat = fluid.layers.conv2d( + input=feat, + num_filters=num_chan, + filter_size=1, + stride=1, + param_attr=ParamAttr( + name=pw_conv_name + '_w', + initializer=TruncatedNormal(scale=fan_in)), + bias_attr=ParamAttr( + name=pw_conv_name + '_b', + initializer=bias_init, + regularizer=L2Decay(0.))) + return feat + + def subnet(inputs, prefix, level): + feat = inputs + for i in range(self.repeat): + # NOTE share weight across FPN levels + conv_name = '{}_pred_conv_{}'.format(prefix, i) + feat = separable_conv(feat, self.num_chan, name=conv_name) + # NOTE batch norm params are not shared + bn_name = '{}_pred_bn_{}_{}'.format(prefix, level, i) + feat = fluid.layers.batch_norm( + input=feat, + act='swish', + momentum=0.997, + epsilon=1e-4, + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + param_attr=ParamAttr( + name=bn_name + '_w', + initializer=Constant(value=1.), + regularizer=L2Decay(0.)), + bias_attr=ParamAttr( + name=bn_name + '_b', regularizer=L2Decay(0.))) + return feat + + cls_preds = [] + box_preds = [] + for l, feat in enumerate(body_feats): + cls_out = subnet(feat, 'cls', l) + box_out = subnet(feat, 'box', l) + + bias_init = float(-np.log((1 - self.prior_prob) / self.prior_prob)) + bias_init = Constant(value=bias_init) + cls_pred = separable_conv( + cls_out, + self.num_anchors * (self.num_classes - 1), + bias_init=bias_init, + name='cls_pred') + cls_pred = fluid.layers.transpose(cls_pred, perm=[0, 2, 3, 1]) + cls_pred = fluid.layers.reshape( + cls_pred, shape=(0, -1, self.num_classes - 1)) + cls_preds.append(cls_pred) + + box_pred = separable_conv( + box_out, self.num_anchors * 4, name='box_pred') + box_pred = fluid.layers.transpose(box_pred, perm=[0, 2, 3, 1]) + box_pred = fluid.layers.reshape(box_pred, shape=(0, -1, 4)) + box_preds.append(box_pred) + + return cls_preds, box_preds + + def get_prediction(self, body_feats, anchors, im_info): + cls_preds, box_preds = self._get_output(body_feats) + cls_preds = [fluid.layers.sigmoid(pred) for pred in cls_preds] + pred_result = self.output_decoder( + bboxes=box_preds, + scores=cls_preds, + anchors=anchors, + im_info=im_info) + return {'bbox': pred_result} + + def get_loss(self, body_feats, gt_labels, gt_targets, fg_num): + cls_preds, box_preds = self._get_output(body_feats) + fg_num = fluid.layers.reduce_sum(fg_num, name='fg_num') + fg_num.stop_gradient = True + + cls_pred = fluid.layers.concat(cls_preds, axis=1) + box_pred = fluid.layers.concat(box_preds, axis=1) + cls_pred_reshape = fluid.layers.reshape( + cls_pred, shape=(-1, self.num_classes - 1)) + gt_labels_reshape = fluid.layers.reshape(gt_labels, shape=(-1, 1)) + loss_cls = fluid.layers.sigmoid_focal_loss( + x=cls_pred_reshape, + label=gt_labels_reshape, + fg_num=fg_num, + gamma=self.gamma, + alpha=self.alpha) + loss_cls = fluid.layers.reduce_sum(loss_cls) + + loss_bbox = fluid.layers.huber_loss( + input=box_pred, label=gt_targets, delta=self.delta) + mask = fluid.layers.expand(gt_labels, expand_times=[1, 1, 4]) > 0 + loss_bbox *= fluid.layers.cast(mask, 'float32') + loss_bbox = fluid.layers.reduce_sum(loss_bbox) / (fg_num * 4) + + return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox} diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/fcos_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/fcos_head.py new file mode 100755 index 000000000..a08e02976 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/fcos_head.py @@ -0,0 +1,379 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer +from paddle.fluid.regularizer import L2Decay +from ppdet.modeling.ops import ConvNorm, DeformConvNorm +from ppdet.modeling.ops import MultiClassNMS + +from ppdet.core.workspace import register + +__all__ = ['FCOSHead'] + + +@register +class FCOSHead(object): + """ + FCOSHead + Args: + num_classes (int): Number of classes + fpn_stride (list): The stride of each FPN Layer + prior_prob (float): Used to set the bias init for the class prediction layer + num_convs (int): The layer number in fcos head + norm_type (str): Normalization type, 'bn'/'sync_bn'/'affine_channel' + fcos_loss (object): Instance of 'FCOSLoss' + norm_reg_targets (bool): Normalization the regression target if true + centerness_on_reg(bool): The prediction of centerness on regression or clssification branch + use_dcn_in_tower (bool): Ues deformable conv on FCOSHead if true + nms (object): Instance of 'MultiClassNMS' + """ + __inject__ = ['fcos_loss', 'nms'] + __shared__ = ['num_classes'] + + def __init__(self, + num_classes=80, + fpn_stride=[8, 16, 32, 64, 128], + prior_prob=0.01, + num_convs=4, + norm_type="gn", + fcos_loss=None, + norm_reg_targets=False, + centerness_on_reg=False, + use_dcn_in_tower=False, + nms=MultiClassNMS( + score_threshold=0.01, + nms_top_k=1000, + keep_top_k=100, + nms_threshold=0.45, + background_label=-1).__dict__): + self.num_classes = num_classes + self.fpn_stride = fpn_stride[::-1] + self.prior_prob = prior_prob + self.num_convs = num_convs + self.norm_reg_targets = norm_reg_targets + self.centerness_on_reg = centerness_on_reg + self.use_dcn_in_tower = use_dcn_in_tower + self.norm_type = norm_type + self.fcos_loss = fcos_loss + self.batch_size = 8 + self.nms = nms + if isinstance(nms, dict): + self.nms = MultiClassNMS(**nms) + + def _fcos_head(self, features, fpn_stride, fpn_scale, is_training=False): + """ + Args: + features (Variables): feature map from FPN + fpn_stride (int): the stride of current feature map + is_training (bool): whether is train or test mode + """ + subnet_blob_cls = features + subnet_blob_reg = features + in_channles = features.shape[1] + if self.use_dcn_in_tower: + conv_norm = DeformConvNorm + else: + conv_norm = ConvNorm + for lvl in range(0, self.num_convs): + conv_cls_name = 'fcos_head_cls_tower_conv_{}'.format(lvl) + subnet_blob_cls = conv_norm( + input=subnet_blob_cls, + num_filters=in_channles, + filter_size=3, + stride=1, + norm_type=self.norm_type, + act='relu', + initializer=Normal( + loc=0., scale=0.01), + bias_attr=True, + norm_name=conv_cls_name + "_norm", + name=conv_cls_name) + conv_reg_name = 'fcos_head_reg_tower_conv_{}'.format(lvl) + subnet_blob_reg = conv_norm( + input=subnet_blob_reg, + num_filters=in_channles, + filter_size=3, + stride=1, + norm_type=self.norm_type, + act='relu', + initializer=Normal( + loc=0., scale=0.01), + bias_attr=True, + norm_name=conv_reg_name + "_norm", + name=conv_reg_name) + conv_cls_name = "fcos_head_cls" + bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) + cls_logits = fluid.layers.conv2d( + input=subnet_blob_cls, + num_filters=self.num_classes, + filter_size=3, + stride=1, + padding=1, + param_attr=ParamAttr( + name=conv_cls_name + "_weights", + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=conv_cls_name + "_bias", + initializer=Constant(value=bias_init_value)), + name=conv_cls_name) + conv_reg_name = "fcos_head_reg" + bbox_reg = fluid.layers.conv2d( + input=subnet_blob_reg, + num_filters=4, + filter_size=3, + stride=1, + padding=1, + param_attr=ParamAttr( + name=conv_reg_name + "_weights", + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=conv_reg_name + "_bias", initializer=Constant(value=0)), + name=conv_reg_name) + bbox_reg = bbox_reg * fpn_scale + if self.norm_reg_targets: + bbox_reg = fluid.layers.relu(bbox_reg) + if not is_training: + bbox_reg = bbox_reg * fpn_stride + else: + bbox_reg = fluid.layers.exp(bbox_reg) + + conv_centerness_name = "fcos_head_centerness" + if self.centerness_on_reg: + subnet_blob_ctn = subnet_blob_reg + else: + subnet_blob_ctn = subnet_blob_cls + centerness = fluid.layers.conv2d( + input=subnet_blob_ctn, + num_filters=1, + filter_size=3, + stride=1, + padding=1, + param_attr=ParamAttr( + name=conv_centerness_name + "_weights", + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=conv_centerness_name + "_bias", + initializer=Constant(value=0)), + name=conv_centerness_name) + return cls_logits, bbox_reg, centerness + + def _get_output(self, body_feats, is_training=False): + """ + Args: + body_feates (list): the list of fpn feature maps + is_training (bool): whether is train or test mode + Return: + cls_logits (Variables): prediction for classification + bboxes_reg (Variables): prediction for bounding box + centerness (Variables): prediction for ceterness + """ + cls_logits = [] + bboxes_reg = [] + centerness = [] + assert len(body_feats) == len(self.fpn_stride), \ + "The size of body_feats is not equal to size of fpn_stride" + for fpn_name, fpn_stride in zip(body_feats, self.fpn_stride): + features = body_feats[fpn_name] + scale = fluid.layers.create_parameter( + shape=[1, ], + dtype="float32", + name="%s_scale_on_reg" % fpn_name, + default_initializer=fluid.initializer.Constant(1.)) + cls_pred, bbox_pred, ctn_pred = self._fcos_head( + features, fpn_stride, scale, is_training=is_training) + cls_logits.append(cls_pred) + bboxes_reg.append(bbox_pred) + centerness.append(ctn_pred) + return cls_logits, bboxes_reg, centerness + + def _compute_locations(self, features): + """ + Args: + features (list): List of Variables for FPN feature maps + Return: + Anchor points for each feature map pixel + """ + locations = [] + for lvl, fpn_name in enumerate(features): + feature = features[fpn_name] + shape_fm = fluid.layers.shape(feature) + shape_fm.stop_gradient = True + h = shape_fm[2] + w = shape_fm[3] + fpn_stride = self.fpn_stride[lvl] + shift_x = fluid.layers.range( + 0, w * fpn_stride, fpn_stride, dtype='float32') + shift_y = fluid.layers.range( + 0, h * fpn_stride, fpn_stride, dtype='float32') + shift_x = fluid.layers.unsqueeze(shift_x, axes=[0]) + shift_y = fluid.layers.unsqueeze(shift_y, axes=[1]) + shift_x = fluid.layers.expand_as( + shift_x, target_tensor=feature[0, 0, :, :]) + shift_y = fluid.layers.expand_as( + shift_y, target_tensor=feature[0, 0, :, :]) + shift_x.stop_gradient = True + shift_y.stop_gradient = True + shift_x = fluid.layers.reshape(shift_x, shape=[-1]) + shift_y = fluid.layers.reshape(shift_y, shape=[-1]) + location = fluid.layers.stack( + [shift_x, shift_y], axis=-1) + fpn_stride // 2 + location.stop_gradient = True + locations.append(location) + return locations + + def __merge_hw(self, input, ch_type="channel_first"): + """ + Args: + input (Variables): Feature map whose H and W will be merged into one dimension + ch_type (str): channel_first / channel_last + Return: + new_shape (Variables): The new shape after h and w merged into one dimension + """ + shape_ = fluid.layers.shape(input) + bs = shape_[0] + ch = shape_[1] + hi = shape_[2] + wi = shape_[3] + img_size = hi * wi + img_size.stop_gradient = True + if ch_type == "channel_first": + new_shape = fluid.layers.concat([bs, ch, img_size]) + elif ch_type == "channel_last": + new_shape = fluid.layers.concat([bs, img_size, ch]) + else: + raise KeyError("Wrong ch_type %s" % ch_type) + new_shape.stop_gradient = True + return new_shape + + def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn, + im_info): + """ + Args: + locations (Variables): anchor points for current layer + box_cls (Variables): categories prediction + box_reg (Variables): bounding box prediction + box_ctn (Variables): centerness prediction + im_info (Variables): [h, w, scale] for input images + Return: + box_cls_ch_last (Variables): score for each category, in [N, C, M] + C is the number of classes and M is the number of anchor points + box_reg_decoding (Variables): decoded bounding box, in [N, M, 4] + last dimension is [x1, y1, x2, y2] + """ + act_shape_cls = self.__merge_hw(box_cls) + box_cls_ch_last = fluid.layers.reshape( + x=box_cls, + shape=[self.batch_size, self.num_classes, -1], + actual_shape=act_shape_cls) + box_cls_ch_last = fluid.layers.sigmoid(box_cls_ch_last) + act_shape_reg = self.__merge_hw(box_reg, "channel_last") + box_reg_ch_last = fluid.layers.transpose(box_reg, perm=[0, 2, 3, 1]) + box_reg_ch_last = fluid.layers.reshape( + x=box_reg_ch_last, + shape=[self.batch_size, -1, 4], + actual_shape=act_shape_reg) + act_shape_ctn = self.__merge_hw(box_ctn) + box_ctn_ch_last = fluid.layers.reshape( + x=box_ctn, + shape=[self.batch_size, 1, -1], + actual_shape=act_shape_ctn) + box_ctn_ch_last = fluid.layers.sigmoid(box_ctn_ch_last) + + box_reg_decoding = fluid.layers.stack( + [ + locations[:, 0] - box_reg_ch_last[:, :, 0], + locations[:, 1] - box_reg_ch_last[:, :, 1], + locations[:, 0] + box_reg_ch_last[:, :, 2], + locations[:, 1] + box_reg_ch_last[:, :, 3] + ], + axis=1) + box_reg_decoding = fluid.layers.transpose( + box_reg_decoding, perm=[0, 2, 1]) + # recover the location to original image + im_scale = im_info[:, 2] + box_reg_decoding = box_reg_decoding / im_scale + box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last + return box_cls_ch_last, box_reg_decoding + + def _post_processing(self, locations, cls_logits, bboxes_reg, centerness, + im_info): + """ + Args: + locations (list): List of Variables composed by center of each anchor point + cls_logits (list): List of Variables for class prediction + bboxes_reg (list): List of Variables for bounding box prediction + centerness (list): List of Variables for centerness prediction + im_info(Variables): [h, w, scale] for input images + Return: + pred (LoDTensor): predicted bounding box after nms, + the shape is n x 6, last dimension is [label, score, xmin, ymin, xmax, ymax] + """ + pred_boxes_ = [] + pred_scores_ = [] + for _, ( + pts, cls, box, ctn + ) in enumerate(zip(locations, cls_logits, bboxes_reg, centerness)): + pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level( + pts, cls, box, ctn, im_info) + pred_boxes_.append(pred_boxes_lvl) + pred_scores_.append(pred_scores_lvl) + pred_boxes = fluid.layers.concat(pred_boxes_, axis=1) + pred_scores = fluid.layers.concat(pred_scores_, axis=2) + pred = self.nms(pred_boxes, pred_scores) + return pred + + def get_loss(self, input, tag_labels, tag_bboxes, tag_centerness): + """ + Calculate the loss for FCOS + Args: + input (list): List of Variables for feature maps from FPN layers + tag_labels (Variables): category targets for each anchor point + tag_bboxes (Variables): bounding boxes targets for positive samples + tag_centerness (Variables): centerness targets for positive samples + Return: + loss (dict): loss composed by classification loss, bounding box + regression loss and centerness regression loss + """ + cls_logits, bboxes_reg, centerness = self._get_output( + input, is_training=True) + loss = self.fcos_loss(cls_logits, bboxes_reg, centerness, tag_labels, + tag_bboxes, tag_centerness) + return loss + + def get_prediction(self, input, im_info): + """ + Decode the prediction + Args: + input (list): List of Variables for feature maps from FPN layers + im_info(Variables): [h, w, scale] for input images + Return: + the bounding box prediction + """ + cls_logits, bboxes_reg, centerness = self._get_output( + input, is_training=False) + locations = self._compute_locations(input) + pred = self._post_processing(locations, cls_logits, bboxes_reg, + centerness, im_info) + return {"bbox": pred} diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/iou_aware.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/iou_aware.py new file mode 100755 index 000000000..7a85a70a6 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/iou_aware.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid + + +def _split_ioup(output, an_num, num_classes): + """ + Split new output feature map to output, predicted iou + along channel dimension + """ + ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num]) + ioup = fluid.layers.sigmoid(ioup) + + oriout = fluid.layers.slice( + output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)]) + + return (ioup, oriout) + + +def _de_sigmoid(x, eps=1e-7): + x = fluid.layers.clip(x, eps, 1 / eps) + one = fluid.layers.fill_constant( + shape=[1, 1, 1, 1], dtype=x.dtype, value=1.) + x = fluid.layers.clip((one / x - 1.0), eps, 1 / eps) + x = -fluid.layers.log(x) + return x + + +def _postprocess_output(ioup, output, an_num, num_classes, iou_aware_factor): + """ + post process output objectness score + """ + tensors = [] + stride = output.shape[1] // an_num + for m in range(an_num): + tensors.append( + fluid.layers.slice( + output, + axes=[1], + starts=[stride * m + 0], + ends=[stride * m + 4])) + obj = fluid.layers.slice( + output, axes=[1], starts=[stride * m + 4], ends=[stride * m + 5]) + obj = fluid.layers.sigmoid(obj) + ip = fluid.layers.slice(ioup, axes=[1], starts=[m], ends=[m + 1]) + + new_obj = fluid.layers.pow(obj, ( + 1 - iou_aware_factor)) * fluid.layers.pow(ip, iou_aware_factor) + new_obj = _de_sigmoid(new_obj) + + tensors.append(new_obj) + + tensors.append( + fluid.layers.slice( + output, + axes=[1], + starts=[stride * m + 5], + ends=[stride * m + 5 + num_classes])) + + output = fluid.layers.concat(tensors, axis=1) + + return output + + +def get_iou_aware_score(output, an_num, num_classes, iou_aware_factor): + ioup, output = _split_ioup(output, an_num, num_classes) + output = _postprocess_output(ioup, output, an_num, num_classes, + iou_aware_factor) + return output diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/retina_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/retina_head.py new file mode 100755 index 000000000..eed81eb4d --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/retina_head.py @@ -0,0 +1,408 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Constant +from paddle.fluid.regularizer import L2Decay +from ppdet.modeling.ops import (AnchorGenerator, RetinaTargetAssign, + RetinaOutputDecoder) + +from ppdet.core.workspace import register + +__all__ = ['RetinaHead'] + + +@register +class RetinaHead(object): + """ + Retina Head + + Args: + anchor_generator (object): `AnchorGenerator` instance + target_assign (object): `RetinaTargetAssign` instance + output_decoder (object): `RetinaOutputDecoder` instance + num_convs_per_octave (int): Number of convolution layers in each octave + num_chan (int): Number of octave output channels + max_level (int): Highest level of FPN output + min_level (int): Lowest level of FPN output + prior_prob (float): Used to set the bias init for the class prediction layer + base_scale (int): Anchors are generated based on this scale + num_scales_per_octave (int): Number of anchor scales per octave + num_classes (int): Number of classes + gamma (float): The parameter in focal loss + alpha (float): The parameter in focal loss + sigma (float): The parameter in smooth l1 loss + """ + __inject__ = ['anchor_generator', 'target_assign', 'output_decoder'] + __shared__ = ['num_classes'] + + def __init__(self, + anchor_generator=AnchorGenerator().__dict__, + target_assign=RetinaTargetAssign().__dict__, + output_decoder=RetinaOutputDecoder().__dict__, + num_convs_per_octave=4, + num_chan=256, + max_level=7, + min_level=3, + prior_prob=0.01, + base_scale=4, + num_scales_per_octave=3, + num_classes=81, + gamma=2.0, + alpha=0.25, + sigma=3.0151134457776365): + self.anchor_generator = anchor_generator + self.target_assign = target_assign + self.output_decoder = output_decoder + self.num_convs_per_octave = num_convs_per_octave + self.num_chan = num_chan + self.max_level = max_level + self.min_level = min_level + self.prior_prob = prior_prob + self.base_scale = base_scale + self.num_scales_per_octave = num_scales_per_octave + self.num_classes = num_classes + self.gamma = gamma + self.alpha = alpha + self.sigma = sigma + if isinstance(anchor_generator, dict): + self.anchor_generator = AnchorGenerator(**anchor_generator) + if isinstance(target_assign, dict): + self.target_assign = RetinaTargetAssign(**target_assign) + if isinstance(output_decoder, dict): + self.output_decoder = RetinaOutputDecoder(**output_decoder) + + def _class_subnet(self, body_feats, spatial_scale): + """ + Get class predictions of all level FPN level. + + Args: + fpn_dict(dict): A dictionary represents the output of FPN with + their name. + spatial_scale(list): A list of multiplicative spatial scale factor. + + Returns: + cls_pred_input(list): Class prediction of all input fpn levels. + """ + assert len(body_feats) == self.max_level - self.min_level + 1 + fpn_name_list = list(body_feats.keys()) + cls_pred_list = [] + for lvl in range(self.min_level, self.max_level + 1): + fpn_name = fpn_name_list[self.max_level - lvl] + subnet_blob = body_feats[fpn_name] + for i in range(self.num_convs_per_octave): + conv_name = 'retnet_cls_conv_n{}_fpn{}'.format(i, lvl) + conv_share_name = 'retnet_cls_conv_n{}_fpn{}'.format( + i, self.min_level) + subnet_blob_in = subnet_blob + subnet_blob = fluid.layers.conv2d( + input=subnet_blob_in, + num_filters=self.num_chan, + filter_size=3, + stride=1, + padding=1, + act='relu', + name=conv_name, + param_attr=ParamAttr( + name=conv_share_name + '_w', + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=conv_share_name + '_b', + learning_rate=2., + regularizer=L2Decay(0.))) + + # class prediction + cls_name = 'retnet_cls_pred_fpn{}'.format(lvl) + cls_share_name = 'retnet_cls_pred_fpn{}'.format(self.min_level) + num_anchors = self.num_scales_per_octave * len( + self.anchor_generator.aspect_ratios) + cls_dim = num_anchors * (self.num_classes - 1) + # bias initialization: b = -log((1 - pai) / pai) + bias_init = float(-np.log((1 - self.prior_prob) / self.prior_prob)) + out_cls = fluid.layers.conv2d( + input=subnet_blob, + num_filters=cls_dim, + filter_size=3, + stride=1, + padding=1, + act=None, + name=cls_name, + param_attr=ParamAttr( + name=cls_share_name + '_w', + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=cls_share_name + '_b', + initializer=Constant(value=bias_init), + learning_rate=2., + regularizer=L2Decay(0.))) + cls_pred_list.append(out_cls) + + return cls_pred_list + + def _bbox_subnet(self, body_feats, spatial_scale): + """ + Get bounding box predictions of all level FPN level. + + Args: + fpn_dict(dict): A dictionary represents the output of FPN with + their name. + spatial_scale(list): A list of multiplicative spatial scale factor. + + Returns: + bbox_pred_input(list): Bounding box prediction of all input fpn + levels. + """ + assert len(body_feats) == self.max_level - self.min_level + 1 + fpn_name_list = list(body_feats.keys()) + bbox_pred_list = [] + for lvl in range(self.min_level, self.max_level + 1): + fpn_name = fpn_name_list[self.max_level - lvl] + subnet_blob = body_feats[fpn_name] + for i in range(self.num_convs_per_octave): + conv_name = 'retnet_bbox_conv_n{}_fpn{}'.format(i, lvl) + conv_share_name = 'retnet_bbox_conv_n{}_fpn{}'.format( + i, self.min_level) + subnet_blob_in = subnet_blob + subnet_blob = fluid.layers.conv2d( + input=subnet_blob_in, + num_filters=self.num_chan, + filter_size=3, + stride=1, + padding=1, + act='relu', + name=conv_name, + param_attr=ParamAttr( + name=conv_share_name + '_w', + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=conv_share_name + '_b', + learning_rate=2., + regularizer=L2Decay(0.))) + + # bbox prediction + bbox_name = 'retnet_bbox_pred_fpn{}'.format(lvl) + bbox_share_name = 'retnet_bbox_pred_fpn{}'.format(self.min_level) + num_anchors = self.num_scales_per_octave * len( + self.anchor_generator.aspect_ratios) + bbox_dim = num_anchors * 4 + out_bbox = fluid.layers.conv2d( + input=subnet_blob, + num_filters=bbox_dim, + filter_size=3, + stride=1, + padding=1, + act=None, + name=bbox_name, + param_attr=ParamAttr( + name=bbox_share_name + '_w', + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=bbox_share_name + '_b', + learning_rate=2., + regularizer=L2Decay(0.))) + bbox_pred_list.append(out_bbox) + return bbox_pred_list + + def _anchor_generate(self, body_feats, spatial_scale): + """ + Get anchor boxes of all level FPN level. + + Args: + fpn_dict(dict): A dictionary represents the output of FPN with + their name. + spatial_scale(list): A list of multiplicative spatial scale factor. + + Return: + anchor_input(list): Anchors of all input fpn levels with shape of. + anchor_var_input(list): Anchor variance of all input fpn levels with + shape. + """ + assert len(body_feats) == self.max_level - self.min_level + 1 + fpn_name_list = list(body_feats.keys()) + anchor_list = [] + anchor_var_list = [] + for lvl in range(self.min_level, self.max_level + 1): + anchor_sizes = [] + stride = int(1 / spatial_scale[self.max_level - lvl]) + for octave in range(self.num_scales_per_octave): + anchor_size = stride * ( + 2**(float(octave) / + float(self.num_scales_per_octave))) * self.base_scale + anchor_sizes.append(anchor_size) + fpn_name = fpn_name_list[self.max_level - lvl] + anchor, anchor_var = self.anchor_generator( + input=body_feats[fpn_name], + anchor_sizes=anchor_sizes, + aspect_ratios=self.anchor_generator.aspect_ratios, + stride=[stride, stride]) + anchor_list.append(anchor) + anchor_var_list.append(anchor_var) + return anchor_list, anchor_var_list + + def _get_output(self, body_feats, spatial_scale): + """ + Get class, bounding box predictions and anchor boxes of all level FPN level. + + Args: + fpn_dict(dict): A dictionary represents the output of FPN with + their name. + spatial_scale(list): A list of multiplicative spatial scale factor. + + Returns: + cls_pred_input(list): Class prediction of all input fpn levels. + bbox_pred_input(list): Bounding box prediction of all input fpn + levels. + anchor_input(list): Anchors of all input fpn levels with shape of. + anchor_var_input(list): Anchor variance of all input fpn levels with + shape. + """ + assert len(body_feats) == self.max_level - self.min_level + 1 + # class subnet + cls_pred_list = self._class_subnet(body_feats, spatial_scale) + # bbox subnet + bbox_pred_list = self._bbox_subnet(body_feats, spatial_scale) + #generate anchors + anchor_list, anchor_var_list = self._anchor_generate(body_feats, + spatial_scale) + cls_pred_reshape_list = [] + bbox_pred_reshape_list = [] + anchor_reshape_list = [] + anchor_var_reshape_list = [] + for i in range(self.max_level - self.min_level + 1): + cls_pred_transpose = fluid.layers.transpose( + cls_pred_list[i], perm=[0, 2, 3, 1]) + cls_pred_reshape = fluid.layers.reshape( + cls_pred_transpose, shape=(0, -1, self.num_classes - 1)) + bbox_pred_transpose = fluid.layers.transpose( + bbox_pred_list[i], perm=[0, 2, 3, 1]) + bbox_pred_reshape = fluid.layers.reshape( + bbox_pred_transpose, shape=(0, -1, 4)) + anchor_reshape = fluid.layers.reshape(anchor_list[i], shape=(-1, 4)) + anchor_var_reshape = fluid.layers.reshape( + anchor_var_list[i], shape=(-1, 4)) + cls_pred_reshape_list.append(cls_pred_reshape) + bbox_pred_reshape_list.append(bbox_pred_reshape) + anchor_reshape_list.append(anchor_reshape) + anchor_var_reshape_list.append(anchor_var_reshape) + output = {} + output['cls_pred'] = cls_pred_reshape_list + output['bbox_pred'] = bbox_pred_reshape_list + output['anchor'] = anchor_reshape_list + output['anchor_var'] = anchor_var_reshape_list + return output + + def get_prediction(self, body_feats, spatial_scale, im_info): + """ + Get prediction bounding box in test stage. + + Args: + fpn_dict(dict): A dictionary represents the output of FPN with + their name. + spatial_scale(list): A list of multiplicative spatial scale factor. + im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the + number of input images, each element consists of im_height, + im_width, im_scale. + + Returns: + pred_result(Variable): Prediction result with shape [N, 6]. Each + row has 6 values: [label, confidence, xmin, ymin, xmax, ymax]. + N is the total number of prediction. + """ + output = self._get_output(body_feats, spatial_scale) + cls_pred_reshape_list = output['cls_pred'] + bbox_pred_reshape_list = output['bbox_pred'] + anchor_reshape_list = output['anchor'] + for i in range(self.max_level - self.min_level + 1): + cls_pred_reshape_list[i] = fluid.layers.sigmoid( + cls_pred_reshape_list[i]) + pred_result = self.output_decoder( + bboxes=bbox_pred_reshape_list, + scores=cls_pred_reshape_list, + anchors=anchor_reshape_list, + im_info=im_info) + return {'bbox': pred_result} + + def get_loss(self, body_feats, spatial_scale, im_info, gt_box, gt_label, + is_crowd): + """ + Calculate the loss of retinanet. + Args: + fpn_dict(dict): A dictionary represents the output of FPN with + their name. + spatial_scale(list): A list of multiplicative spatial scale factor. + im_info(Variable): A 2-D LoDTensor with shape [B, 3]. B is the + number of input images, each element consists of im_height, + im_width, im_scale. + gt_box(Variable): The ground-truth bounding boxes with shape [M, 4]. + M is the number of groundtruth. + gt_label(Variable): The ground-truth labels with shape [M, 1]. + M is the number of groundtruth. + is_crowd(Variable): Indicates groud-truth is crowd or not with + shape [M, 1]. M is the number of groundtruth. + + Returns: + Type: dict + loss_cls(Variable): focal loss. + loss_bbox(Variable): smooth l1 loss. + """ + output = self._get_output(body_feats, spatial_scale) + cls_pred_reshape_list = output['cls_pred'] + bbox_pred_reshape_list = output['bbox_pred'] + anchor_reshape_list = output['anchor'] + anchor_var_reshape_list = output['anchor_var'] + + cls_pred_input = fluid.layers.concat(cls_pred_reshape_list, axis=1) + bbox_pred_input = fluid.layers.concat(bbox_pred_reshape_list, axis=1) + anchor_input = fluid.layers.concat(anchor_reshape_list, axis=0) + anchor_var_input = fluid.layers.concat(anchor_var_reshape_list, axis=0) + score_pred, loc_pred, score_tgt, loc_tgt, bbox_weight, fg_num = \ + self.target_assign( + bbox_pred=bbox_pred_input, + cls_logits=cls_pred_input, + anchor_box=anchor_input, + anchor_var=anchor_var_input, + gt_boxes=gt_box, + gt_labels=gt_label, + is_crowd=is_crowd, + im_info=im_info, + num_classes=self.num_classes - 1) + fg_num = fluid.layers.reduce_sum(fg_num, name='fg_num') + score_tgt = fluid.layers.cast(score_tgt, 'int32') + loss_cls = fluid.layers.sigmoid_focal_loss( + x=score_pred, + label=score_tgt, + fg_num=fg_num, + gamma=self.gamma, + alpha=self.alpha) + loss_cls = fluid.layers.reduce_sum(loss_cls, name='loss_cls') + loss_bbox = fluid.layers.smooth_l1( + x=loc_pred, + y=loc_tgt, + sigma=self.sigma, + inside_weight=bbox_weight, + outside_weight=bbox_weight) + loss_bbox = fluid.layers.reduce_sum(loss_bbox, name='loss_bbox') + loss_bbox = loss_bbox / fluid.layers.cast(fg_num, loss_bbox.dtype) + return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox} diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/rpn_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/rpn_head.py new file mode 100755 index 000000000..876aafe36 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/rpn_head.py @@ -0,0 +1,497 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import (AnchorGenerator, RPNTargetAssign, + GenerateProposals) + +__all__ = ['RPNTargetAssign', 'GenerateProposals', 'RPNHead', 'FPNRPNHead'] + + +@register +class RPNHead(object): + """ + RPN Head + + Args: + anchor_generator (object): `AnchorGenerator` instance + rpn_target_assign (object): `RPNTargetAssign` instance + train_proposal (object): `GenerateProposals` instance for training + test_proposal (object): `GenerateProposals` instance for testing + num_classes (int): number of classes in rpn output + """ + __inject__ = [ + 'anchor_generator', 'rpn_target_assign', 'train_proposal', + 'test_proposal' + ] + + def __init__(self, + anchor_generator=AnchorGenerator().__dict__, + rpn_target_assign=RPNTargetAssign().__dict__, + train_proposal=GenerateProposals(12000, 2000).__dict__, + test_proposal=GenerateProposals().__dict__, + num_classes=1): + super(RPNHead, self).__init__() + self.anchor_generator = anchor_generator + self.rpn_target_assign = rpn_target_assign + self.train_proposal = train_proposal + self.test_proposal = test_proposal + self.num_classes = num_classes + if isinstance(anchor_generator, dict): + self.anchor_generator = AnchorGenerator(**anchor_generator) + if isinstance(rpn_target_assign, dict): + self.rpn_target_assign = RPNTargetAssign(**rpn_target_assign) + if isinstance(train_proposal, dict): + self.train_proposal = GenerateProposals(**train_proposal) + if isinstance(test_proposal, dict): + self.test_proposal = GenerateProposals(**test_proposal) + + def _get_output(self, input): + """ + Get anchor and RPN head output. + + Args: + input(Variable): feature map from backbone with shape of [N, C, H, W] + + Returns: + rpn_cls_score(Variable): Output of rpn head with shape of + [N, num_anchors, H, W]. + rpn_bbox_pred(Variable): Output of rpn head with shape of + [N, num_anchors * 4, H, W]. + """ + dim_out = input.shape[1] + rpn_conv = fluid.layers.conv2d( + input=input, + num_filters=dim_out, + filter_size=3, + stride=1, + padding=1, + act='relu', + name='conv_rpn', + param_attr=ParamAttr( + name="conv_rpn_w", initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name="conv_rpn_b", learning_rate=2., regularizer=L2Decay(0.))) + # Generate anchors + self.anchor, self.anchor_var = self.anchor_generator(input=rpn_conv) + num_anchor = self.anchor.shape[2] + # Proposal classification scores + self.rpn_cls_score = fluid.layers.conv2d( + rpn_conv, + num_filters=num_anchor * self.num_classes, + filter_size=1, + stride=1, + padding=0, + act=None, + name='rpn_cls_score', + param_attr=ParamAttr( + name="rpn_cls_logits_w", initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name="rpn_cls_logits_b", + learning_rate=2., + regularizer=L2Decay(0.))) + # Proposal bbox regression deltas + self.rpn_bbox_pred = fluid.layers.conv2d( + rpn_conv, + num_filters=4 * num_anchor, + filter_size=1, + stride=1, + padding=0, + act=None, + name='rpn_bbox_pred', + param_attr=ParamAttr( + name="rpn_bbox_pred_w", initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name="rpn_bbox_pred_b", + learning_rate=2., + regularizer=L2Decay(0.))) + return self.rpn_cls_score, self.rpn_bbox_pred + + def get_proposals(self, body_feats, im_info, mode='train'): + """ + Get proposals according to the output of backbone. + + Args: + body_feats (dict): The dictionary of feature maps from backbone. + im_info(Variable): The information of image with shape [N, 3] with + shape (height, width, scale). + body_feat_names(list): A list of names of feature maps from + backbone. + + Returns: + rpn_rois(Variable): Output proposals with shape of (rois_num, 4). + """ + + # In RPN Heads, only the last feature map of backbone is used. + # And body_feat_names[-1] represents the last level name of backbone. + body_feat = list(body_feats.values())[-1] + rpn_cls_score, rpn_bbox_pred = self._get_output(body_feat) + + if self.num_classes == 1: + rpn_cls_prob = fluid.layers.sigmoid( + rpn_cls_score, name='rpn_cls_prob') + else: + rpn_cls_score = fluid.layers.transpose( + rpn_cls_score, perm=[0, 2, 3, 1]) + rpn_cls_score = fluid.layers.reshape( + rpn_cls_score, shape=(0, 0, 0, -1, self.num_classes)) + rpn_cls_prob_tmp = fluid.layers.softmax( + rpn_cls_score, use_cudnn=False, name='rpn_cls_prob') + rpn_cls_prob_slice = fluid.layers.slice( + rpn_cls_prob_tmp, axes=[4], starts=[1], + ends=[self.num_classes]) + rpn_cls_prob, _ = fluid.layers.topk(rpn_cls_prob_slice, 1) + rpn_cls_prob = fluid.layers.reshape( + rpn_cls_prob, shape=(0, 0, 0, -1)) + rpn_cls_prob = fluid.layers.transpose( + rpn_cls_prob, perm=[0, 3, 1, 2]) + prop_op = self.train_proposal if mode == 'train' else self.test_proposal + rpn_rois, rpn_roi_probs = prop_op( + scores=rpn_cls_prob, + bbox_deltas=rpn_bbox_pred, + im_info=im_info, + anchors=self.anchor, + variances=self.anchor_var) + return rpn_rois + + def _transform_input(self, rpn_cls_score, rpn_bbox_pred, anchor, + anchor_var): + rpn_cls_score = fluid.layers.transpose(rpn_cls_score, perm=[0, 2, 3, 1]) + rpn_bbox_pred = fluid.layers.transpose(rpn_bbox_pred, perm=[0, 2, 3, 1]) + anchor = fluid.layers.reshape(anchor, shape=(-1, 4)) + anchor_var = fluid.layers.reshape(anchor_var, shape=(-1, 4)) + rpn_cls_score = fluid.layers.reshape( + x=rpn_cls_score, shape=(0, -1, self.num_classes)) + rpn_bbox_pred = fluid.layers.reshape(x=rpn_bbox_pred, shape=(0, -1, 4)) + return rpn_cls_score, rpn_bbox_pred, anchor, anchor_var + + def _get_loss_input(self): + for attr in ['rpn_cls_score', 'rpn_bbox_pred', 'anchor', 'anchor_var']: + if not getattr(self, attr, None): + raise ValueError("self.{} should not be None,".format(attr), + "call RPNHead.get_proposals first") + return self._transform_input(self.rpn_cls_score, self.rpn_bbox_pred, + self.anchor, self.anchor_var) + + def get_loss(self, im_info, gt_box, is_crowd, gt_label=None): + """ + Sample proposals and Calculate rpn loss. + + Args: + im_info(Variable): The information of image with shape [N, 3] with + shape (height, width, scale). + gt_box(Variable): The ground-truth bounding boxes with shape [M, 4]. + M is the number of groundtruth. + is_crowd(Variable): Indicates groud-truth is crowd or not with + shape [M, 1]. M is the number of groundtruth. + + Returns: + Type: dict + rpn_cls_loss(Variable): RPN classification loss. + rpn_bbox_loss(Variable): RPN bounding box regression loss. + + """ + rpn_cls, rpn_bbox, anchor, anchor_var = self._get_loss_input() + if self.num_classes == 1: + score_pred, loc_pred, score_tgt, loc_tgt, bbox_weight = \ + self.rpn_target_assign( + bbox_pred=rpn_bbox, + cls_logits=rpn_cls, + anchor_box=anchor, + anchor_var=anchor_var, + gt_boxes=gt_box, + is_crowd=is_crowd, + im_info=im_info) + score_tgt = fluid.layers.cast(x=score_tgt, dtype='float32') + score_tgt.stop_gradient = True + rpn_cls_loss = fluid.layers.sigmoid_cross_entropy_with_logits( + x=score_pred, label=score_tgt) + else: + score_pred, loc_pred, score_tgt, loc_tgt, bbox_weight = \ + self.rpn_target_assign( + bbox_pred=rpn_bbox, + cls_logits=rpn_cls, + anchor_box=anchor, + anchor_var=anchor_var, + gt_boxes=gt_box, + gt_labels=gt_label, + is_crowd=is_crowd, + num_classes=self.num_classes, + im_info=im_info) + labels_int64 = fluid.layers.cast(x=score_tgt, dtype='int64') + labels_int64.stop_gradient = True + rpn_cls_loss = fluid.layers.softmax_with_cross_entropy( + logits=score_pred, label=labels_int64, numeric_stable_mode=True) + + rpn_cls_loss = fluid.layers.reduce_mean( + rpn_cls_loss, name='loss_rpn_cls') + + loc_tgt = fluid.layers.cast(x=loc_tgt, dtype='float32') + loc_tgt.stop_gradient = True + rpn_reg_loss = fluid.layers.smooth_l1( + x=loc_pred, + y=loc_tgt, + sigma=3.0, + inside_weight=bbox_weight, + outside_weight=bbox_weight) + rpn_reg_loss = fluid.layers.reduce_sum( + rpn_reg_loss, name='loss_rpn_bbox') + score_shape = fluid.layers.shape(score_tgt) + score_shape = fluid.layers.cast(x=score_shape, dtype='float32') + norm = fluid.layers.reduce_prod(score_shape) + norm.stop_gradient = True + rpn_reg_loss = rpn_reg_loss / norm + + return {'loss_rpn_cls': rpn_cls_loss, 'loss_rpn_bbox': rpn_reg_loss} + + +@register +class FPNRPNHead(RPNHead): + """ + RPN Head that supports FPN input + + Args: + anchor_generator (object): `AnchorGenerator` instance + rpn_target_assign (object): `RPNTargetAssign` instance + train_proposal (object): `GenerateProposals` instance for training + test_proposal (object): `GenerateProposals` instance for testing + anchor_start_size (int): size of anchor at the first scale + num_chan (int): number of FPN output channels + min_level (int): lowest level of FPN output + max_level (int): highest level of FPN output + num_classes (int): number of classes in rpn output + """ + + __inject__ = [ + 'anchor_generator', 'rpn_target_assign', 'train_proposal', + 'test_proposal' + ] + + def __init__(self, + anchor_generator=AnchorGenerator().__dict__, + rpn_target_assign=RPNTargetAssign().__dict__, + train_proposal=GenerateProposals(12000, 2000).__dict__, + test_proposal=GenerateProposals().__dict__, + anchor_start_size=32, + num_chan=256, + min_level=2, + max_level=6, + num_classes=1): + super(FPNRPNHead, self).__init__(anchor_generator, rpn_target_assign, + train_proposal, test_proposal) + self.anchor_start_size = anchor_start_size + self.num_chan = num_chan + self.min_level = min_level + self.max_level = max_level + self.num_classes = num_classes + + self.fpn_rpn_list = [] + self.anchors_list = [] + self.anchor_var_list = [] + + def _get_output(self, input, feat_lvl): + """ + Get anchor and FPN RPN head output at one level. + + Args: + input(Variable): Body feature from backbone. + feat_lvl(int): Indicate the level of rpn output corresponding + to the level of feature map. + + Return: + rpn_cls_score(Variable): Output of one level of fpn rpn head with + shape of [N, num_anchors, H, W]. + rpn_bbox_pred(Variable): Output of one level of fpn rpn head with + shape of [N, num_anchors * 4, H, W]. + """ + slvl = str(feat_lvl) + conv_name = 'conv_rpn_fpn' + slvl + cls_name = 'rpn_cls_logits_fpn' + slvl + bbox_name = 'rpn_bbox_pred_fpn' + slvl + conv_share_name = 'conv_rpn_fpn' + str(self.min_level) + cls_share_name = 'rpn_cls_logits_fpn' + str(self.min_level) + bbox_share_name = 'rpn_bbox_pred_fpn' + str(self.min_level) + + num_anchors = len(self.anchor_generator.aspect_ratios) + conv_rpn_fpn = fluid.layers.conv2d( + input=input, + num_filters=self.num_chan, + filter_size=3, + padding=1, + act='relu', + name=conv_name, + param_attr=ParamAttr( + name=conv_share_name + '_w', + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=conv_share_name + '_b', + learning_rate=2., + regularizer=L2Decay(0.))) + + self.anchors, self.anchor_var = self.anchor_generator( + input=conv_rpn_fpn, + anchor_sizes=(self.anchor_start_size * 2. + **(feat_lvl - self.min_level), ), + stride=(2.**feat_lvl, 2.**feat_lvl)) + + cls_num_filters = num_anchors * self.num_classes + self.rpn_cls_score = fluid.layers.conv2d( + input=conv_rpn_fpn, + num_filters=cls_num_filters, + filter_size=1, + act=None, + name=cls_name, + param_attr=ParamAttr( + name=cls_share_name + '_w', + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=cls_share_name + '_b', + learning_rate=2., + regularizer=L2Decay(0.))) + self.rpn_bbox_pred = fluid.layers.conv2d( + input=conv_rpn_fpn, + num_filters=num_anchors * 4, + filter_size=1, + act=None, + name=bbox_name, + param_attr=ParamAttr( + name=bbox_share_name + '_w', + initializer=Normal( + loc=0., scale=0.01)), + bias_attr=ParamAttr( + name=bbox_share_name + '_b', + learning_rate=2., + regularizer=L2Decay(0.))) + return self.rpn_cls_score, self.rpn_bbox_pred + + def _get_single_proposals(self, body_feat, im_info, feat_lvl, mode='train'): + """ + Get proposals in one level according to the output of fpn rpn head + + Args: + body_feat(Variable): the feature map from backone. + im_info(Variable): The information of image with shape [N, 3] with + format (height, width, scale). + feat_lvl(int): Indicate the level of proposals corresponding to + the feature maps. + + Returns: + rpn_rois_fpn(Variable): Output proposals with shape of (rois_num, 4). + rpn_roi_probs_fpn(Variable): Scores of proposals with + shape of (rois_num, 1). + """ + + rpn_cls_score_fpn, rpn_bbox_pred_fpn = self._get_output(body_feat, + feat_lvl) + + prop_op = self.train_proposal if mode == 'train' else self.test_proposal + if self.num_classes == 1: + rpn_cls_prob_fpn = fluid.layers.sigmoid( + rpn_cls_score_fpn, name='rpn_cls_prob_fpn' + str(feat_lvl)) + else: + rpn_cls_score_fpn = fluid.layers.transpose( + rpn_cls_score_fpn, perm=[0, 2, 3, 1]) + rpn_cls_score_fpn = fluid.layers.reshape( + rpn_cls_score_fpn, shape=(0, 0, 0, -1, self.num_classes)) + rpn_cls_prob_fpn = fluid.layers.softmax( + rpn_cls_score_fpn, + use_cudnn=False, + name='rpn_cls_prob_fpn' + str(feat_lvl)) + rpn_cls_prob_fpn = fluid.layers.slice( + rpn_cls_prob_fpn, axes=[4], starts=[1], + ends=[self.num_classes]) + rpn_cls_prob_fpn, _ = fluid.layers.topk(rpn_cls_prob_fpn, 1) + rpn_cls_prob_fpn = fluid.layers.reshape( + rpn_cls_prob_fpn, shape=(0, 0, 0, -1)) + rpn_cls_prob_fpn = fluid.layers.transpose( + rpn_cls_prob_fpn, perm=[0, 3, 1, 2]) + rpn_rois_fpn, rpn_roi_prob_fpn = prop_op( + scores=rpn_cls_prob_fpn, + bbox_deltas=rpn_bbox_pred_fpn, + im_info=im_info, + anchors=self.anchors, + variances=self.anchor_var) + return rpn_rois_fpn, rpn_roi_prob_fpn + + def get_proposals(self, fpn_feats, im_info, mode='train'): + """ + Get proposals in multiple levels according to the output of fpn + rpn head + + Args: + fpn_feats(dict): A dictionary represents the output feature map + of FPN with their name. + im_info(Variable): The information of image with shape [N, 3] with + format (height, width, scale). + + Return: + rois_list(Variable): Output proposals in shape of [rois_num, 4] + """ + rois_list = [] + roi_probs_list = [] + fpn_feat_names = list(fpn_feats.keys()) + for lvl in range(self.min_level, self.max_level + 1): + fpn_feat_name = fpn_feat_names[self.max_level - lvl] + fpn_feat = fpn_feats[fpn_feat_name] + rois_fpn, roi_probs_fpn = self._get_single_proposals( + fpn_feat, im_info, lvl, mode) + self.fpn_rpn_list.append((self.rpn_cls_score, self.rpn_bbox_pred)) + rois_list.append(rois_fpn) + roi_probs_list.append(roi_probs_fpn) + self.anchors_list.append(self.anchors) + self.anchor_var_list.append(self.anchor_var) + prop_op = self.train_proposal if mode == 'train' else self.test_proposal + post_nms_top_n = prop_op.post_nms_top_n + rois_collect = fluid.layers.collect_fpn_proposals( + rois_list, + roi_probs_list, + self.min_level, + self.max_level, + post_nms_top_n, + name='collect') + return rois_collect + + def _get_loss_input(self): + rpn_clses = [] + rpn_bboxes = [] + anchors = [] + anchor_vars = [] + for i in range(len(self.fpn_rpn_list)): + single_input = self._transform_input( + self.fpn_rpn_list[i][0], self.fpn_rpn_list[i][1], + self.anchors_list[i], self.anchor_var_list[i]) + rpn_clses.append(single_input[0]) + rpn_bboxes.append(single_input[1]) + anchors.append(single_input[2]) + anchor_vars.append(single_input[3]) + + rpn_cls = fluid.layers.concat(rpn_clses, axis=1) + rpn_bbox = fluid.layers.concat(rpn_bboxes, axis=1) + anchors = fluid.layers.concat(anchors) + anchor_var = fluid.layers.concat(anchor_vars) + return rpn_cls, rpn_bbox, anchors, anchor_var diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/solov2_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/solov2_head.py new file mode 100755 index 000000000..21cdad729 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/solov2_head.py @@ -0,0 +1,438 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from ppdet.modeling.ops import ConvNorm, DeformConvNorm, MaskMatrixNMS, DropBlock +from ppdet.core.workspace import register + +from six.moves import zip +import numpy as np + +__all__ = ['SOLOv2Head'] + + +@register +class SOLOv2Head(object): + """ + Head block for SOLOv2 network + + Args: + num_classes (int): Number of output classes. + seg_feat_channels (int): Num_filters of kernel & categroy branch convolution operation. + stacked_convs (int): Times of convolution operation. + num_grids (list[int]): List of feature map grids size. + kernel_out_channels (int): Number of output channels in kernel branch. + dcn_v2_stages (list): Which stage use dcn v2 in tower. + segm_strides (list[int]): List of segmentation area stride. + solov2_loss (object): SOLOv2Loss instance. + score_threshold (float): Threshold of categroy score. + mask_nms (object): MaskMatrixNMS instance. + drop_block (bool): Whether use drop_block or not. + """ + __inject__ = ['solov2_loss', 'mask_nms'] + __shared__ = ['num_classes'] + + def __init__(self, + num_classes=80, + seg_feat_channels=256, + stacked_convs=4, + num_grids=[40, 36, 24, 16, 12], + kernel_out_channels=256, + dcn_v2_stages=[], + segm_strides=[8, 8, 16, 32, 32], + solov2_loss=None, + score_threshold=0.1, + mask_threshold=0.5, + mask_nms=MaskMatrixNMS( + update_threshold=0.05, + pre_nms_top_n=500, + post_nms_top_n=100, + kernel='gaussian', + sigma=2.0).__dict__, + drop_block=False): + self.num_classes = num_classes + self.seg_num_grids = num_grids + self.cate_out_channels = self.num_classes - 1 + self.seg_feat_channels = seg_feat_channels + self.stacked_convs = stacked_convs + self.kernel_out_channels = kernel_out_channels + self.dcn_v2_stages = dcn_v2_stages + self.segm_strides = segm_strides + self.solov2_loss = solov2_loss + self.mask_nms = mask_nms + self.score_threshold = score_threshold + self.mask_threshold = mask_threshold + self.drop_block = drop_block + self.conv_type = [ConvNorm, DeformConvNorm] + if isinstance(mask_nms, dict): + self.mask_nms = MaskMatrixNMS(**mask_nms) + + def _conv_pred(self, conv_feat, num_filters, is_test, name, name_feat=None): + for i in range(self.stacked_convs): + if i in self.dcn_v2_stages: + conv_func = self.conv_type[1] + else: + conv_func = self.conv_type[0] + conv_feat = conv_func( + input=conv_feat, + num_filters=self.seg_feat_channels, + filter_size=3, + stride=1, + norm_type='gn', + norm_groups=32, + freeze_norm=False, + act='relu', + initializer=fluid.initializer.NormalInitializer(scale=0.01), + norm_name='{}.{}.gn'.format(name, i), + name='{}.{}'.format(name, i)) + if name_feat == 'bbox_head.solo_cate': + bias_init = float(-np.log((1 - 0.01) / 0.01)) + bias_attr = ParamAttr( + name="{}.bias".format(name_feat), + initializer=fluid.initializer.Constant(value=bias_init)) + else: + bias_attr = ParamAttr(name="{}.bias".format(name_feat)) + + if self.drop_block: + conv_feat = DropBlock( + conv_feat, block_size=3, keep_prob=0.9, is_test=is_test) + + conv_feat = fluid.layers.conv2d( + input=conv_feat, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=1, + param_attr=ParamAttr( + name="{}.weight".format(name_feat), + initializer=fluid.initializer.NormalInitializer(scale=0.01)), + bias_attr=bias_attr, + name=name + '_feat_') + return conv_feat + + def _points_nms(self, heat, kernel=2): + hmax = fluid.layers.pool2d( + input=heat, pool_size=kernel, pool_type='max', pool_padding=1) + keep = fluid.layers.cast((hmax[:, :, :-1, :-1] == heat), 'float32') + return heat * keep + + def _split_feats(self, feats): + return (paddle.nn.functional.interpolate( + feats[0], + scale_factor=0.5, + align_corners=False, + align_mode=0, + mode='bilinear'), feats[1], feats[2], feats[3], + paddle.nn.functional.interpolate( + feats[4], + size=fluid.layers.shape(feats[3])[-2:], + mode='bilinear', + align_corners=False, + align_mode=0)) + + def get_outputs(self, input, is_eval=False): + """ + Get SOLOv2 head output + + Args: + input (list): List of Variables, output of backbone or neck stages + is_eval (bool): whether in train or test mode + Returns: + cate_pred_list (list): Variables of each category branch layer + kernel_pred_list (list): Variables of each kernel branch layer + """ + feats = self._split_feats(input) + cate_pred_list = [] + kernel_pred_list = [] + for idx in range(len(self.seg_num_grids)): + cate_pred, kernel_pred = self._get_output_single( + feats[idx], idx, is_eval=is_eval) + cate_pred_list.append(cate_pred) + kernel_pred_list.append(kernel_pred) + + return cate_pred_list, kernel_pred_list + + def _get_output_single(self, input, idx, is_eval=False): + ins_kernel_feat = input + # CoordConv + x_range = paddle.linspace( + -1, 1, fluid.layers.shape(ins_kernel_feat)[-1], dtype='float32') + y_range = paddle.linspace( + -1, 1, fluid.layers.shape(ins_kernel_feat)[-2], dtype='float32') + y, x = paddle.tensor.meshgrid([y_range, x_range]) + x = fluid.layers.unsqueeze(x, [0, 1]) + y = fluid.layers.unsqueeze(y, [0, 1]) + y = fluid.layers.expand( + y, expand_times=[fluid.layers.shape(ins_kernel_feat)[0], 1, 1, 1]) + x = fluid.layers.expand( + x, expand_times=[fluid.layers.shape(ins_kernel_feat)[0], 1, 1, 1]) + coord_feat = fluid.layers.concat([x, y], axis=1) + ins_kernel_feat = fluid.layers.concat( + [ins_kernel_feat, coord_feat], axis=1) + + # kernel branch + kernel_feat = ins_kernel_feat + seg_num_grid = self.seg_num_grids[idx] + kernel_feat = paddle.nn.functional.interpolate( + kernel_feat, + size=[seg_num_grid, seg_num_grid], + mode='bilinear', + align_corners=False, + align_mode=0) + cate_feat = kernel_feat[:, :-2, :, :] + + kernel_pred = self._conv_pred( + kernel_feat, + self.kernel_out_channels, + is_eval, + name='bbox_head.kernel_convs', + name_feat='bbox_head.solo_kernel') + + # cate branch + cate_pred = self._conv_pred( + cate_feat, + self.cate_out_channels, + is_eval, + name='bbox_head.cate_convs', + name_feat='bbox_head.solo_cate') + + if is_eval: + cate_pred = self._points_nms( + fluid.layers.sigmoid(cate_pred), kernel=2) + cate_pred = fluid.layers.transpose(cate_pred, [0, 2, 3, 1]) + return cate_pred, kernel_pred + + def get_loss(self, cate_preds, kernel_preds, ins_pred, ins_labels, + cate_labels, grid_order_list, fg_num): + """ + Get loss of network of SOLOv2. + + Args: + cate_preds (list): Variable list of categroy branch output. + kernel_preds (list): Variable list of kernel branch output. + ins_pred (list): Variable list of instance branch output. + ins_labels (list): List of instance labels pre batch. + cate_labels (list): List of categroy labels pre batch. + grid_order_list (list): List of index in pre grid. + fg_num (int): Number of positive samples in a mini-batch. + Returns: + loss_ins (Variable): The instance loss Variable of SOLOv2 network. + loss_cate (Variable): The category loss Variable of SOLOv2 network. + """ + new_kernel_preds = [] + pad_length_list = [] + for kernel_preds_level, grid_orders_level in zip(kernel_preds, + grid_order_list): + reshape_pred = fluid.layers.reshape( + kernel_preds_level, + shape=(fluid.layers.shape(kernel_preds_level)[0], + fluid.layers.shape(kernel_preds_level)[1], -1)) + reshape_pred = fluid.layers.transpose(reshape_pred, [0, 2, 1]) + reshape_pred = fluid.layers.reshape( + reshape_pred, shape=(-1, fluid.layers.shape(reshape_pred)[2])) + gathered_pred = fluid.layers.gather( + reshape_pred, index=grid_orders_level) + gathered_pred = fluid.layers.lod_reset(gathered_pred, + grid_orders_level) + pad_value = fluid.layers.assign(input=np.array( + [0.0], dtype=np.float32)) + pad_pred, pad_length = fluid.layers.sequence_pad( + gathered_pred, pad_value=pad_value) + new_kernel_preds.append(pad_pred) + pad_length_list.append(pad_length) + + # generate masks + ins_pred_list = [] + for kernel_pred, pad_length in zip(new_kernel_preds, pad_length_list): + cur_ins_pred = ins_pred + cur_ins_pred = fluid.layers.reshape( + cur_ins_pred, + shape=(fluid.layers.shape(cur_ins_pred)[0], + fluid.layers.shape(cur_ins_pred)[1], -1)) + ins_pred_conv = paddle.matmul(kernel_pred, cur_ins_pred) + cur_ins_pred = fluid.layers.reshape( + ins_pred_conv, + shape=(fluid.layers.shape(ins_pred_conv)[0], + fluid.layers.shape(ins_pred_conv)[1], + fluid.layers.shape(ins_pred)[-2], + fluid.layers.shape(ins_pred)[-1])) + cur_ins_pred = fluid.layers.sequence_unpad(cur_ins_pred, pad_length) + ins_pred_list.append(cur_ins_pred) + + num_ins = fluid.layers.reduce_sum(fg_num) + cate_preds = [ + fluid.layers.reshape( + fluid.layers.transpose(cate_pred, [0, 2, 3, 1]), + shape=(-1, self.cate_out_channels)) for cate_pred in cate_preds + ] + flatten_cate_preds = fluid.layers.concat(cate_preds) + new_cate_labels = [] + cate_labels = fluid.layers.concat(cate_labels) + cate_labels = fluid.layers.unsqueeze(cate_labels, 1) + loss_ins, loss_cate = self.solov2_loss( + ins_pred_list, ins_labels, flatten_cate_preds, cate_labels, num_ins) + + return {'loss_ins': loss_ins, 'loss_cate': loss_cate} + + def get_prediction(self, cate_preds, kernel_preds, seg_pred, im_info): + """ + Get prediction result of SOLOv2 network + + Args: + cate_preds (list): List of Variables, output of categroy branch. + kernel_preds (list): List of Variables, output of kernel branch. + seg_pred (list): List of Variables, output of mask head stages. + im_info(Variables): [h, w, scale] for input images. + Returns: + seg_masks (Variable): The prediction segmentation. + cate_labels (Variable): The prediction categroy label of each segmentation. + seg_masks (Variable): The prediction score of each segmentation. + """ + num_levels = len(cate_preds) + featmap_size = fluid.layers.shape(seg_pred)[-2:] + seg_masks_list = [] + cate_labels_list = [] + cate_scores_list = [] + cate_preds = [cate_pred * 1.0 for cate_pred in cate_preds] + kernel_preds = [kernel_pred * 1.0 for kernel_pred in kernel_preds] + # Currently only supports batch size == 1 + for idx in range(1): + cate_pred_list = [ + fluid.layers.reshape( + cate_preds[i][idx], shape=(-1, self.cate_out_channels)) + for i in range(num_levels) + ] + seg_pred_list = seg_pred + kernel_pred_list = [ + fluid.layers.reshape( + fluid.layers.transpose(kernel_preds[i][idx], [1, 2, 0]), + shape=(-1, self.kernel_out_channels)) + for i in range(num_levels) + ] + cate_pred_list = fluid.layers.concat(cate_pred_list, axis=0) + kernel_pred_list = fluid.layers.concat(kernel_pred_list, axis=0) + + seg_masks, cate_labels, cate_scores = self.get_seg_single( + cate_pred_list, seg_pred_list, kernel_pred_list, featmap_size, + im_info[idx]) + return { + "segm": seg_masks, + 'cate_label': cate_labels, + 'cate_score': cate_scores + } + + def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size, + im_info): + + im_scale = im_info[2] + h = fluid.layers.cast(im_info[0], 'int32') + w = fluid.layers.cast(im_info[1], 'int32') + upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) + + inds = fluid.layers.where(cate_preds > self.score_threshold) + cate_preds = fluid.layers.reshape(cate_preds, shape=[-1]) + # Prevent empty and increase fake data + ind_a = fluid.layers.cast(fluid.layers.shape(kernel_preds)[0], 'int64') + ind_b = fluid.layers.zeros(shape=[1], dtype='int64') + inds_end = fluid.layers.unsqueeze( + fluid.layers.concat([ind_a, ind_b]), 0) + inds = fluid.layers.concat([inds, inds_end]) + kernel_preds_end = fluid.layers.ones( + shape=[1, self.kernel_out_channels], dtype='float32') + kernel_preds = fluid.layers.concat([kernel_preds, kernel_preds_end]) + cate_preds = fluid.layers.concat( + [cate_preds, fluid.layers.zeros( + shape=[1], dtype='float32')]) + + # cate_labels & kernel_preds + cate_labels = inds[:, 1] + kernel_preds = fluid.layers.gather(kernel_preds, index=inds[:, 0]) + cate_score_idx = fluid.layers.elementwise_add(inds[:, 0] * 80, + cate_labels) + cate_scores = fluid.layers.gather(cate_preds, index=cate_score_idx) + + size_trans = np.power(self.seg_num_grids, 2) + strides = [] + for _ind in range(len(self.segm_strides)): + strides.append( + fluid.layers.fill_constant( + shape=[int(size_trans[_ind])], + dtype="int32", + value=self.segm_strides[_ind])) + strides = fluid.layers.concat(strides) + strides = fluid.layers.gather(strides, index=inds[:, 0]) + + # mask encoding. + kernel_preds = fluid.layers.unsqueeze(kernel_preds, [2, 3]) + seg_preds = paddle.nn.functional.conv2d(seg_preds, kernel_preds) + seg_preds = fluid.layers.sigmoid(fluid.layers.squeeze(seg_preds, [0])) + seg_masks = seg_preds > self.mask_threshold + seg_masks = fluid.layers.cast(seg_masks, 'float32') + sum_masks = fluid.layers.reduce_sum(seg_masks, dim=[1, 2]) + + keep = fluid.layers.where(sum_masks > strides) + keep = fluid.layers.squeeze(keep, axes=[1]) + # Prevent empty and increase fake data + keep_other = fluid.layers.concat([ + keep, fluid.layers.cast( + fluid.layers.shape(sum_masks)[0] - 1, 'int64') + ]) + keep_scores = fluid.layers.concat([ + keep, fluid.layers.cast(fluid.layers.shape(sum_masks)[0], 'int64') + ]) + cate_scores_end = fluid.layers.zeros(shape=[1], dtype='float32') + cate_scores = fluid.layers.concat([cate_scores, cate_scores_end]) + + seg_masks = fluid.layers.gather(seg_masks, index=keep_other) + seg_preds = fluid.layers.gather(seg_preds, index=keep_other) + sum_masks = fluid.layers.gather(sum_masks, index=keep_other) + cate_labels = fluid.layers.gather(cate_labels, index=keep_other) + cate_scores = fluid.layers.gather(cate_scores, index=keep_scores) + + # mask scoring. + seg_mul = fluid.layers.cast(seg_preds * seg_masks, 'float32') + seg_scores = fluid.layers.reduce_sum(seg_mul, dim=[1, 2]) / sum_masks + cate_scores *= seg_scores + + # Matrix NMS + seg_preds, cate_scores, cate_labels = self.mask_nms( + seg_preds, seg_masks, cate_labels, cate_scores, sum_masks=sum_masks) + + ori_shape = im_info[:2] / im_scale + 0.5 + ori_shape = fluid.layers.cast(ori_shape, 'int32') + seg_preds = paddle.nn.functional.interpolate( + fluid.layers.unsqueeze(seg_preds, 0), + size=upsampled_size_out, + mode='bilinear', + align_corners=False, + align_mode=0)[:, :, :h, :w] + seg_masks = fluid.layers.squeeze( + paddle.nn.functional.interpolate( + seg_preds, + size=ori_shape[:2], + mode='bilinear', + align_corners=False, + align_mode=0), + axes=[0]) + # TODO: convert uint8 + seg_masks = fluid.layers.cast(seg_masks > self.mask_threshold, 'int32') + return seg_masks, cate_labels, cate_scores diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/ttf_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/ttf_head.py new file mode 100755 index 000000000..ba9ec802e --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/ttf_head.py @@ -0,0 +1,386 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Constant, Uniform, Xavier +from paddle.fluid.regularizer import L2Decay +from ppdet.core.workspace import register +from ppdet.modeling.ops import DeformConv, DropBlock +from ppdet.modeling.losses import GiouLoss + +__all__ = ['TTFHead'] + + +@register +class TTFHead(object): + """ + TTFHead + Args: + head_conv(int): the default channel number of convolution in head. + 128 by default. + num_classes(int): the number of classes, 80 by default. + hm_weight(float): the weight of heatmap branch. 1. by default. + wh_weight(float): the weight of wh branch. 5. by default. + wh_offset_base(flaot): the base offset of width and height. + 16. by default. + planes(tuple): the channel number of convolution in each upsample. + (256, 128, 64) by default. + shortcut_num(tuple): the number of convolution layers in each shortcut. + (1, 2, 3) by default. + wh_head_conv_num(int): the number of convolution layers in wh head. + 2 by default. + hm_head_conv_num(int): the number of convolution layers in wh head. + 2 by default. + wh_conv(int): the channel number of convolution in wh head. + 64 by default. + wh_planes(int): the output channel in wh head. 4 by default. + score_thresh(float): the score threshold to get prediction. + 0.01 by default. + max_per_img(int): the maximum detection per image. 100 by default. + base_down_ratio(int): the base down_ratio, the actual down_ratio is + calculated by base_down_ratio and the number of upsample layers. + 16 by default. + wh_loss(object): `GiouLoss` instance. + dcn_upsample(bool): whether upsample by dcn. True by default. + dcn_head(bool): whether use dcn in head. False by default. + drop_block(bool): whether use dropblock. False by default. + block_size(int): block_size parameter for drop_block. 3 by default. + keep_prob(float): keep_prob parameter for drop_block. 0.9 by default. + """ + + __inject__ = ['wh_loss'] + __shared__ = ['num_classes'] + + def __init__(self, + head_conv=128, + num_classes=80, + hm_weight=1., + wh_weight=5., + wh_offset_base=16., + planes=(256, 128, 64), + shortcut_num=(1, 2, 3), + wh_head_conv_num=2, + hm_head_conv_num=2, + wh_conv=64, + wh_planes=4, + score_thresh=0.01, + max_per_img=100, + base_down_ratio=32, + wh_loss='GiouLoss', + dcn_upsample=True, + dcn_head=False, + drop_block=False, + block_size=3, + keep_prob=0.9): + super(TTFHead, self).__init__() + self.head_conv = head_conv + self.num_classes = num_classes + self.hm_weight = hm_weight + self.wh_weight = wh_weight + self.wh_offset_base = wh_offset_base + self.planes = planes + self.shortcut_num = shortcut_num + self.shortcut_len = len(shortcut_num) + self.wh_head_conv_num = wh_head_conv_num + self.hm_head_conv_num = hm_head_conv_num + self.wh_conv = wh_conv + self.wh_planes = wh_planes + self.score_thresh = score_thresh + self.max_per_img = max_per_img + self.down_ratio = base_down_ratio // 2**len(planes) + self.hm_weight = hm_weight + self.wh_weight = wh_weight + self.wh_loss = wh_loss + self.dcn_upsample = dcn_upsample + self.dcn_head = dcn_head + self.drop_block = drop_block + self.block_size = block_size + self.keep_prob = keep_prob + + def shortcut(self, x, out_c, layer_num, kernel_size=3, padding=1, + name=None): + assert layer_num > 0 + for i in range(layer_num): + act = 'relu' if i < layer_num - 1 else None + fan_out = kernel_size * kernel_size * out_c + std = math.sqrt(2. / fan_out) + param_name = name + '.layers.' + str(i * 2) + x = fluid.layers.conv2d( + x, + out_c, + kernel_size, + padding=padding, + act=act, + param_attr=ParamAttr( + initializer=Normal(0, std), name=param_name + '.weight'), + bias_attr=ParamAttr( + learning_rate=2., + regularizer=L2Decay(0.), + name=param_name + '.bias')) + return x + + def upsample(self, x, out_c, name=None): + fan_in = x.shape[1] * 3 * 3 + stdv = 1. / math.sqrt(fan_in) + if self.dcn_upsample: + conv = DeformConv( + x, + out_c, + 3, + initializer=Uniform(-stdv, stdv), + bias_attr=True, + name=name + '.0') + else: + conv = fluid.layers.conv2d( + x, + out_c, + 3, + padding=1, + param_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), + bias_attr=ParamAttr( + learning_rate=2., regularizer=L2Decay(0.))) + + norm_name = name + '.1' + pattr = ParamAttr(name=norm_name + '.weight', initializer=Constant(1.)) + battr = ParamAttr(name=norm_name + '.bias', initializer=Constant(0.)) + bn = fluid.layers.batch_norm( + input=conv, + act='relu', + param_attr=pattr, + bias_attr=battr, + name=norm_name + '.output.1', + moving_mean_name=norm_name + '.running_mean', + moving_variance_name=norm_name + '.running_var') + up = fluid.layers.resize_bilinear( + bn, scale=2, name=name + '.2.upsample') + return up + + def _head(self, + x, + out_c, + conv_num=1, + head_out_c=None, + name=None, + is_test=False): + head_out_c = self.head_conv if not head_out_c else head_out_c + conv_w_std = 0.01 if '.hm' in name else 0.001 + conv_w_init = Normal(0, conv_w_std) + for i in range(conv_num): + conv_name = '{}.{}.conv'.format(name, i) + if self.dcn_head: + x = DeformConv( + x, + head_out_c, + 3, + initializer=conv_w_init, + name=conv_name + '.dcn') + x = fluid.layers.relu(x) + else: + x = fluid.layers.conv2d( + x, + head_out_c, + 3, + padding=1, + param_attr=ParamAttr( + initializer=conv_w_init, name=conv_name + '.weight'), + bias_attr=ParamAttr( + learning_rate=2., + regularizer=L2Decay(0.), + name=conv_name + '.bias'), + act='relu') + if self.drop_block and '.hm' in name: + x = DropBlock( + x, + block_size=self.block_size, + keep_prob=self.keep_prob, + is_test=is_test) + bias_init = float(-np.log((1 - 0.01) / 0.01)) if '.hm' in name else 0. + conv_b_init = Constant(bias_init) + x = fluid.layers.conv2d( + x, + out_c, + 1, + param_attr=ParamAttr( + initializer=conv_w_init, + name='{}.{}.weight'.format(name, conv_num)), + bias_attr=ParamAttr( + learning_rate=2., + regularizer=L2Decay(0.), + name='{}.{}.bias'.format(name, conv_num), + initializer=conv_b_init)) + return x + + def hm_head(self, x, name=None, is_test=False): + hm = self._head( + x, + self.num_classes, + self.hm_head_conv_num, + name=name, + is_test=is_test) + return hm + + def wh_head(self, x, name=None): + planes = self.wh_planes + wh = self._head( + x, planes, self.wh_head_conv_num, self.wh_conv, name=name) + return fluid.layers.relu(wh) + + def get_output(self, input, name=None, is_test=False): + feat = input[-1] + for i, out_c in enumerate(self.planes): + feat = self.upsample( + feat, out_c, name=name + '.deconv_layers.' + str(i)) + if i < self.shortcut_len: + shortcut = self.shortcut( + input[-i - 2], + out_c, + self.shortcut_num[i], + name=name + '.shortcut_layers.' + str(i)) + feat = fluid.layers.elementwise_add(feat, shortcut) + + hm = self.hm_head(feat, name=name + '.hm', is_test=is_test) + wh = self.wh_head(feat, name=name + '.wh') * self.wh_offset_base + + return hm, wh + + def _simple_nms(self, heat, kernel=3): + pad = (kernel - 1) // 2 + hmax = fluid.layers.pool2d(heat, kernel, 'max', pool_padding=pad) + keep = fluid.layers.cast(hmax == heat, 'float32') + return heat * keep + + def _topk(self, scores, k): + cat, height, width = scores.shape[1:] + # batch size is 1 + scores_r = fluid.layers.reshape(scores, [cat, -1]) + topk_scores, topk_inds = fluid.layers.topk(scores_r, k) + topk_ys = topk_inds / width + topk_xs = topk_inds % width + + topk_score_r = fluid.layers.reshape(topk_scores, [-1]) + topk_score, topk_ind = fluid.layers.topk(topk_score_r, k) + topk_clses = fluid.layers.cast(topk_ind / k, 'float32') + + topk_inds = fluid.layers.reshape(topk_inds, [-1]) + topk_ys = fluid.layers.reshape(topk_ys, [-1, 1]) + topk_xs = fluid.layers.reshape(topk_xs, [-1, 1]) + topk_inds = fluid.layers.gather(topk_inds, topk_ind) + topk_ys = fluid.layers.gather(topk_ys, topk_ind) + topk_xs = fluid.layers.gather(topk_xs, topk_ind) + + return topk_score, topk_inds, topk_clses, topk_ys, topk_xs + + def get_bboxes(self, heatmap, wh, scale_factor): + heatmap = fluid.layers.sigmoid(heatmap) + heat = self._simple_nms(heatmap) + scores, inds, clses, ys, xs = self._topk(heat, self.max_per_img) + ys = fluid.layers.cast(ys, 'float32') * self.down_ratio + xs = fluid.layers.cast(xs, 'float32') * self.down_ratio + scores = fluid.layers.unsqueeze(scores, [1]) + clses = fluid.layers.unsqueeze(clses, [1]) + + wh_t = fluid.layers.transpose(wh, [0, 2, 3, 1]) + wh = fluid.layers.reshape(wh_t, [-1, wh_t.shape[-1]]) + wh = fluid.layers.gather(wh, inds) + + x1 = xs - wh[:, 0:1] + y1 = ys - wh[:, 1:2] + x2 = xs + wh[:, 2:3] + y2 = ys + wh[:, 3:4] + bboxes = fluid.layers.concat([x1, y1, x2, y2], axis=1) + bboxes = fluid.layers.elementwise_div(bboxes, scale_factor, axis=-1) + results = fluid.layers.concat([clses, scores, bboxes], axis=1) + # hack: append result with cls=-1 and score=1. to avoid all scores + # are less than score_thresh which may cause error in gather. + fill_r = fluid.layers.assign( + np.array( + [[-1, 1., 0, 0, 0, 0]], dtype='float32')) + results = fluid.layers.concat([results, fill_r]) + scores = results[:, 1] + valid_ind = fluid.layers.where(scores > self.score_thresh) + results = fluid.layers.gather(results, valid_ind) + return {'bbox': results} + + def ct_focal_loss(self, pred_hm, target_hm, gamma=2.0): + fg_map = fluid.layers.cast(target_hm == 1, 'float32') + fg_map.stop_gradient = True + bg_map = fluid.layers.cast(target_hm < 1, 'float32') + bg_map.stop_gradient = True + + neg_weights = fluid.layers.pow(1 - target_hm, 4) * bg_map + pos_loss = 0 - fluid.layers.log(pred_hm) * fluid.layers.pow( + 1 - pred_hm, gamma) * fg_map + neg_loss = 0 - fluid.layers.log(1 - pred_hm) * fluid.layers.pow( + pred_hm, gamma) * neg_weights + pos_loss = fluid.layers.reduce_sum(pos_loss) + neg_loss = fluid.layers.reduce_sum(neg_loss) + + fg_num = fluid.layers.reduce_sum(fg_map) + focal_loss = (pos_loss + neg_loss) / ( + fg_num + fluid.layers.cast(fg_num == 0, 'float32')) + return focal_loss + + def filter_box_by_weight(self, pred, target, weight): + index = fluid.layers.where(weight > 0) + index.stop_gradient = True + weight = fluid.layers.gather_nd(weight, index) + pred = fluid.layers.gather_nd(pred, index) + target = fluid.layers.gather_nd(target, index) + return pred, target, weight + + def get_loss(self, pred_hm, pred_wh, target_hm, box_target, target_weight): + try: + pred_hm = paddle.clip(fluid.layers.sigmoid(pred_hm), 1e-4, 1 - 1e-4) + except: + pred_hm = paddle.tensor.clamp( + fluid.layers.sigmoid(pred_hm), 1e-4, 1 - 1e-4) + hm_loss = self.ct_focal_loss(pred_hm, target_hm) * self.hm_weight + shape = fluid.layers.shape(target_hm) + shape.stop_gradient = True + H, W = shape[2], shape[3] + + mask = fluid.layers.reshape(target_weight, [-1, H, W]) + avg_factor = fluid.layers.reduce_sum(mask) + 1e-4 + base_step = self.down_ratio + zero = fluid.layers.fill_constant(shape=[1], value=0, dtype='int32') + shifts_x = paddle.arange(zero, W * base_step, base_step, dtype='int32') + shifts_y = paddle.arange(zero, H * base_step, base_step, dtype='int32') + shift_y, shift_x = paddle.tensor.meshgrid([shifts_y, shifts_x]) + base_loc = fluid.layers.stack([shift_x, shift_y], axis=0) + base_loc.stop_gradient = True + + pred_boxes = fluid.layers.concat( + [0 - pred_wh[:, 0:2, :, :] + base_loc, pred_wh[:, 2:4] + base_loc], + axis=1) + pred_boxes = fluid.layers.transpose(pred_boxes, [0, 2, 3, 1]) + boxes = fluid.layers.transpose(box_target, [0, 2, 3, 1]) + boxes.stop_gradient = True + + pred_boxes, boxes, mask = self.filter_box_by_weight(pred_boxes, boxes, + mask) + mask.stop_gradient = True + wh_loss = self.wh_loss( + pred_boxes, boxes, outside_weight=mask, use_transform=False) + wh_loss = wh_loss / avg_factor + + ttf_loss = {'hm_loss': hm_loss, 'wh_loss': wh_loss} + return ttf_loss diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/yolo_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/yolo_head.py new file mode 100755 index 000000000..a0c3d2bc4 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/anchor_heads/yolo_head.py @@ -0,0 +1,642 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from ppdet.modeling.ops import MultiClassNMS, MultiClassSoftNMS, MatrixNMS +from ppdet.modeling.losses.yolo_loss import YOLOv3Loss +from ppdet.core.workspace import register +from ppdet.modeling.ops import DropBlock +from .iou_aware import get_iou_aware_score +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence +from ppdet.utils.check import check_version + +__all__ = ['YOLOv3Head', 'YOLOv4Head'] + + +@register +class YOLOv3Head(object): + """ + Head block for YOLOv3 network + + Args: + conv_block_num (int): number of conv block in each detection block + norm_decay (float): weight decay for normalization layer weights + num_classes (int): number of output classes + anchors (list): anchors + anchor_masks (list): anchor masks + nms (object): an instance of `MultiClassNMS` + """ + __inject__ = ['yolo_loss', 'nms'] + __shared__ = ['num_classes', 'weight_prefix_name'] + + def __init__(self, + conv_block_num=2, + norm_decay=0., + num_classes=80, + anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], + [59, 119], [116, 90], [156, 198], [373, 326]], + anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], + drop_block=False, + coord_conv=False, + iou_aware=False, + iou_aware_factor=0.4, + block_size=3, + keep_prob=0.9, + yolo_loss="YOLOv3Loss", + spp=False, + nms=MultiClassNMS( + score_threshold=0.01, + nms_top_k=1000, + keep_top_k=100, + nms_threshold=0.45, + background_label=-1).__dict__, + weight_prefix_name='', + downsample=[32, 16, 8], + scale_x_y=1.0, + clip_bbox=True): + check_version("1.8.4") + self.conv_block_num = conv_block_num + self.norm_decay = norm_decay + self.num_classes = num_classes + self.anchor_masks = anchor_masks + self._parse_anchors(anchors) + self.yolo_loss = yolo_loss + self.nms = nms + self.prefix_name = weight_prefix_name + self.drop_block = drop_block + self.iou_aware = iou_aware + self.coord_conv = coord_conv + self.iou_aware_factor = iou_aware_factor + self.block_size = block_size + self.keep_prob = keep_prob + self.use_spp = spp + if isinstance(nms, dict): + self.nms = MultiClassNMS(**nms) + self.downsample = downsample + self.scale_x_y = scale_x_y + self.clip_bbox = clip_bbox + + def _create_tensor_from_numpy(self, numpy_array): + paddle_array = fluid.layers.create_global_var( + shape=numpy_array.shape, value=0., dtype=numpy_array.dtype) + fluid.layers.assign(numpy_array, paddle_array) + return paddle_array + + def _add_coord(self, input, is_test=True): + if not self.coord_conv: + return input + + # NOTE: here is used for exporting model for TensorRT inference, + # only support batch_size=1 for input shape should be fixed, + # and we create tensor with fixed shape from numpy array + if is_test and input.shape[2] > 0 and input.shape[3] > 0: + batch_size = 1 + grid_x = int(input.shape[3]) + grid_y = int(input.shape[2]) + idx_i = np.array( + [[i / (grid_x - 1) * 2.0 - 1 for i in range(grid_x)]], + dtype='float32') + gi_np = np.repeat(idx_i, grid_y, axis=0) + gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x]) + gi_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1]) + + x_range = self._create_tensor_from_numpy(gi_np.astype(np.float32)) + x_range.stop_gradient = True + + idx_j = np.array( + [[j / (grid_y - 1) * 2.0 - 1 for j in range(grid_y)]], + dtype='float32') + gj_np = np.repeat(idx_j, grid_x, axis=1) + gj_np = np.reshape(gj_np, newshape=[1, 1, grid_y, grid_x]) + gj_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1]) + y_range = self._create_tensor_from_numpy(gj_np.astype(np.float32)) + y_range.stop_gradient = True + + # NOTE: in training mode, H and W is variable for random shape, + # implement add_coord with shape as Variable + else: + input_shape = fluid.layers.shape(input) + b = input_shape[0] + h = input_shape[2] + w = input_shape[3] + + x_range = fluid.layers.range(0, w, 1, 'float32') / ((w - 1.) / 2.) + x_range = x_range - 1. + x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2]) + x_range = fluid.layers.expand(x_range, [b, 1, h, 1]) + x_range.stop_gradient = True + + y_range = fluid.layers.range(0, h, 1, 'float32') / ((h - 1.) / 2.) + y_range = y_range - 1. + y_range = fluid.layers.unsqueeze(y_range, [0, 1, 3]) + y_range = fluid.layers.expand(y_range, [b, 1, 1, w]) + y_range.stop_gradient = True + + return fluid.layers.concat([input, x_range, y_range], axis=1) + + def _conv_bn(self, + input, + ch_out, + filter_size, + stride, + padding, + act='leaky', + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + act=None, + param_attr=ParamAttr(name=name + ".conv.weights"), + bias_attr=False) + + bn_name = name + ".bn" + bn_param_attr = ParamAttr( + regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale') + bn_bias_attr = ParamAttr( + regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset') + out = fluid.layers.batch_norm( + input=conv, + act=None, + param_attr=bn_param_attr, + bias_attr=bn_bias_attr, + moving_mean_name=bn_name + '.mean', + moving_variance_name=bn_name + '.var') + + if act == 'leaky': + out = fluid.layers.leaky_relu(x=out, alpha=0.1) + return out + + def _spp_module(self, input, name=""): + output1 = input + output2 = fluid.layers.pool2d( + input=output1, + pool_size=5, + pool_stride=1, + pool_padding=2, + ceil_mode=False, + pool_type='max') + output3 = fluid.layers.pool2d( + input=output1, + pool_size=9, + pool_stride=1, + pool_padding=4, + ceil_mode=False, + pool_type='max') + output4 = fluid.layers.pool2d( + input=output1, + pool_size=13, + pool_stride=1, + pool_padding=6, + ceil_mode=False, + pool_type='max') + output = fluid.layers.concat( + input=[output1, output2, output3, output4], axis=1) + return output + + def _detection_block(self, + input, + channel, + conv_block_num=2, + is_first=False, + is_test=True, + name=None): + assert channel % 2 == 0, \ + "channel {} cannot be divided by 2 in detection block {}" \ + .format(channel, name) + + conv = input + for j in range(conv_block_num): + conv = self._add_coord(conv, is_test=is_test) + conv = self._conv_bn( + conv, + channel, + filter_size=1, + stride=1, + padding=0, + name='{}.{}.0'.format(name, j)) + if self.use_spp and is_first and j == 1: + conv = self._spp_module(conv, name="spp") + conv = self._conv_bn( + conv, + 512, + filter_size=1, + stride=1, + padding=0, + name='{}.{}.spp.conv'.format(name, j)) + conv = self._conv_bn( + conv, + channel * 2, + filter_size=3, + stride=1, + padding=1, + name='{}.{}.1'.format(name, j)) + if self.drop_block and j == 0 and not is_first: + conv = DropBlock( + conv, + block_size=self.block_size, + keep_prob=self.keep_prob, + is_test=is_test) + + if self.use_spp and conv_block_num == 0 and is_first: + conv = self._spp_module(conv, name="spp") + + if self.drop_block and (is_first or conv_block_num == 0): + conv = DropBlock( + conv, + block_size=self.block_size, + keep_prob=self.keep_prob, + is_test=is_test) + conv = self._add_coord(conv, is_test=is_test) + route = self._conv_bn( + conv, + channel, + filter_size=1, + stride=1, + padding=0, + name='{}.2'.format(name)) + new_route = self._add_coord(route, is_test=is_test) + tip = self._conv_bn( + new_route, + channel * 2, + filter_size=3, + stride=1, + padding=1, + name='{}.tip'.format(name)) + return route, tip + + def _upsample(self, input, scale=2, name=None): + out = fluid.layers.resize_nearest( + input=input, scale=float(scale), name=name) + return out + + def _parse_anchors(self, anchors): + """ + Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors + + """ + self.anchors = [] + self.mask_anchors = [] + + assert len(anchors) > 0, "ANCHORS not set." + assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set." + + for anchor in anchors: + assert len(anchor) == 2, "anchor {} len should be 2".format(anchor) + self.anchors.extend(anchor) + + anchor_num = len(anchors) + for masks in self.anchor_masks: + self.mask_anchors.append([]) + for mask in masks: + assert mask < anchor_num, "anchor mask index overflow" + self.mask_anchors[-1].extend(anchors[mask]) + + def _get_outputs(self, input, is_train=True): + """ + Get YOLOv3 head output + + Args: + input (list): List of Variables, output of backbone stages + is_train (bool): whether in train or test mode + + Returns: + outputs (list): Variables of each output layer + """ + + outputs = [] + + # get last out_layer_num blocks in reverse order + out_layer_num = len(self.anchor_masks) + blocks = input[-1:-out_layer_num - 1:-1] + + route = None + for i, block in enumerate(blocks): + if i > 0: # perform concat in first 2 detection_block + block = fluid.layers.concat(input=[route, block], axis=1) + route, tip = self._detection_block( + block, + channel=64 * (2**out_layer_num) // (2**i), + is_first=i == 0, + is_test=(not is_train), + conv_block_num=self.conv_block_num, + name=self.prefix_name + "yolo_block.{}".format(i)) + + # out channel number = mask_num * (5 + class_num) + if self.iou_aware: + num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6) + else: + num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) + with fluid.name_scope('yolo_output'): + block_out = fluid.layers.conv2d( + input=tip, + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + act=None, + param_attr=ParamAttr( + name=self.prefix_name + + "yolo_output.{}.conv.weights".format(i)), + bias_attr=ParamAttr( + regularizer=L2Decay(0.), + name=self.prefix_name + + "yolo_output.{}.conv.bias".format(i))) + outputs.append(block_out) + + if i < len(blocks) - 1: + # do not perform upsample in the last detection_block + route = self._conv_bn( + input=route, + ch_out=256 // (2**i), + filter_size=1, + stride=1, + padding=0, + name=self.prefix_name + "yolo_transition.{}".format(i)) + # upsample + route = self._upsample(route) + + return outputs + + def get_loss(self, input, gt_box, gt_label, gt_score, targets): + """ + Get final loss of network of YOLOv3. + + Args: + input (list): List of Variables, output of backbone stages + gt_box (Variable): The ground-truth boudding boxes. + gt_label (Variable): The ground-truth class labels. + gt_score (Variable): The ground-truth boudding boxes mixup scores. + targets ([Variables]): List of Variables, the targets for yolo + loss calculatation. + + Returns: + loss (Variable): The loss Variable of YOLOv3 network. + + """ + outputs = self._get_outputs(input, is_train=True) + + return self.yolo_loss(outputs, gt_box, gt_label, gt_score, targets, + self.anchors, self.anchor_masks, + self.mask_anchors, self.num_classes, + self.prefix_name) + + def get_prediction(self, input, im_size, exclude_nms=False): + """ + Get prediction result of YOLOv3 network + + Args: + input (list): List of Variables, output of backbone stages + im_size (Variable): Variable of size([h, w]) of each image + + Returns: + pred (Variable): The prediction result after non-max suppress. + + """ + + outputs = self._get_outputs(input, is_train=False) + + boxes = [] + scores = [] + for i, output in enumerate(outputs): + if self.iou_aware: + output = get_iou_aware_score(output, + len(self.anchor_masks[i]), + self.num_classes, + self.iou_aware_factor) + scale_x_y = self.scale_x_y if not isinstance( + self.scale_x_y, Sequence) else self.scale_x_y[i] + box, score = fluid.layers.yolo_box( + x=output, + img_size=im_size, + anchors=self.mask_anchors[i], + class_num=self.num_classes, + conf_thresh=self.nms.score_threshold, + downsample_ratio=self.downsample[i], + name=self.prefix_name + "yolo_box" + str(i), + clip_bbox=self.clip_bbox, + scale_x_y=scale_x_y) + boxes.append(box) + scores.append(fluid.layers.transpose(score, perm=[0, 2, 1])) + + yolo_boxes = fluid.layers.concat(boxes, axis=1) + yolo_scores = fluid.layers.concat(scores, axis=2) + + # Only for benchmark, postprocess(NMS) is not needed + if exclude_nms: + return {'bbox': yolo_scores} + + if type(self.nms) is MultiClassSoftNMS: + yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1]) + pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores) + return {'bbox': pred} + + +@register +class YOLOv4Head(YOLOv3Head): + """ + Head block for YOLOv4 network + + Args: + anchors (list): anchors + anchor_masks (list): anchor masks + nms (object): an instance of `MultiClassNMS` + spp_stage (int): apply spp on which stage. + num_classes (int): number of output classes + downsample (list): downsample ratio for each yolo_head + scale_x_y (list): scale the center point of bbox at each stage + """ + __inject__ = ['nms', 'yolo_loss'] + __shared__ = ['num_classes', 'weight_prefix_name'] + + def __init__(self, + anchors=[[12, 16], [19, 36], [40, 28], [36, 75], [76, 55], + [72, 146], [142, 110], [192, 243], [459, 401]], + anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]], + nms=MultiClassNMS( + score_threshold=0.01, + nms_top_k=-1, + keep_top_k=-1, + nms_threshold=0.45, + background_label=-1).__dict__, + spp_stage=5, + num_classes=80, + weight_prefix_name='', + downsample=[8, 16, 32], + scale_x_y=1.0, + yolo_loss="YOLOv3Loss", + iou_aware=False, + iou_aware_factor=0.4, + clip_bbox=False): + super(YOLOv4Head, self).__init__( + anchors=anchors, + anchor_masks=anchor_masks, + nms=nms, + num_classes=num_classes, + weight_prefix_name=weight_prefix_name, + downsample=downsample, + scale_x_y=scale_x_y, + yolo_loss=yolo_loss, + iou_aware=iou_aware, + iou_aware_factor=iou_aware_factor, + clip_bbox=clip_bbox) + self.spp_stage = spp_stage + + def _upsample(self, input, scale=2, name=None): + out = fluid.layers.resize_nearest( + input=input, scale=float(scale), name=name) + return out + + def max_pool(self, input, size): + pad = [(size - 1) // 2] * 2 + return fluid.layers.pool2d(input, size, 'max', pool_padding=pad) + + def spp(self, input): + branch_a = self.max_pool(input, 13) + branch_b = self.max_pool(input, 9) + branch_c = self.max_pool(input, 5) + out = fluid.layers.concat([branch_a, branch_b, branch_c, input], axis=1) + return out + + def stack_conv(self, + input, + ch_list=[512, 1024, 512], + filter_list=[1, 3, 1], + stride=1, + name=None): + conv = input + for i, (ch_out, f_size) in enumerate(zip(ch_list, filter_list)): + padding = 1 if f_size == 3 else 0 + conv = self._conv_bn( + conv, + ch_out=ch_out, + filter_size=f_size, + stride=stride, + padding=padding, + name='{}.{}'.format(name, i)) + return conv + + def spp_module(self, input, name=None): + conv = self.stack_conv(input, name=name + '.stack_conv.0') + spp_out = self.spp(conv) + conv = self.stack_conv(spp_out, name=name + '.stack_conv.1') + return conv + + def pan_module(self, input, filter_list, name=None): + for i in range(1, len(input)): + ch_out = input[i].shape[1] // 2 + conv_left = self._conv_bn( + input[i], + ch_out=ch_out, + filter_size=1, + stride=1, + padding=0, + name=name + '.{}.left'.format(i)) + ch_out = input[i - 1].shape[1] // 2 + conv_right = self._conv_bn( + input[i - 1], + ch_out=ch_out, + filter_size=1, + stride=1, + padding=0, + name=name + '.{}.right'.format(i)) + conv_right = self._upsample(conv_right) + pan_out = fluid.layers.concat([conv_left, conv_right], axis=1) + ch_list = [pan_out.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]] + input[i] = self.stack_conv( + pan_out, + ch_list=ch_list, + filter_list=filter_list, + name=name + '.stack_conv.{}'.format(i)) + return input + + def _get_outputs(self, input, is_train=True): + outputs = [] + filter_list = [1, 3, 1, 3, 1] + spp_stage = len(input) - self.spp_stage + # get last out_layer_num blocks in reverse order + out_layer_num = len(self.anchor_masks) + blocks = input[-1:-out_layer_num - 1:-1] + blocks[spp_stage] = self.spp_module( + blocks[spp_stage], name=self.prefix_name + "spp_module") + blocks = self.pan_module( + blocks, + filter_list=filter_list, + name=self.prefix_name + 'pan_module') + + # reverse order back to input + blocks = blocks[::-1] + + route = None + for i, block in enumerate(blocks): + if i > 0: # perform concat in first 2 detection_block + route = self._conv_bn( + route, + ch_out=route.shape[1] * 2, + filter_size=3, + stride=2, + padding=1, + name=self.prefix_name + 'yolo_block.route.{}'.format(i)) + block = fluid.layers.concat(input=[route, block], axis=1) + ch_list = [block.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]] + block = self.stack_conv( + block, + ch_list=ch_list, + filter_list=filter_list, + name=self.prefix_name + + 'yolo_block.stack_conv.{}'.format(i)) + route = block + + block_out = self._conv_bn( + block, + ch_out=block.shape[1] * 2, + filter_size=3, + stride=1, + padding=1, + name=self.prefix_name + 'yolo_output.{}.conv.0'.format(i)) + + if self.iou_aware: + num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6) + else: + num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) + block_out = fluid.layers.conv2d( + input=block_out, + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + act=None, + param_attr=ParamAttr(name=self.prefix_name + + "yolo_output.{}.conv.1.weights".format(i)), + bias_attr=ParamAttr( + regularizer=L2Decay(0.), + name=self.prefix_name + + "yolo_output.{}.conv.1.bias".format(i))) + outputs.append(block_out) + + return outputs diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/__init__.py new file mode 100755 index 000000000..7693a2c1e --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/__init__.py @@ -0,0 +1,49 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from . import faster_rcnn +from . import mask_rcnn +from . import cascade_rcnn +from . import cascade_mask_rcnn +from . import cascade_rcnn_cls_aware +from . import yolo +from . import ssd +from . import retinanet +from . import efficientdet +from . import blazeface +from . import faceboxes +from . import fcos +from . import cornernet_squeeze +from . import ttfnet +from . import htc +from . import solov2 + +from .faster_rcnn import * +from .mask_rcnn import * +from .cascade_rcnn import * +from .cascade_mask_rcnn import * +from .cascade_rcnn_cls_aware import * +from .yolo import * +from .ssd import * +from .retinanet import * +from .efficientdet import * +from .blazeface import * +from .faceboxes import * +from .fcos import * +from .cornernet_squeeze import * +from .ttfnet import * +from .htc import * +from .solov2 import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/blazeface.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/blazeface.py new file mode 100755 index 000000000..7508a6b08 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/blazeface.py @@ -0,0 +1,260 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import SSDOutputDecoder +from ppdet.modeling.losses import SSDWithLmkLoss + +__all__ = ['BlazeFace'] + + +@register +class BlazeFace(object): + """ + BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs, + see https://arxiv.org/abs/1907.05047 + + Args: + backbone (object): backbone instance + output_decoder (object): `SSDOutputDecoder` instance + min_sizes (list|None): min sizes of generated prior boxes. + max_sizes (list|None): max sizes of generated prior boxes. Default: None. + steps (list|None): step size of adjacent prior boxes on each feature map. + num_classes (int): number of output classes + use_density_prior_box (bool): whether or not use density_prior_box + instead of prior_box + densities (list|None): the densities of generated density prior boxes, + this attribute should be a list or tuple of integers + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'output_decoder'] + __shared__ = ['num_classes', 'with_lmk'] + + def __init__(self, + backbone="BlazeNet", + output_decoder=SSDOutputDecoder().__dict__, + min_sizes=[[16., 24.], [32., 48., 64., 80., 96., 128.]], + max_sizes=None, + steps=[8., 16.], + num_classes=2, + use_density_prior_box=False, + densities=[[2, 2], [2, 1, 1, 1, 1, 1]], + with_lmk=False, + lmk_loss=SSDWithLmkLoss().__dict__): + super(BlazeFace, self).__init__() + self.backbone = backbone + self.num_classes = num_classes + self.with_lmk = with_lmk + self.output_decoder = output_decoder + if isinstance(output_decoder, dict): + if self.with_lmk: + output_decoder['return_index'] = True + self.output_decoder = SSDOutputDecoder(**output_decoder) + self.min_sizes = min_sizes + self.max_sizes = max_sizes + self.steps = steps + self.use_density_prior_box = use_density_prior_box + self.densities = densities + self.landmark = None + if self.with_lmk and isinstance(lmk_loss, dict): + self.lmk_loss = SSDWithLmkLoss(**lmk_loss) + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + + body_feats = self.backbone(im) + locs, confs, box, box_var = self._multi_box_head( + inputs=body_feats, + image=im, + num_classes=self.num_classes, + use_density_prior_box=self.use_density_prior_box) + + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + gt_class = feed_vars['gt_class'] + if self.with_lmk: + lmk_labels = feed_vars['gt_keypoint'] + lmk_ignore_flag = feed_vars["keypoint_ignore"] + loss = self.lmk_loss(locs, confs, gt_bbox, gt_class, + self.landmark, lmk_labels, lmk_ignore_flag, + box, box_var) + else: + loss = fluid.layers.ssd_loss( + locs, + confs, + gt_bbox, + gt_class, + box, + box_var, + overlap_threshold=0.35, + neg_overlap=0.35) + + loss = fluid.layers.reduce_sum(loss) + return {'loss': loss} + else: + if self.with_lmk: + pred, face_index = self.output_decoder(locs, confs, box, + box_var) + return { + 'bbox': pred, + 'face_index': face_index, + 'prior_boxes': box, + 'landmark': self.landmark + } + else: + pred = self.output_decoder(locs, confs, box, box_var) + return {'bbox': pred} + + def _multi_box_head(self, + inputs, + image, + num_classes=2, + use_density_prior_box=False): + def permute_and_reshape(input, last_dim): + trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1]) + compile_shape = [0, -1, last_dim] + return fluid.layers.reshape(trans, shape=compile_shape) + + locs, confs = [], [] + boxes, vars = [], [] + lmk_locs = [] + b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.)) + + for i, input in enumerate(inputs): + min_size = self.min_sizes[i] + + if use_density_prior_box: + densities = self.densities[i] + box, var = fluid.layers.density_prior_box( + input, + image, + densities=densities, + fixed_sizes=min_size, + fixed_ratios=[1.], + clip=False, + offset=0.5, + steps=[self.steps[i]] * 2) + else: + box, var = fluid.layers.prior_box( + input, + image, + min_sizes=min_size, + max_sizes=None, + steps=[self.steps[i]] * 2, + aspect_ratios=[1.], + clip=False, + flip=False, + offset=0.5) + + num_boxes = box.shape[2] + + box = fluid.layers.reshape(box, shape=[-1, 4]) + var = fluid.layers.reshape(var, shape=[-1, 4]) + num_loc_output = num_boxes * 4 + num_conf_output = num_boxes * num_classes + # get loc + mbox_loc = fluid.layers.conv2d( + input, num_loc_output, 3, 1, 1, bias_attr=b_attr) + loc = permute_and_reshape(mbox_loc, 4) + # get conf + mbox_conf = fluid.layers.conv2d( + input, num_conf_output, 3, 1, 1, bias_attr=b_attr) + conf = permute_and_reshape(mbox_conf, num_classes) + + if self.with_lmk: + # get landmark + lmk_loc_output = num_boxes * 10 + lmk_box_loc = fluid.layers.conv2d( + input, + lmk_loc_output, + 3, + 1, + 1, + param_attr=ParamAttr(name='lmk' + str(i) + '_weights'), + bias_attr=False) + lmk_loc = permute_and_reshape(lmk_box_loc, 10) + lmk_locs.append(lmk_loc) + + locs.append(loc) + confs.append(conf) + boxes.append(box) + vars.append(var) + + face_mbox_loc = fluid.layers.concat(locs, axis=1) + face_mbox_conf = fluid.layers.concat(confs, axis=1) + prior_boxes = fluid.layers.concat(boxes) + box_vars = fluid.layers.concat(vars) + if self.with_lmk: + self.landmark = fluid.layers.concat(lmk_locs, axis=1) + return face_mbox_loc, face_mbox_conf, prior_boxes, box_vars + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'im_shape': {'shape': [None, 3], 'dtype': 'int32', 'lod_level': 0}, + 'gt_keypoint': {'shape': [None, 10], 'dtype': 'float32', 'lod_level': 1}, + 'keypoint_ignore': {'shape': [None, 1], 'dtype': 'float32', 'lod_level': 1}, + } + # yapf: enable + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=['image', 'im_id', 'gt_bbox', 'gt_class'], # for train + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars): + return self.build(feed_vars, 'eval') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') + + def is_bbox_normalized(self): + return True diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cascade_mask_rcnn.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cascade_mask_rcnn.py new file mode 100755 index 000000000..0fc5bb135 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cascade_mask_rcnn.py @@ -0,0 +1,447 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import copy + +import paddle.fluid as fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register +from ppdet.utils.check import check_version + +from .input_helper import multiscale_def + +__all__ = ['CascadeMaskRCNN'] + + +@register +class CascadeMaskRCNN(object): + """ + Cascade Mask R-CNN architecture, see https://arxiv.org/abs/1712.00726 + + Args: + backbone (object): backbone instance + rpn_head (object): `RPNhead` instance + bbox_assigner (object): `BBoxAssigner` instance + roi_extractor (object): ROI extractor instance + bbox_head (object): `BBoxHead` instance + mask_assigner (object): `MaskAssigner` instance + mask_head (object): `MaskHead` instance + fpn (object): feature pyramid network instance + """ + + __category__ = 'architecture' + __inject__ = [ + 'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head', + 'mask_assigner', 'mask_head', 'fpn' + ] + + def __init__(self, + backbone, + rpn_head, + roi_extractor='FPNRoIAlign', + bbox_head='CascadeBBoxHead', + bbox_assigner='CascadeBBoxAssigner', + mask_assigner='MaskAssigner', + mask_head='MaskHead', + rpn_only=False, + fpn='FPN'): + super(CascadeMaskRCNN, self).__init__() + check_version('2.0.0-rc0') + assert fpn is not None, "cascade RCNN requires FPN" + self.backbone = backbone + self.fpn = fpn + self.rpn_head = rpn_head + self.bbox_assigner = bbox_assigner + self.roi_extractor = roi_extractor + self.bbox_head = bbox_head + self.mask_assigner = mask_assigner + self.mask_head = mask_head + self.rpn_only = rpn_only + # Cascade local cfg + self.cls_agnostic_bbox_reg = 2 + (brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights + self.cascade_bbox_reg_weights = [ + [1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0], + [1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1], + [1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2] + ] + self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25] + + def build(self, feed_vars, mode='train'): + if mode == 'train': + required_fields = [ + 'gt_class', 'gt_bbox', 'gt_mask', 'is_crowd', 'im_info' + ] + else: + required_fields = ['im_shape', 'im_info'] + self._input_check(required_fields, feed_vars) + + im = feed_vars['image'] + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + is_crowd = feed_vars['is_crowd'] + + im_info = feed_vars['im_info'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + # backbone + body_feats = self.backbone(im) + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + # rpn proposals + rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode) + + if mode == 'train': + rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd) + else: + if self.rpn_only: + im_scale = fluid.layers.slice( + im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, rpn_rois) + rois = rpn_rois / im_scale + return {'proposal': rois} + + proposal_list = [] + roi_feat_list = [] + rcnn_pred_list = [] + rcnn_target_list = [] + + proposals = None + bbox_pred = None + max_overlap = None + for i in range(3): + if i > 0: + refined_bbox = self._decode_box( + proposals, + bbox_pred, + curr_stage=i - 1, ) + else: + refined_bbox = rpn_rois + + if mode == 'train': + outs = self.bbox_assigner( + input_rois=refined_bbox, + feed_vars=feed_vars, + curr_stage=i, + max_overlap=max_overlap) + + proposals = outs[0] + max_overlap = outs[-1] + rcnn_target_list.append(outs[:-1]) + else: + proposals = refined_bbox + proposal_list.append(proposals) + + # extract roi features + roi_feat = self.roi_extractor(body_feats, proposals, spatial_scale) + roi_feat_list.append(roi_feat) + + # bbox head + cls_score, bbox_pred = self.bbox_head.get_output( + roi_feat, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i], + name='_' + str(i + 1) if i > 0 else '') + rcnn_pred_list.append((cls_score, bbox_pred)) + + # get mask rois + rois = proposal_list[2] + + if mode == 'train': + loss = self.bbox_head.get_loss(rcnn_pred_list, rcnn_target_list, + self.cascade_rcnn_loss_weight) + loss.update(rpn_loss) + + labels_int32 = rcnn_target_list[2][1] + + mask_rois, roi_has_mask_int32, mask_int32 = self.mask_assigner( + rois=rois, + gt_classes=feed_vars['gt_class'], + is_crowd=feed_vars['is_crowd'], + gt_segms=feed_vars['gt_mask'], + im_info=feed_vars['im_info'], + labels_int32=labels_int32) + + if self.fpn is None: + bbox_head_feat = self.bbox_head.get_head_feat() + feat = fluid.layers.gather(bbox_head_feat, roi_has_mask_int32) + else: + feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + mask_loss = self.mask_head.get_loss(feat, mask_int32) + loss.update(mask_loss) + + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + mask_name = 'mask_pred' + mask_pred, bbox_pred = self.single_scale_eval( + body_feats, spatial_scale, im_info, mask_name, bbox_pred, + roi_feat_list, rcnn_pred_list, proposal_list, + feed_vars['im_shape']) + return {'bbox': bbox_pred, 'mask': mask_pred} + + def build_multi_scale(self, feed_vars, mask_branch=False): + required_fields = ['image', 'im_info'] + self._input_check(required_fields, feed_vars) + + result = {} + if not mask_branch: + assert 'im_shape' in feed_vars, \ + "{} has no im_shape field".format(feed_vars) + result.update(feed_vars) + + for i in range(len(self.im_info_names) // 2): + im = feed_vars[self.im_info_names[2 * i]] + im_info = feed_vars[self.im_info_names[2 * i + 1]] + body_feats = self.backbone(im) + + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + rois = self.rpn_head.get_proposals(body_feats, im_info, mode='test') + if not mask_branch: + im_shape = feed_vars['im_shape'] + body_feat_names = list(body_feats.keys()) + proposal_list = [] + roi_feat_list = [] + rcnn_pred_list = [] + + proposals = None + bbox_pred = None + for i in range(3): + if i > 0: + refined_bbox = self._decode_box( + proposals, + bbox_pred, + curr_stage=i - 1, ) + else: + refined_bbox = rois + + proposals = refined_bbox + proposal_list.append(proposals) + + # extract roi features + roi_feat = self.roi_extractor(body_feats, proposals, + spatial_scale) + roi_feat_list.append(roi_feat) + + # bbox head + cls_score, bbox_pred = self.bbox_head.get_output( + roi_feat, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i], + name='_' + str(i + 1) if i > 0 else '') + rcnn_pred_list.append((cls_score, bbox_pred)) + + # get mask rois + if self.fpn is None: + body_feat = body_feats[body_feat_names[-1]] + pred = self.bbox_head.get_prediction( + im_info, + im_shape, + roi_feat_list, + rcnn_pred_list, + proposal_list, + self.cascade_bbox_reg_weights, + return_box_score=True) + bbox_name = 'bbox_' + str(i) + score_name = 'score_' + str(i) + if 'flip' in im.name: + bbox_name += '_flip' + score_name += '_flip' + result[bbox_name] = pred['bbox'] + result[score_name] = pred['score'] + else: + mask_name = 'mask_pred_' + str(i) + bbox_pred = feed_vars['bbox'] + if 'flip' in im.name: + mask_name += '_flip' + bbox_pred = feed_vars['bbox_flip'] + mask_pred, bbox_pred = self.single_scale_eval( + body_feats, + spatial_scale, + im_info, + mask_name, + bbox_pred=bbox_pred, + use_multi_test=True) + result[mask_name] = mask_pred + return result + + def single_scale_eval(self, + body_feats, + spatial_scale, + im_info, + mask_name, + bbox_pred, + roi_feat_list=None, + rcnn_pred_list=None, + proposal_list=None, + im_shape=None, + use_multi_test=False): + if self.fpn is None: + last_feat = body_feats[list(body_feats.keys())[-1]] + if not use_multi_test: + bbox_pred = self.bbox_head.get_prediction( + im_info, im_shape, roi_feat_list, rcnn_pred_list, proposal_list, + self.cascade_bbox_reg_weights) + bbox_pred = bbox_pred['bbox'] + + # share weight + bbox_shape = fluid.layers.shape(bbox_pred) + bbox_size = fluid.layers.reduce_prod(bbox_shape) + bbox_size = fluid.layers.reshape(bbox_size, [1, 1]) + size = fluid.layers.fill_constant([1, 1], value=6, dtype='int32') + cond = fluid.layers.less_than(x=bbox_size, y=size) + + mask_pred = fluid.layers.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=False, + name=mask_name) + + def noop(): + fluid.layers.assign(input=bbox_pred, output=mask_pred) + + def process_boxes(): + bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6]) + + im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, bbox) + + mask_rois = bbox * im_scale + if self.fpn is None: + mask_feat = self.roi_extractor(last_feat, mask_rois) + mask_feat = self.bbox_head.get_head_feat(mask_feat) + else: + mask_feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + + mask_out = self.mask_head.get_prediction(mask_feat, bbox) + fluid.layers.assign(input=mask_out, output=mask_pred) + + fluid.layers.cond(cond, noop, process_boxes) + return mask_pred, bbox_pred + + def _input_check(self, require_fields, feed_vars): + for var in require_fields: + assert var in feed_vars, \ + "{} has no {} field".format(feed_vars, var) + + def _decode_box(self, proposals, bbox_pred, curr_stage): + rcnn_loc_delta_r = fluid.layers.reshape( + bbox_pred, (-1, self.cls_agnostic_bbox_reg, 4)) + # only use fg box delta to decode box + rcnn_loc_delta_s = fluid.layers.slice( + rcnn_loc_delta_r, axes=[1], starts=[1], ends=[2]) + refined_bbox = fluid.layers.box_coder( + prior_box=proposals, + prior_box_var=self.cascade_bbox_reg_weights[curr_stage], + target_box=rcnn_loc_delta_s, + code_type='decode_center_size', + box_normalized=False, + axis=1, ) + refined_bbox = fluid.layers.reshape(refined_bbox, shape=[-1, 4]) + + return refined_bbox + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'gt_mask': {'shape': [None, 2], 'dtype': 'float32', 'lod_level': 3}, # polygon coordinates + 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + } + # yapf: enable + return inputs_def + + def build_inputs(self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', + 'is_crowd', 'gt_mask' + ], + multi_scale=False, + num_scales=-1, + use_flip=None, + use_dataloader=True, + iterable=False, + mask_branch=False): + inputs_def = self._inputs_def(image_shape) + fields = copy.deepcopy(fields) + if multi_scale: + ms_def, ms_fields = multiscale_def(image_shape, num_scales, + use_flip) + inputs_def.update(ms_def) + fields += ms_fields + self.im_info_names = ['image', 'im_info'] + ms_fields + if mask_branch: + box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox'] + for key in box_fields: + inputs_def[key] = { + 'shape': [None, 6], + 'dtype': 'float32', + 'lod_level': 1 + } + fields += box_fields + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + use_dataloader = use_dataloader and not mask_branch + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars, multi_scale=None, mask_branch=False): + if multi_scale: + return self.build_multi_scale(feed_vars, mask_branch) + return self.build(feed_vars, 'test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn.py new file mode 100755 index 000000000..e018caeae --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn.py @@ -0,0 +1,344 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +from collections import OrderedDict + +import paddle.fluid as fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register +from ppdet.utils.check import check_version +from .input_helper import multiscale_def + +__all__ = ['CascadeRCNN'] + + +@register +class CascadeRCNN(object): + """ + Cascade R-CNN architecture, see https://arxiv.org/abs/1712.00726 + + Args: + backbone (object): backbone instance + rpn_head (object): `RPNhead` instance + bbox_assigner (object): `BBoxAssigner` instance + roi_extractor (object): ROI extractor instance + bbox_head (object): `BBoxHead` instance + fpn (object): feature pyramid network instance + """ + + __category__ = 'architecture' + __inject__ = [ + 'backbone', 'fpn', 'rpn_head', 'bbox_assigner', 'roi_extractor', + 'bbox_head' + ] + + def __init__(self, + backbone, + rpn_head, + roi_extractor='FPNRoIAlign', + bbox_head='CascadeBBoxHead', + bbox_assigner='CascadeBBoxAssigner', + rpn_only=False, + fpn='FPN'): + super(CascadeRCNN, self).__init__() + check_version('2.0.0-rc0') + assert fpn is not None, "cascade RCNN requires FPN" + self.backbone = backbone + self.fpn = fpn + self.rpn_head = rpn_head + self.bbox_assigner = bbox_assigner + self.roi_extractor = roi_extractor + self.bbox_head = bbox_head + self.rpn_only = rpn_only + # Cascade local cfg + self.cls_agnostic_bbox_reg = 2 + (brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights + self.cascade_bbox_reg_weights = [ + [1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0], + [1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1], + [1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2] + ] + self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25] + + def build(self, feed_vars, mode='train'): + if mode == 'train': + required_fields = ['gt_class', 'gt_bbox', 'is_crowd', 'im_info'] + else: + required_fields = ['im_shape', 'im_info'] + self._input_check(required_fields, feed_vars) + + im = feed_vars['image'] + im_info = feed_vars['im_info'] + + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + is_crowd = feed_vars['is_crowd'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + # backbone + body_feats = self.backbone(im) + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + # rpn proposals + rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode) + + if mode == 'train': + #fluid.layers.Print(gt_bbox) + #fluid.layers.Print(is_crowd) + rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd) + else: + if self.rpn_only: + im_scale = fluid.layers.slice( + im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, rpn_rois) + rois = rpn_rois / im_scale + return {'proposal': rois} + + proposal_list = [] + roi_feat_list = [] + rcnn_pred_list = [] + rcnn_target_list = [] + + proposals = None + bbox_pred = None + max_overlap = None + for i in range(3): + if i > 0: + refined_bbox = self._decode_box( + proposals, + bbox_pred, + curr_stage=i - 1, ) + else: + refined_bbox = rpn_rois + + if mode == 'train': + outs = self.bbox_assigner( + input_rois=refined_bbox, + feed_vars=feed_vars, + curr_stage=i, + max_overlap=max_overlap) + + proposals = outs[0] + max_overlap = outs[-1] + rcnn_target_list.append(outs[:-1]) + else: + proposals = refined_bbox + proposal_list.append(proposals) + + # extract roi features + roi_feat = self.roi_extractor(body_feats, proposals, spatial_scale) + roi_feat_list.append(roi_feat) + + # bbox head + cls_score, bbox_pred = self.bbox_head.get_output( + roi_feat, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i], + name='_' + str(i + 1) if i > 0 else '') + rcnn_pred_list.append((cls_score, bbox_pred)) + + if mode == 'train': + loss = self.bbox_head.get_loss(rcnn_pred_list, rcnn_target_list, + self.cascade_rcnn_loss_weight) + loss.update(rpn_loss) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + pred = self.bbox_head.get_prediction( + im_info, feed_vars['im_shape'], roi_feat_list, rcnn_pred_list, + proposal_list, self.cascade_bbox_reg_weights, + self.cls_agnostic_bbox_reg) + return pred + + def build_multi_scale(self, feed_vars): + required_fields = ['image', 'im_shape', 'im_info'] + self._input_check(required_fields, feed_vars) + + result = {} + im_shape = feed_vars['im_shape'] + result['im_shape'] = im_shape + + for i in range(len(self.im_info_names) // 2): + im = feed_vars[self.im_info_names[2 * i]] + im_info = feed_vars[self.im_info_names[2 * i + 1]] + + # backbone + body_feats = self.backbone(im) + result.update(body_feats) + + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + # rpn proposals + rpn_rois = self.rpn_head.get_proposals( + body_feats, im_info, mode='test') + + proposal_list = [] + roi_feat_list = [] + rcnn_pred_list = [] + + proposals = None + bbox_pred = None + for i in range(3): + if i > 0: + refined_bbox = self._decode_box( + proposals, + bbox_pred, + curr_stage=i - 1, ) + else: + refined_bbox = rpn_rois + + proposals = refined_bbox + proposal_list.append(proposals) + + # extract roi features + roi_feat = self.roi_extractor(body_feats, proposals, + spatial_scale) + roi_feat_list.append(roi_feat) + + # bbox head + cls_score, bbox_pred = self.bbox_head.get_output( + roi_feat, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i], + name='_' + str(i + 1) if i > 0 else '') + rcnn_pred_list.append((cls_score, bbox_pred)) + + # get mask rois + rois = proposal_list[2] + + if self.fpn is None: + last_feat = body_feats[list(body_feats.keys())[-1]] + roi_feat = self.roi_extractor(last_feat, rois) + else: + roi_feat = self.roi_extractor(body_feats, rois, spatial_scale) + + pred = self.bbox_head.get_prediction( + im_info, + im_shape, + roi_feat_list, + rcnn_pred_list, + proposal_list, + self.cascade_bbox_reg_weights, + self.cls_agnostic_bbox_reg, + return_box_score=True) + bbox_name = 'bbox_' + str(i) + score_name = 'score_' + str(i) + if 'flip' in im.name: + bbox_name += '_flip' + score_name += '_flip' + result[bbox_name] = pred['bbox'] + result[score_name] = pred['score'] + return result + + def _input_check(self, require_fields, feed_vars): + for var in require_fields: + assert var in feed_vars, \ + "{} has no {} field".format(feed_vars, var) + + def _decode_box(self, proposals, bbox_pred, curr_stage): + rcnn_loc_delta_r = fluid.layers.reshape( + bbox_pred, (-1, self.cls_agnostic_bbox_reg, 4)) + # only use fg box delta to decode box + rcnn_loc_delta_s = fluid.layers.slice( + rcnn_loc_delta_r, axes=[1], starts=[1], ends=[2]) + refined_bbox = fluid.layers.box_coder( + prior_box=proposals, + prior_box_var=self.cascade_bbox_reg_weights[curr_stage], + target_box=rcnn_loc_delta_s, + code_type='decode_center_size', + box_normalized=False, + axis=1, ) + refined_bbox = fluid.layers.reshape(refined_bbox, shape=[-1, 4]) + + return refined_bbox + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + } + # yapf: enable + return inputs_def + + def build_inputs(self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', + 'is_crowd' + ], + multi_scale=False, + num_scales=-1, + use_flip=None, + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape) + fields = copy.deepcopy(fields) + if multi_scale: + ms_def, ms_fields = multiscale_def(image_shape, num_scales, + use_flip) + inputs_def.update(ms_def) + fields += ms_fields + self.im_info_names = ['image', 'im_info'] + ms_fields + + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars, multi_scale=None): + if multi_scale: + return self.build_multi_scale(feed_vars) + return self.build(feed_vars, 'test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn_cls_aware.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn_cls_aware.py new file mode 100755 index 000000000..837d87e97 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn_cls_aware.py @@ -0,0 +1,332 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import sys + +from collections import OrderedDict +import copy + +import paddle.fluid as fluid +from ppdet.core.workspace import register +from ppdet.utils.check import check_version +from .input_helper import multiscale_def + +__all__ = ['CascadeRCNNClsAware'] + + +@register +class CascadeRCNNClsAware(object): + """ + Cascade R-CNN architecture, see https://arxiv.org/abs/1712.00726 + This is a kind of modification of Cascade R-CNN. + Specifically, it predicts bboxes for all classes with different weights, + while the standard vesion just predicts bboxes for foreground + Args: + backbone (object): backbone instance + rpn_head (object): `RPNhead` instance + bbox_assigner (object): `BBoxAssigner` instance + roi_extractor (object): ROI extractor instance + bbox_head (object): `BBoxHead` instance + fpn (object): feature pyramid network instance + """ + + __category__ = 'architecture' + __inject__ = [ + 'backbone', 'fpn', 'rpn_head', 'bbox_assigner', 'roi_extractor', + 'bbox_head' + ] + + def __init__( + self, + backbone, + rpn_head, + roi_extractor='FPNRoIAlign', + bbox_head='CascadeBBoxHead', + bbox_assigner='CascadeBBoxAssigner', + fpn='FPN', ): + super(CascadeRCNNClsAware, self).__init__() + check_version('2.0.0-rc0') + assert fpn is not None, "cascade RCNN requires FPN" + self.backbone = backbone + self.fpn = fpn + self.rpn_head = rpn_head + self.bbox_assigner = bbox_assigner + self.roi_extractor = roi_extractor + self.bbox_head = bbox_head + self.bbox_clip = np.log(1000. / 16.) + # Cascade local cfg + (brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights + self.cascade_bbox_reg_weights = [ + [1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0], + [1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1], + [1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2] + ] + self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25] + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + im_info = feed_vars['im_info'] + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + is_crowd = feed_vars['is_crowd'] + gt_class = feed_vars['gt_class'] + else: + im_shape = feed_vars['im_shape'] + + # backbone + body_feats = self.backbone(im) + + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + # rpn proposals + rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode) + + if mode == 'train': + rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd) + + proposal_list = [] + roi_feat_list = [] + rcnn_pred_list = [] + rcnn_target_list = [] + + bbox_pred = None + + self.cascade_var_v = [] + for stage in range(3): + var_v = np.array( + self.cascade_bbox_reg_weights[stage], dtype="float32") + prior_box_var = fluid.layers.create_tensor(dtype="float32") + fluid.layers.assign(input=var_v, output=prior_box_var) + self.cascade_var_v.append(prior_box_var) + + self.cascade_decoded_box = [] + self.cascade_cls_prob = [] + max_overlap = None + + for stage in range(3): + if stage > 0: + pool_rois = decoded_assign_box + else: + pool_rois = rpn_rois + if mode == "train": + self.cascade_var_v[stage].stop_gradient = True + outs = self.bbox_assigner( + input_rois=pool_rois, + feed_vars=feed_vars, + curr_stage=stage, + max_overlap=max_overlap) + pool_rois = outs[0] + max_overlap = outs[-1] + rcnn_target_list.append(outs[:-1]) + + # extract roi features + roi_feat = self.roi_extractor(body_feats, pool_rois, spatial_scale) + roi_feat_list.append(roi_feat) + + # bbox head + cls_score, bbox_pred = self.bbox_head.get_output( + roi_feat, + cls_agnostic_bbox_reg=self.bbox_head.num_classes, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[stage], + name='_' + str(stage + 1)) + + cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False) + + decoded_box, decoded_assign_box = fluid.layers.box_decoder_and_assign( + pool_rois, self.cascade_var_v[stage], bbox_pred, cls_prob, + self.bbox_clip) + + if mode == "train": + decoded_box.stop_gradient = True + decoded_assign_box.stop_gradient = True + else: + self.cascade_cls_prob.append(cls_prob) + self.cascade_decoded_box.append(decoded_box) + + rcnn_pred_list.append((cls_score, bbox_pred)) + + # out loop + if mode == 'train': + loss = self.bbox_head.get_loss(rcnn_pred_list, rcnn_target_list, + self.cascade_rcnn_loss_weight) + loss.update(rpn_loss) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + pred = self.bbox_head.get_prediction_cls_aware( + im_info, im_shape, self.cascade_cls_prob, + self.cascade_decoded_box, self.cascade_bbox_reg_weights) + return pred + + def build_multi_scale(self, feed_vars): + required_fields = ['image', 'im_shape', 'im_info'] + self._input_check(required_fields, feed_vars) + + result = {} + im_shape = feed_vars['im_shape'] + result['im_shape'] = im_shape + + for i in range(len(self.im_info_names) // 2): + im = feed_vars[self.im_info_names[2 * i]] + im_info = feed_vars[self.im_info_names[2 * i + 1]] + + # backbone + body_feats = self.backbone(im) + result.update(body_feats) + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + # rpn proposals + rpn_rois = self.rpn_head.get_proposals( + body_feats, im_info, mode="test") + + proposal_list = [] + roi_feat_list = [] + rcnn_pred_list = [] + rcnn_target_list = [] + + bbox_pred = None + + self.cascade_var_v = [] + for stage in range(3): + var_v = np.array( + self.cascade_bbox_reg_weights[stage], dtype="float32") + prior_box_var = fluid.layers.create_tensor(dtype="float32") + fluid.layers.assign(input=var_v, output=prior_box_var) + self.cascade_var_v.append(prior_box_var) + + self.cascade_decoded_box = [] + self.cascade_cls_prob = [] + + for stage in range(3): + if stage > 0: + pool_rois = decoded_assign_box + else: + pool_rois = rpn_rois + + # extract roi features + roi_feat = self.roi_extractor(body_feats, pool_rois, + spatial_scale) + roi_feat_list.append(roi_feat) + + # bbox head + cls_score, bbox_pred = self.bbox_head.get_output( + roi_feat, + cls_agnostic_bbox_reg=self.bbox_head.num_classes, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[stage], + name='_' + str(stage + 1)) + + cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False) + + decoded_box, decoded_assign_box = fluid.layers.box_decoder_and_assign( + pool_rois, self.cascade_var_v[stage], bbox_pred, cls_prob, + self.bbox_clip) + + self.cascade_cls_prob.append(cls_prob) + self.cascade_decoded_box.append(decoded_box) + + rcnn_pred_list.append((cls_score, bbox_pred)) + + pred = self.bbox_head.get_prediction_cls_aware( + im_info, + im_shape, + self.cascade_cls_prob, + self.cascade_decoded_box, + self.cascade_bbox_reg_weights, + return_box_score=True) + + bbox_name = 'bbox_' + str(i) + score_name = 'score_' + str(i) + if 'flip' in im.name: + bbox_name += '_flip' + score_name += '_flip' + result[bbox_name] = pred['bbox'] + result[score_name] = pred['score'] + + return result + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + } + # yapf: enable + return inputs_def + + def build_inputs(self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', + 'is_crowd', 'gt_mask' + ], + multi_scale=False, + num_scales=-1, + use_flip=None, + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape) + fields = copy.deepcopy(fields) + if multi_scale: + ms_def, ms_fields = multiscale_def(image_shape, num_scales, + use_flip) + inputs_def.update(ms_def) + fields += ms_fields + self.im_info_names = ['image', 'im_info'] + ms_fields + + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def _input_check(self, require_fields, feed_vars): + for var in require_fields: + assert var in feed_vars, \ + "{} has no {} field".format(feed_vars, var) + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars, multi_scale=None): + if multi_scale: + return self.build_multi_scale(feed_vars) + return self.build(feed_vars, 'test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cornernet_squeeze.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cornernet_squeeze.py new file mode 100755 index 000000000..61e17f0ce --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/cornernet_squeeze.py @@ -0,0 +1,142 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid + +from ppdet.core.workspace import register +import numpy as np +from ppdet.utils.check import check_version + +__all__ = ['CornerNetSqueeze'] + + +def rescale_bboxes(bboxes, ratios, borders): + x1, y1, x2, y2 = fluid.layers.split(bboxes, 4) + x1 = x1 / ratios[:, 1] - borders[:, 2] + x2 = x2 / ratios[:, 1] - borders[:, 2] + y1 = y1 / ratios[:, 0] - borders[:, 0] + y2 = y2 / ratios[:, 0] - borders[:, 0] + return fluid.layers.concat([x1, y1, x2, y2], axis=2) + + +@register +class CornerNetSqueeze(object): + """ + """ + __category__ = 'architecture' + __inject__ = ['backbone', 'corner_head', 'fpn'] + __shared__ = ['num_classes'] + + def __init__(self, + backbone, + corner_head='CornerHead', + num_classes=80, + fpn=None): + check_version('1.8.0') + super(CornerNetSqueeze, self).__init__() + self.backbone = backbone + self.corner_head = corner_head + self.num_classes = num_classes + self.fpn = fpn + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + body_feats = self.backbone(im) + if self.fpn is not None: + body_feats, _ = self.fpn.get_output(body_feats) + body_feats = [list(body_feats.values())[-1]] + if mode == 'train': + target_vars = [ + 'tl_heatmaps', 'br_heatmaps', 'tag_masks', 'tl_regrs', + 'br_regrs', 'tl_tags', 'br_tags' + ] + target = {key: feed_vars[key] for key in target_vars} + self.corner_head.get_output(body_feats) + loss = self.corner_head.get_loss(target) + return loss + + elif mode == 'test': + ratios = feed_vars['ratios'] + borders = feed_vars['borders'] + bboxes, scores, tl_scores, br_scores, clses = self.corner_head.get_prediction( + body_feats[-1]) + bboxes = rescale_bboxes(bboxes, ratios, borders) + detections = fluid.layers.concat([clses, scores, bboxes], axis=2) + + detections = detections[0] + return {'bbox': detections} + + def _inputs_def(self, image_shape, output_size, max_tag_len): + im_shape = [None] + image_shape + C = self.num_classes + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'ratios': {'shape': [None, 2], 'dtype': 'float32', 'lod_level': 0}, + 'borders': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 0}, + 'tl_heatmaps': {'shape': [None, C, output_size, output_size], 'dtype': 'float32', 'lod_level': 0}, + 'br_heatmaps': {'shape': [None, C, output_size, output_size], 'dtype': 'float32', 'lod_level': 0}, + 'tl_regrs': {'shape': [None, max_tag_len, 2], 'dtype': 'float32', 'lod_level': 0}, + 'br_regrs': {'shape': [None, max_tag_len, 2], 'dtype': 'float32', 'lod_level': 0}, + 'tl_tags': {'shape': [None, max_tag_len], 'dtype': 'int64', 'lod_level': 0}, + 'br_tags': {'shape': [None, max_tag_len], 'dtype': 'int64', 'lod_level': 0}, + 'tag_masks': {'shape': [None, max_tag_len], 'dtype': 'int32', 'lod_level': 0}, + } + # yapf: enable + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_id', 'gt_box', 'gt_class', 'tl_heatmaps', + 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', + 'tag_masks' + ], # for train + output_size=64, + max_tag_len=128, + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape, output_size, max_tag_len) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=64, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, mode='train') + + def eval(self, feed_vars): + return self.build(feed_vars, mode='test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, mode='test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/efficientdet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/efficientdet.py new file mode 100755 index 000000000..17561b687 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/efficientdet.py @@ -0,0 +1,152 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division + +from collections import OrderedDict + +import paddle.fluid as fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +__all__ = ['EfficientDet'] + + +@register +class EfficientDet(object): + """ + EfficientDet architecture, see https://arxiv.org/abs/1911.09070 + + Args: + backbone (object): backbone instance + fpn (object): feature pyramid network instance + retina_head (object): `RetinaHead` instance + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'fpn', 'efficient_head', 'anchor_grid'] + + def __init__(self, + backbone, + fpn, + efficient_head, + anchor_grid, + box_loss_weight=50.): + super(EfficientDet, self).__init__() + self.backbone = backbone + self.fpn = fpn + self.efficient_head = efficient_head + self.anchor_grid = anchor_grid + self.box_loss_weight = box_loss_weight + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + if mode == 'train': + gt_labels = feed_vars['gt_label'] + gt_targets = feed_vars['gt_target'] + fg_num = feed_vars['fg_num'] + else: + im_info = feed_vars['im_info'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + body_feats = self.backbone(im) + if mixed_precision_enabled: + body_feats = [fluid.layers.cast(f, 'float32') for f in body_feats] + body_feats = self.fpn(body_feats) + + # XXX not used for training, but the parameters are needed when + # exporting inference model + anchors = self.anchor_grid() + + if mode == 'train': + loss = self.efficient_head.get_loss(body_feats, gt_labels, + gt_targets, fg_num) + loss_cls = loss['loss_cls'] + loss_bbox = loss['loss_bbox'] + total_loss = loss_cls + self.box_loss_weight * loss_bbox + loss.update({'loss': total_loss}) + return loss + else: + pred = self.efficient_head.get_prediction(body_feats, anchors, + im_info) + return pred + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + inputs_def = { + 'image': { + 'shape': im_shape, + 'dtype': 'float32' + }, + 'im_info': { + 'shape': [None, 3], + 'dtype': 'float32' + }, + 'im_id': { + 'shape': [None, 1], + 'dtype': 'int64' + }, + 'im_shape': { + 'shape': [None, 3], + 'dtype': 'float32' + }, + 'fg_num': { + 'shape': [None, 1], + 'dtype': 'int32' + }, + 'gt_label': { + 'shape': [None, None, 1], + 'dtype': 'int32' + }, + 'gt_target': { + 'shape': [None, None, 4], + 'dtype': 'float32' + }, + } + return inputs_def + + def build_inputs(self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_info', 'im_id', 'fg_num', 'gt_label', + 'gt_target' + ], + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars): + return self.build(feed_vars, 'test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/faceboxes.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/faceboxes.py new file mode 100755 index 000000000..2d8abe6b9 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/faceboxes.py @@ -0,0 +1,192 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import SSDOutputDecoder + +__all__ = ['FaceBoxes'] + + +@register +class FaceBoxes(object): + """ + FaceBoxes: A CPU Real-time Face Detector with High Accuracy. + see https://arxiv.org/abs/1708.05234 + + Args: + backbone (object): backbone instance + output_decoder (object): `SSDOutputDecoder` instance + densities (list|None): the densities of generated density prior boxes, + this attribute should be a list or tuple of integers. + fixed_sizes (list|None): the fixed sizes of generated density prior boxes, + this attribute should a list or tuple of same length with `densities`. + num_classes (int): number of output classes. + steps (list|None): step size of adjacent prior boxes on each feature map. + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'output_decoder'] + __shared__ = ['num_classes'] + + def __init__(self, + backbone="FaceBoxNet", + output_decoder=SSDOutputDecoder().__dict__, + densities=[[4, 2, 1], [1], [1]], + fixed_sizes=[[32., 64., 128.], [256.], [512.]], + num_classes=2, + steps=[8., 16., 32.]): + super(FaceBoxes, self).__init__() + self.backbone = backbone + self.num_classes = num_classes + self.output_decoder = output_decoder + if isinstance(output_decoder, dict): + self.output_decoder = SSDOutputDecoder(**output_decoder) + self.densities = densities + self.fixed_sizes = fixed_sizes + self.steps = steps + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + gt_class = feed_vars['gt_class'] + + body_feats = self.backbone(im) + locs, confs, box, box_var = self._multi_box_head( + inputs=body_feats, image=im, num_classes=self.num_classes) + + if mode == 'train': + loss = fluid.layers.ssd_loss( + locs, + confs, + gt_bbox, + gt_class, + box, + box_var, + overlap_threshold=0.35, + neg_overlap=0.35) + loss = fluid.layers.reduce_sum(loss) + return {'loss': loss} + else: + pred = self.output_decoder(locs, confs, box, box_var) + return {'bbox': pred} + + def _multi_box_head(self, inputs, image, num_classes=2): + def permute_and_reshape(input, last_dim): + trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1]) + compile_shape = [0, -1, last_dim] + return fluid.layers.reshape(trans, shape=compile_shape) + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + locs, confs = [], [] + boxes, vars = [], [] + b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.)) + + for i, input in enumerate(inputs): + densities = self.densities[i] + fixed_sizes = self.fixed_sizes[i] + box, var = fluid.layers.density_prior_box( + input, + image, + densities=densities, + fixed_sizes=fixed_sizes, + fixed_ratios=[1.], + clip=False, + offset=0.5, + steps=[self.steps[i]] * 2) + + num_boxes = box.shape[2] + + box = fluid.layers.reshape(box, shape=[-1, 4]) + var = fluid.layers.reshape(var, shape=[-1, 4]) + num_loc_output = num_boxes * 4 + num_conf_output = num_boxes * num_classes + # get loc + mbox_loc = fluid.layers.conv2d( + input, num_loc_output, 3, 1, 1, bias_attr=b_attr) + loc = permute_and_reshape(mbox_loc, 4) + # get conf + mbox_conf = fluid.layers.conv2d( + input, num_conf_output, 3, 1, 1, bias_attr=b_attr) + conf = permute_and_reshape(mbox_conf, 2) + + locs.append(loc) + confs.append(conf) + boxes.append(box) + vars.append(var) + + face_mbox_loc = fluid.layers.concat(locs, axis=1) + face_mbox_conf = fluid.layers.concat(confs, axis=1) + prior_boxes = fluid.layers.concat(boxes) + box_vars = fluid.layers.concat(vars) + return face_mbox_loc, face_mbox_conf, prior_boxes, box_vars + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'im_shape': {'shape': [None, 3], 'dtype': 'int32', 'lod_level': 0}, + } + # yapf: enable + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=['image', 'im_id', 'gt_bbox', 'gt_class'], # for train + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars): + return self.build(feed_vars, 'eval') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') + + def is_bbox_normalized(self): + return True diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/faster_rcnn.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/faster_rcnn.py new file mode 100755 index 000000000..13bffd83a --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/faster_rcnn.py @@ -0,0 +1,250 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import copy + +from paddle import fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +from .input_helper import multiscale_def + +__all__ = ['FasterRCNN'] + + +@register +class FasterRCNN(object): + """ + Faster R-CNN architecture, see https://arxiv.org/abs/1506.01497 + Args: + backbone (object): backbone instance + rpn_head (object): `RPNhead` instance + bbox_assigner (object): `BBoxAssigner` instance + roi_extractor (object): ROI extractor instance + bbox_head (object): `BBoxHead` instance + fpn (object): feature pyramid network instance + """ + + __category__ = 'architecture' + __inject__ = [ + 'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head', + 'fpn' + ] + + def __init__(self, + backbone, + rpn_head, + roi_extractor, + bbox_head='BBoxHead', + bbox_assigner='BBoxAssigner', + rpn_only=False, + fpn=None): + super(FasterRCNN, self).__init__() + self.backbone = backbone + self.rpn_head = rpn_head + self.bbox_assigner = bbox_assigner + self.roi_extractor = roi_extractor + self.bbox_head = bbox_head + self.fpn = fpn + self.rpn_only = rpn_only + + def build(self, feed_vars, mode='train'): + if mode == 'train': + required_fields = ['gt_class', 'gt_bbox', 'is_crowd', 'im_info'] + else: + required_fields = ['im_shape', 'im_info'] + self._input_check(required_fields, feed_vars) + + im = feed_vars['image'] + im_info = feed_vars['im_info'] + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + is_crowd = feed_vars['is_crowd'] + else: + im_shape = feed_vars['im_shape'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + body_feats = self.backbone(im) + body_feat_names = list(body_feats.keys()) + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode) + + if mode == 'train': + rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd) + # sampled rpn proposals + for var in ['gt_class', 'is_crowd', 'gt_bbox', 'im_info']: + assert var in feed_vars, "{} has no {}".format(feed_vars, var) + outs = self.bbox_assigner( + rpn_rois=rois, + gt_classes=feed_vars['gt_class'], + is_crowd=feed_vars['is_crowd'], + gt_boxes=feed_vars['gt_bbox'], + im_info=feed_vars['im_info']) + + rois = outs[0] + labels_int32 = outs[1] + bbox_targets = outs[2] + bbox_inside_weights = outs[3] + bbox_outside_weights = outs[4] + else: + if self.rpn_only: + im_scale = fluid.layers.slice( + im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, rois) + rois = rois / im_scale + return {'proposal': rois} + if self.fpn is None: + # in models without FPN, roi extractor only uses the last level of + # feature maps. And body_feat_names[-1] represents the name of + # last feature map. + body_feat = body_feats[body_feat_names[-1]] + roi_feat = self.roi_extractor(body_feat, rois) + else: + roi_feat = self.roi_extractor(body_feats, rois, spatial_scale) + + if mode == 'train': + loss = self.bbox_head.get_loss(roi_feat, labels_int32, bbox_targets, + bbox_inside_weights, + bbox_outside_weights) + loss.update(rpn_loss) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + pred = self.bbox_head.get_prediction(roi_feat, rois, im_info, + im_shape) + return pred + + def build_multi_scale(self, feed_vars): + required_fields = ['image', 'im_info', 'im_shape'] + self._input_check(required_fields, feed_vars) + + result = {} + im_shape = feed_vars['im_shape'] + result['im_shape'] = im_shape + for i in range(len(self.im_info_names) // 2): + im = feed_vars[self.im_info_names[2 * i]] + im_info = feed_vars[self.im_info_names[2 * i + 1]] + body_feats = self.backbone(im) + body_feat_names = list(body_feats.keys()) + + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + rois = self.rpn_head.get_proposals(body_feats, im_info, mode='test') + + if self.fpn is None: + # in models without FPN, roi extractor only uses the last level of + # feature maps. And body_feat_names[-1] represents the name of + # last feature map. + body_feat = body_feats[body_feat_names[-1]] + roi_feat = self.roi_extractor(body_feat, rois) + else: + roi_feat = self.roi_extractor(body_feats, rois, spatial_scale) + + pred = self.bbox_head.get_prediction( + roi_feat, rois, im_info, im_shape, return_box_score=True) + bbox_name = 'bbox_' + str(i) + score_name = 'score_' + str(i) + if 'flip' in im.name: + bbox_name += '_flip' + score_name += '_flip' + result[bbox_name] = pred['bbox'] + result[score_name] = pred['score'] + return result + + def _input_check(self, require_fields, feed_vars): + for var in require_fields: + assert var in feed_vars, \ + "{} has no {} field".format(feed_vars, var) + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + } + # yapf: enable + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd' + ], # for train + multi_scale=False, + num_scales=-1, + use_flip=None, + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape) + fields = copy.deepcopy(fields) + if multi_scale: + ms_def, ms_fields = multiscale_def(image_shape, num_scales, + use_flip) + inputs_def.update(ms_def) + fields += ms_fields + self.im_info_names = ['image', 'im_info'] + ms_fields + + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars, multi_scale=None): + if multi_scale: + return self.build_multi_scale(feed_vars) + return self.build(feed_vars, 'test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/fcos.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/fcos.py new file mode 100755 index 000000000..055d4db43 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/fcos.py @@ -0,0 +1,185 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +import paddle.fluid as fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +__all__ = ['FCOS'] + + +@register +class FCOS(object): + """ + FCOS architecture, see https://arxiv.org/abs/1904.01355 + + Args: + backbone (object): backbone instance + fpn (object): feature pyramid network instance + fcos_head (object): `FCOSHead` instance + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'fpn', 'fcos_head'] + + def __init__(self, backbone, fpn, fcos_head): + super(FCOS, self).__init__() + self.backbone = backbone + self.fpn = fpn + self.fcos_head = fcos_head + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + im_info = feed_vars['im_info'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + # backbone + body_feats = self.backbone(im) + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + + # FPN + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + # fcosnet head + if mode == 'train': + tag_labels = [] + tag_bboxes = [] + tag_centerness = [] + for i in range(len(self.fcos_head.fpn_stride)): + # reg_target, labels, scores, centerness + k_lbl = 'labels{}'.format(i) + if k_lbl in feed_vars: + tag_labels.append(feed_vars[k_lbl]) + k_box = 'reg_target{}'.format(i) + if k_box in feed_vars: + tag_bboxes.append(feed_vars[k_box]) + k_ctn = 'centerness{}'.format(i) + if k_ctn in feed_vars: + tag_centerness.append(feed_vars[k_ctn]) + # tag_labels, tag_bboxes, tag_centerness + loss = self.fcos_head.get_loss(body_feats, tag_labels, tag_bboxes, + tag_centerness) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + pred = self.fcos_head.get_prediction(body_feats, im_info) + return pred + + def _inputs_def(self, image_shape, fields): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'gt_score': {'shape': [None, 1], 'dtype': 'float32', 'lod_level': 1}, + 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1} + } + # yapf: disable + if 'fcos_target' in fields: + targets_def = { + 'labels0': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0}, + 'reg_target0': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0}, + 'centerness0': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0}, + 'labels1': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0}, + 'reg_target1': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0}, + 'centerness1': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0}, + 'labels2': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0}, + 'reg_target2': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0}, + 'centerness2': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0}, + 'labels3': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0}, + 'reg_target3': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0}, + 'centerness3': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0}, + 'labels4': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0}, + 'reg_target4': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0}, + 'centerness4': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0}, + } + # yapf: enable + + # downsample = 128 + for k, stride in enumerate(self.fcos_head.fpn_stride): + k_lbl = 'labels{}'.format(k) + k_box = 'reg_target{}'.format(k) + k_ctn = 'centerness{}'.format(k) + grid_y = image_shape[-2] // stride if image_shape[-2] else None + grid_x = image_shape[-1] // stride if image_shape[-1] else None + if grid_x is not None: + num_pts = grid_x * grid_y + num_dim2 = 1 + else: + num_pts = None + num_dim2 = None + targets_def[k_lbl]['shape'][1] = num_pts + targets_def[k_box]['shape'][1] = num_pts + targets_def[k_ctn]['shape'][1] = num_pts + targets_def[k_lbl]['shape'][2] = num_dim2 + targets_def[k_box]['shape'][2] = num_dim2 + targets_def[k_ctn]['shape'][2] = num_dim2 + inputs_def.update(targets_def) + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=['image', 'im_info', 'fcos_target'], # for-train + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape, fields) + if "fcos_target" in fields: + for i in range(len(self.fcos_head.fpn_stride)): + fields.extend( + ['labels%d' % i, 'reg_target%d' % i, 'centerness%d' % i]) + fields.remove('fcos_target') + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars): + return self.build(feed_vars, 'test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/htc.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/htc.py new file mode 100755 index 000000000..c20822f85 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/htc.py @@ -0,0 +1,475 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import copy +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import MSRA +from paddle.fluid.regularizer import L2Decay +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register +from ppdet.utils.check import check_version + +from .input_helper import multiscale_def + +__all__ = ['HybridTaskCascade'] + + +@register +class HybridTaskCascade(object): + """ + Hybrid Task Cascade Mask R-CNN architecture, see https://arxiv.org/abs/1901.07518 + + Args: + backbone (object): backbone instance + rpn_head (object): `RPNhead` instance + bbox_assigner (object): `BBoxAssigner` instance + roi_extractor (object): ROI extractor instance + bbox_head (object): `HTCBBoxHead` instance + mask_assigner (object): `MaskAssigner` instance + mask_head (object): `HTCMaskHead` instance + fpn (object): feature pyramid network instance + semantic_roi_extractor(object): ROI extractor instance + fused_semantic_head (object): `FusedSemanticHead` instance + """ + + __category__ = 'architecture' + __inject__ = [ + 'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head', + 'mask_assigner', 'mask_head', 'fpn', 'semantic_roi_extractor', + 'fused_semantic_head' + ] + + def __init__(self, + backbone, + rpn_head, + roi_extractor='FPNRoIAlign', + semantic_roi_extractor='RoIAlign', + fused_semantic_head='FusedSemanticHead', + bbox_head='HTCBBoxHead', + bbox_assigner='CascadeBBoxAssigner', + mask_assigner='MaskAssigner', + mask_head='HTCMaskHead', + rpn_only=False, + fpn='FPN'): + super(HybridTaskCascade, self).__init__() + check_version('2.0.0-rc0') + assert fpn is not None, "HTC requires FPN" + self.backbone = backbone + self.fpn = fpn + self.rpn_head = rpn_head + self.bbox_assigner = bbox_assigner + self.roi_extractor = roi_extractor + self.semantic_roi_extractor = semantic_roi_extractor + self.fused_semantic_head = fused_semantic_head + self.bbox_head = bbox_head + self.mask_assigner = mask_assigner + self.mask_head = mask_head + self.rpn_only = rpn_only + # Cascade local cfg + self.cls_agnostic_bbox_reg = 2 + (brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights + self.cascade_bbox_reg_weights = [ + [1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0], + [1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1], + [1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2] + ] + self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25] + self.num_stage = 3 + self.with_mask = True + self.interleaved = True + self.mask_info_flow = True + self.with_semantic = True + self.use_bias_scalar = True + + def build(self, feed_vars, mode='train'): + if mode == 'train': + required_fields = [ + 'gt_class', 'gt_bbox', 'gt_mask', 'is_crowd', 'im_info', + 'semantic' + ] + else: + required_fields = ['im_shape', 'im_info'] + self._input_check(required_fields, feed_vars) + + im = feed_vars['image'] + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + is_crowd = feed_vars['is_crowd'] + + im_info = feed_vars['im_info'] + + # backbone + body_feats = self.backbone(im) + + loss = {} + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + if self.with_semantic: + # TODO: use cfg + semantic_feat, seg_pred = self.fused_semantic_head.get_out( + body_feats) + if mode == 'train': + s_label = feed_vars['semantic'] + semantic_loss = self.fused_semantic_head.get_loss(seg_pred, + s_label) * 0.2 + loss.update({"semantic_loss": semantic_loss}) + else: + semantic_feat = None + + # rpn proposals + rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode) + if mode == 'train': + rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd) + loss.update(rpn_loss) + else: + if self.rpn_only: + im_scale = fluid.layers.slice( + im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, rpn_rois) + rois = rpn_rois / im_scale + return {'proposal': rois} + + proposal_list = [] + roi_feat_list = [] + rcnn_pred_list = [] + rcnn_target_list = [] + mask_logits_list = [] + mask_target_list = [] + proposals = None + bbox_pred = None + outs = None + refined_bbox = rpn_rois + max_overlap = None + for i in range(self.num_stage): + # BBox Branch + if mode == 'train': + outs = self.bbox_assigner( + input_rois=refined_bbox, + feed_vars=feed_vars, + curr_stage=i, + max_overlap=max_overlap) + proposals = outs[0] + max_overlap = outs[-1] + rcnn_target_list.append(outs[:-1]) + else: + proposals = refined_bbox + proposal_list.append(proposals) + + # extract roi features + roi_feat = self.roi_extractor(body_feats, proposals, spatial_scale) + if self.with_semantic: + semantic_roi_feat = self.semantic_roi_extractor(semantic_feat, + proposals) + if semantic_roi_feat is not None: + semantic_roi_feat = fluid.layers.pool2d( + semantic_roi_feat, + pool_size=2, + pool_stride=2, + pool_padding='SAME') + roi_feat = fluid.layers.sum([roi_feat, semantic_roi_feat]) + roi_feat_list.append(roi_feat) + + # bbox head + cls_score, bbox_pred = self.bbox_head.get_output( + roi_feat, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i], + name='_' + str(i)) + rcnn_pred_list.append((cls_score, bbox_pred)) + + # Mask Branch + if self.with_mask: + if mode == 'train': + labels_int32 = outs[1] + if self.interleaved: + refined_bbox = self._decode_box( + proposals, bbox_pred, curr_stage=i) + proposals = refined_bbox + + mask_rois, roi_has_mask_int32, mask_int32 = self.mask_assigner( + rois=proposals, + gt_classes=feed_vars['gt_class'], + is_crowd=feed_vars['is_crowd'], + gt_segms=feed_vars['gt_mask'], + im_info=feed_vars['im_info'], + labels_int32=labels_int32) + mask_target_list.append(mask_int32) + + mask_feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + + if self.with_semantic: + semantic_roi_feat = self.semantic_roi_extractor( + semantic_feat, mask_rois) + if semantic_roi_feat is not None: + mask_feat = fluid.layers.sum( + [mask_feat, semantic_roi_feat]) + + if self.mask_info_flow: + last_feat = None + for j in range(i): + last_feat = self.mask_head.get_output( + mask_feat, + last_feat, + return_logits=False, + return_feat=True, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i) + '_' + str(j)) + mask_logits = self.mask_head.get_output( + mask_feat, + last_feat, + return_logits=True, + return_feat=False, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i)) + else: + mask_logits = self.mask_head.get_output( + mask_feat, + return_logits=True, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i)) + mask_logits_list.append(mask_logits) + + if i < self.num_stage - 1 and not self.interleaved: + refined_bbox = self._decode_box( + proposals, bbox_pred, curr_stage=i) + elif i < self.num_stage - 1 and mode != 'train': + refined_bbox = self._decode_box( + proposals, bbox_pred, curr_stage=i) + + if mode == 'train': + bbox_loss = self.bbox_head.get_loss( + rcnn_pred_list, rcnn_target_list, self.cascade_rcnn_loss_weight) + loss.update(bbox_loss) + mask_loss = self.mask_head.get_loss(mask_logits_list, + mask_target_list, + self.cascade_rcnn_loss_weight) + loss.update(mask_loss) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + mask_name = 'mask_pred' + mask_pred, bbox_pred = self.single_scale_eval( + body_feats, + spatial_scale, + im_info, + mask_name, + bbox_pred, + roi_feat_list, + rcnn_pred_list, + proposal_list, + feed_vars['im_shape'], + semantic_feat=semantic_feat if self.with_semantic else None) + return {'bbox': bbox_pred, 'mask': mask_pred} + + def single_scale_eval(self, + body_feats, + spatial_scale, + im_info, + mask_name, + bbox_pred, + roi_feat_list=None, + rcnn_pred_list=None, + proposal_list=None, + im_shape=None, + use_multi_test=False, + semantic_feat=None): + + if not use_multi_test: + bbox_pred = self.bbox_head.get_prediction( + im_info, im_shape, roi_feat_list, rcnn_pred_list, proposal_list, + self.cascade_bbox_reg_weights) + bbox_pred = bbox_pred['bbox'] + + # share weight + bbox_shape = fluid.layers.shape(bbox_pred) + bbox_size = fluid.layers.reduce_prod(bbox_shape) + bbox_size = fluid.layers.reshape(bbox_size, [1, 1]) + size = fluid.layers.fill_constant([1, 1], value=6, dtype='int32') + cond = fluid.layers.less_than(x=bbox_size, y=size) + + mask_pred = fluid.layers.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=False, + name=mask_name) + + def noop(): + fluid.layers.assign(input=bbox_pred, output=mask_pred) + + def process_boxes(): + bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6]) + + im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, bbox) + + bbox = fluid.layers.cast(bbox, dtype='float32') + im_scale = fluid.layers.cast(im_scale, dtype='float32') + mask_rois = bbox * im_scale + + mask_feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + + if self.with_semantic: + semantic_roi_feat = self.semantic_roi_extractor(semantic_feat, + mask_rois) + if semantic_roi_feat is not None: + mask_feat = fluid.layers.sum([mask_feat, semantic_roi_feat]) + + mask_logits_list = [] + mask_pred_list = [] + for i in range(self.num_stage): + if self.mask_info_flow: + last_feat = None + for j in range(i): + last_feat = self.mask_head.get_output( + mask_feat, + last_feat, + return_logits=False, + return_feat=True, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i) + '_' + str(j)) + mask_logits = self.mask_head.get_output( + mask_feat, + last_feat, + return_logits=True, + return_feat=False, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i)) + mask_logits_list.append(mask_logits) + else: + mask_logits = self.mask_head.get_output( + mask_feat, + return_logits=True, + return_feat=False, + name='_' + str(i)) + mask_pred_out = self.mask_head.get_prediction(mask_logits, bbox) + mask_pred_list.append(mask_pred_out) + + mask_pred_out = fluid.layers.sum(mask_pred_list) / float( + len(mask_pred_list)) + fluid.layers.assign(input=mask_pred_out, output=mask_pred) + + fluid.layers.cond(cond, noop, process_boxes) + return mask_pred, bbox_pred + + def _input_check(self, require_fields, feed_vars): + for var in require_fields: + assert var in feed_vars, \ + "{} has no {} field".format(feed_vars, var) + + def _decode_box(self, proposals, bbox_pred, curr_stage): + rcnn_loc_delta_r = fluid.layers.reshape( + bbox_pred, (-1, self.cls_agnostic_bbox_reg, 4)) + # only use fg box delta to decode box + rcnn_loc_delta_s = fluid.layers.slice( + rcnn_loc_delta_r, axes=[1], starts=[1], ends=[2]) + refined_bbox = fluid.layers.box_coder( + prior_box=proposals, + prior_box_var=self.cascade_bbox_reg_weights[curr_stage], + target_box=rcnn_loc_delta_s, + code_type='decode_center_size', + box_normalized=False, + axis=1, ) + refined_bbox = fluid.layers.reshape(refined_bbox, shape=[-1, 4]) + + return refined_bbox + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'gt_mask': {'shape': [None, 2], 'dtype': 'float32', 'lod_level': 3}, # polygon coordinates + 'semantic': {'shape': [None, 1, None, None], 'dtype': 'int32', 'lod_level': 0}, + } + # yapf: enable + return inputs_def + + def build_inputs(self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', + 'is_crowd', 'gt_mask', 'semantic' + ], + multi_scale=False, + num_scales=-1, + use_flip=None, + use_dataloader=True, + iterable=False, + mask_branch=False): + inputs_def = self._inputs_def(image_shape) + fields = copy.deepcopy(fields) + if multi_scale: + ms_def, ms_fields = multiscale_def(image_shape, num_scales, + use_flip) + inputs_def.update(ms_def) + fields += ms_fields + self.im_info_names = ['image', 'im_info'] + ms_fields + if mask_branch: + box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox'] + for key in box_fields: + inputs_def[key] = { + 'shape': [6], + 'dtype': 'float32', + 'lod_level': 1 + } + fields += box_fields + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + use_dataloader = use_dataloader and not mask_branch + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=64, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars, multi_scale=None, mask_branch=False): + if multi_scale: + return self.build_multi_scale(feed_vars, mask_branch) + return self.build(feed_vars, 'test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/input_helper.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/input_helper.py new file mode 100755 index 000000000..a6c961c89 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/input_helper.py @@ -0,0 +1,51 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def multiscale_def(image_shape, num_scale, use_flip=True): + base_name_list = ['image'] + multiscale_def = {} + ms_def_names = [] + if use_flip: + num_scale //= 2 + base_name_list.append('image_flip') + multiscale_def['image_flip'] = { + 'shape': [None] + image_shape, + 'dtype': 'float32', + 'lod_level': 0 + } + multiscale_def['im_info_image_flip'] = { + 'shape': [None, 3], + 'dtype': 'float32', + 'lod_level': 0 + } + ms_def_names.append('image_flip') + ms_def_names.append('im_info_image_flip') + for base_name in base_name_list: + for i in range(0, num_scale - 1): + name = base_name + '_scale_' + str(i) + multiscale_def[name] = { + 'shape': [None] + image_shape, + 'dtype': 'float32', + 'lod_level': 0 + } + im_info_name = 'im_info_' + name + multiscale_def[im_info_name] = { + 'shape': [None, 3], + 'dtype': 'float32', + 'lod_level': 0 + } + ms_def_names.append(name) + ms_def_names.append(im_info_name) + return multiscale_def, ms_def_names diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/mask_rcnn.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/mask_rcnn.py new file mode 100755 index 000000000..0277f64c6 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/mask_rcnn.py @@ -0,0 +1,343 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import copy + +import paddle.fluid as fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +from .input_helper import multiscale_def + +__all__ = ['MaskRCNN'] + + +@register +class MaskRCNN(object): + """ + Mask R-CNN architecture, see https://arxiv.org/abs/1703.06870 + Args: + backbone (object): backbone instance + rpn_head (object): `RPNhead` instance + bbox_assigner (object): `BBoxAssigner` instance + roi_extractor (object): ROI extractor instance + bbox_head (object): `BBoxHead` instance + mask_assigner (object): `MaskAssigner` instance + mask_head (object): `MaskHead` instance + fpn (object): feature pyramid network instance + """ + + __category__ = 'architecture' + __inject__ = [ + 'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head', + 'mask_assigner', 'mask_head', 'fpn' + ] + + def __init__(self, + backbone, + rpn_head, + bbox_head='BBoxHead', + bbox_assigner='BBoxAssigner', + roi_extractor='RoIAlign', + mask_assigner='MaskAssigner', + mask_head='MaskHead', + rpn_only=False, + fpn=None): + super(MaskRCNN, self).__init__() + self.backbone = backbone + self.rpn_head = rpn_head + self.bbox_assigner = bbox_assigner + self.roi_extractor = roi_extractor + self.bbox_head = bbox_head + self.mask_assigner = mask_assigner + self.mask_head = mask_head + self.rpn_only = rpn_only + self.fpn = fpn + + def build(self, feed_vars, mode='train'): + if mode == 'train': + required_fields = [ + 'gt_class', 'gt_bbox', 'gt_mask', 'is_crowd', 'im_info' + ] + else: + required_fields = ['im_shape', 'im_info'] + self._input_check(required_fields, feed_vars) + im = feed_vars['image'] + im_info = feed_vars['im_info'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + # backbone + body_feats = self.backbone(im) + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + + # FPN + spatial_scale = None + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + # RPN proposals + rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode) + + if mode == 'train': + rpn_loss = self.rpn_head.get_loss(im_info, feed_vars['gt_bbox'], + feed_vars['is_crowd']) + + outs = self.bbox_assigner( + rpn_rois=rois, + gt_classes=feed_vars['gt_class'], + is_crowd=feed_vars['is_crowd'], + gt_boxes=feed_vars['gt_bbox'], + im_info=feed_vars['im_info']) + rois = outs[0] + labels_int32 = outs[1] + + if self.fpn is None: + last_feat = body_feats[list(body_feats.keys())[-1]] + roi_feat = self.roi_extractor(last_feat, rois) + else: + roi_feat = self.roi_extractor(body_feats, rois, spatial_scale) + + loss = self.bbox_head.get_loss(roi_feat, labels_int32, *outs[2:]) + loss.update(rpn_loss) + + mask_rois, roi_has_mask_int32, mask_int32 = self.mask_assigner( + rois=rois, + gt_classes=feed_vars['gt_class'], + is_crowd=feed_vars['is_crowd'], + gt_segms=feed_vars['gt_mask'], + im_info=feed_vars['im_info'], + labels_int32=labels_int32) + if self.fpn is None: + bbox_head_feat = self.bbox_head.get_head_feat() + feat = fluid.layers.gather(bbox_head_feat, roi_has_mask_int32) + else: + feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + + mask_loss = self.mask_head.get_loss(feat, mask_int32) + loss.update(mask_loss) + + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + + else: + if self.rpn_only: + im_scale = fluid.layers.slice( + im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, rois) + rois = rois / im_scale + return {'proposal': rois} + mask_name = 'mask_pred' + mask_pred, bbox_pred = self.single_scale_eval( + body_feats, mask_name, rois, im_info, feed_vars['im_shape'], + spatial_scale) + return {'bbox': bbox_pred, 'mask': mask_pred} + + def build_multi_scale(self, feed_vars, mask_branch=False): + required_fields = ['image', 'im_info'] + self._input_check(required_fields, feed_vars) + + result = {} + if not mask_branch: + assert 'im_shape' in feed_vars, \ + "{} has no im_shape field".format(feed_vars) + result.update(feed_vars) + + for i in range(len(self.im_info_names) // 2): + im = feed_vars[self.im_info_names[2 * i]] + im_info = feed_vars[self.im_info_names[2 * i + 1]] + body_feats = self.backbone(im) + + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + rois = self.rpn_head.get_proposals(body_feats, im_info, mode='test') + if not mask_branch: + im_shape = feed_vars['im_shape'] + body_feat_names = list(body_feats.keys()) + if self.fpn is None: + body_feat = body_feats[body_feat_names[-1]] + roi_feat = self.roi_extractor(body_feat, rois) + else: + roi_feat = self.roi_extractor(body_feats, rois, + spatial_scale) + pred = self.bbox_head.get_prediction( + roi_feat, rois, im_info, im_shape, return_box_score=True) + bbox_name = 'bbox_' + str(i) + score_name = 'score_' + str(i) + if 'flip' in im.name: + bbox_name += '_flip' + score_name += '_flip' + result[bbox_name] = pred['bbox'] + result[score_name] = pred['score'] + else: + mask_name = 'mask_pred_' + str(i) + bbox_pred = feed_vars['bbox'] + #result.update({im.name: im}) + if 'flip' in im.name: + mask_name += '_flip' + bbox_pred = feed_vars['bbox_flip'] + mask_pred, bbox_pred = self.single_scale_eval( + body_feats, mask_name, rois, im_info, feed_vars['im_shape'], + spatial_scale, bbox_pred) + result[mask_name] = mask_pred + return result + + def single_scale_eval(self, + body_feats, + mask_name, + rois, + im_info, + im_shape, + spatial_scale, + bbox_pred=None): + if not bbox_pred: + if self.fpn is None: + last_feat = body_feats[list(body_feats.keys())[-1]] + roi_feat = self.roi_extractor(last_feat, rois) + else: + roi_feat = self.roi_extractor(body_feats, rois, spatial_scale) + bbox_pred = self.bbox_head.get_prediction(roi_feat, rois, im_info, + im_shape) + bbox_pred = bbox_pred['bbox'] + + # share weight + bbox_shape = fluid.layers.shape(bbox_pred) + bbox_size = fluid.layers.reduce_prod(bbox_shape) + bbox_size = fluid.layers.reshape(bbox_size, [1, 1]) + size = fluid.layers.fill_constant([1, 1], value=6, dtype='int32') + cond = fluid.layers.less_than(x=bbox_size, y=size) + + mask_pred = fluid.layers.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=False, + name=mask_name) + + def noop(): + fluid.layers.assign(input=bbox_pred, output=mask_pred) + + def process_boxes(): + bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6]) + + im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, bbox) + + mask_rois = bbox * im_scale + if self.fpn is None: + last_feat = body_feats[list(body_feats.keys())[-1]] + mask_feat = self.roi_extractor(last_feat, mask_rois) + mask_feat = self.bbox_head.get_head_feat(mask_feat) + else: + mask_feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + + mask_out = self.mask_head.get_prediction(mask_feat, bbox) + fluid.layers.assign(input=mask_out, output=mask_pred) + + fluid.layers.cond(cond, noop, process_boxes) + return mask_pred, bbox_pred + + def _input_check(self, require_fields, feed_vars): + for var in require_fields: + assert var in feed_vars, \ + "{} has no {} field".format(feed_vars, var) + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'gt_mask': {'shape': [None, 2], 'dtype': 'float32', 'lod_level': 3}, # polygon coordinates + 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + } + # yapf: enable + return inputs_def + + def build_inputs(self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', + 'is_crowd', 'gt_mask' + ], + multi_scale=False, + num_scales=-1, + use_flip=None, + use_dataloader=True, + iterable=False, + mask_branch=False): + inputs_def = self._inputs_def(image_shape) + fields = copy.deepcopy(fields) + if multi_scale: + ms_def, ms_fields = multiscale_def(image_shape, num_scales, + use_flip) + inputs_def.update(ms_def) + fields += ms_fields + self.im_info_names = ['image', 'im_info'] + ms_fields + if mask_branch: + box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox'] + for key in box_fields: + inputs_def[key] = { + 'shape': [None, 6], + 'dtype': 'float32', + 'lod_level': 1 + } + fields += box_fields + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + use_dataloader = use_dataloader and not mask_branch + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars, multi_scale=None, mask_branch=False): + if multi_scale: + return self.build_multi_scale(feed_vars, mask_branch) + return self.build(feed_vars, 'test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/retinanet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/retinanet.py new file mode 100755 index 000000000..e89bdeae1 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/retinanet.py @@ -0,0 +1,131 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +import paddle.fluid as fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +__all__ = ['RetinaNet'] + + +@register +class RetinaNet(object): + """ + RetinaNet architecture, see https://arxiv.org/abs/1708.02002 + + Args: + backbone (object): backbone instance + fpn (object): feature pyramid network instance + retina_head (object): `RetinaHead` instance + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'fpn', 'retina_head'] + + def __init__(self, backbone, fpn, retina_head): + super(RetinaNet, self).__init__() + self.backbone = backbone + self.fpn = fpn + self.retina_head = retina_head + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + im_info = feed_vars['im_info'] + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + gt_class = feed_vars['gt_class'] + is_crowd = feed_vars['is_crowd'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + # backbone + body_feats = self.backbone(im) + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32')) + for k, v in body_feats.items()) + + # FPN + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + # retinanet head + if mode == 'train': + loss = self.retina_head.get_loss(body_feats, spatial_scale, im_info, + gt_bbox, gt_class, is_crowd) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + pred = self.retina_head.get_prediction(body_feats, spatial_scale, + im_info) + return pred + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + } + # yapf: enable + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd' + ], # for-train + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars): + return self.build(feed_vars, 'test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/solov2.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/solov2.py new file mode 100755 index 000000000..6c590e4a9 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/solov2.py @@ -0,0 +1,175 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register +from ppdet.utils.check import check_version + +__all__ = ['SOLOv2'] + + +@register +class SOLOv2(object): + """ + SOLOv2 network, see https://arxiv.org/abs/2003.10152 + + Args: + backbone (object): an backbone instance + fpn (object): feature pyramid network instance + bbox_head (object): an `SOLOv2Head` instance + mask_head (object): an `SOLOv2MaskHead` instance + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'fpn', 'bbox_head', 'mask_head'] + + def __init__(self, + backbone, + fpn=None, + bbox_head='SOLOv2Head', + mask_head='SOLOv2MaskHead'): + super(SOLOv2, self).__init__() + check_version('2.0.0-rc0') + self.backbone = backbone + self.fpn = fpn + self.bbox_head = bbox_head + self.mask_head = mask_head + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + body_feats = self.backbone(im) + + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + if isinstance(body_feats, OrderedDict): + body_feat_names = list(body_feats.keys()) + body_feats = [body_feats[name] for name in body_feat_names] + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats] + + mask_feat_pred = self.mask_head.get_output(body_feats) + + if mode == 'train': + ins_labels = [] + cate_labels = [] + grid_orders = [] + fg_num = feed_vars['fg_num'] + + for i in range(self.num_level): + ins_label = 'ins_label{}'.format(i) + if ins_label in feed_vars: + ins_labels.append(feed_vars[ins_label]) + cate_label = 'cate_label{}'.format(i) + if cate_label in feed_vars: + cate_labels.append(feed_vars[cate_label]) + grid_order = 'grid_order{}'.format(i) + if grid_order in feed_vars: + grid_orders.append(feed_vars[grid_order]) + + cate_preds, kernel_preds = self.bbox_head.get_outputs(body_feats) + + losses = self.bbox_head.get_loss(cate_preds, kernel_preds, + mask_feat_pred, ins_labels, + cate_labels, grid_orders, fg_num) + total_loss = fluid.layers.sum(list(losses.values())) + losses.update({'loss': total_loss}) + return losses + else: + im_info = feed_vars['im_info'] + outs = self.bbox_head.get_outputs(body_feats, is_eval=True) + seg_inputs = outs + (mask_feat_pred, im_info) + return self.bbox_head.get_prediction(*seg_inputs) + + def _inputs_def(self, image_shape, fields): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + } + + if 'gt_segm' in fields: + for i in range(self.num_level): + targets_def = { + 'ins_label%d' % i: {'shape': [None, None, None], 'dtype': 'int32', 'lod_level': 1}, + 'cate_label%d' % i: {'shape': [None], 'dtype': 'int32', 'lod_level': 1}, + 'grid_order%d' % i: {'shape': [None], 'dtype': 'int32', 'lod_level': 1}, + } + inputs_def.update(targets_def) + targets_def = { + 'fg_num': {'shape': [None], 'dtype': 'int32', 'lod_level': 0}, + } + # yapf: enable + inputs_def.update(targets_def) + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=['image', 'im_id', 'gt_segm'], # for train + num_level=5, + use_dataloader=True, + iterable=False): + self.num_level = num_level + inputs_def = self._inputs_def(image_shape, fields) + if 'gt_segm' in fields: + fields.remove('gt_segm') + fields.extend(['fg_num']) + for i in range(num_level): + fields.extend([ + 'ins_label%d' % i, 'cate_label%d' % i, 'grid_order%d' % i + ]) + + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, mode='train') + + def eval(self, feed_vars): + return self.build(feed_vars, mode='test') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, mode='test') diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/ssd.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/ssd.py new file mode 100755 index 000000000..8f749bbfe --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/ssd.py @@ -0,0 +1,145 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +import paddle.fluid as fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register +from ppdet.modeling.ops import SSDOutputDecoder + +__all__ = ['SSD'] + + +@register +class SSD(object): + """ + Single Shot MultiBox Detector, see https://arxiv.org/abs/1512.02325 + + Args: + backbone (object): backbone instance + multi_box_head (object): `MultiBoxHead` instance + output_decoder (object): `SSDOutputDecoder` instance + num_classes (int): number of output classes + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'multi_box_head', 'output_decoder', 'fpn'] + __shared__ = ['num_classes'] + + def __init__(self, + backbone, + fpn=None, + multi_box_head='MultiBoxHead', + output_decoder=SSDOutputDecoder().__dict__, + num_classes=21): + super(SSD, self).__init__() + self.backbone = backbone + self.fpn = fpn + self.multi_box_head = multi_box_head + self.num_classes = num_classes + self.output_decoder = output_decoder + if isinstance(output_decoder, dict): + self.output_decoder = SSDOutputDecoder(**output_decoder) + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + if mode == 'train' or mode == 'eval': + gt_bbox = feed_vars['gt_bbox'] + gt_class = feed_vars['gt_class'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + # backbone + body_feats = self.backbone(im) + + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + if isinstance(body_feats, OrderedDict): + body_feat_names = list(body_feats.keys()) + body_feats = [body_feats[name] for name in body_feat_names] + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats] + + locs, confs, box, box_var = self.multi_box_head( + inputs=body_feats, image=im, num_classes=self.num_classes) + + if mode == 'train': + loss = fluid.layers.ssd_loss(locs, confs, gt_bbox, gt_class, box, + box_var) + loss = fluid.layers.reduce_sum(loss) + return {'loss': loss} + else: + pred = self.output_decoder(locs, confs, box, box_var) + return {'bbox': pred} + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'im_shape': {'shape': [None, 3], 'dtype': 'int32', 'lod_level': 0}, + 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + } + # yapf: enable + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=['image', 'im_id', 'gt_bbox', 'gt_class'], # for train + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars): + return self.build(feed_vars, 'eval') + + def test(self, feed_vars, exclude_nms=False): + assert not exclude_nms, "exclude_nms for {} is not support currently".format( + self.__class__.__name__) + return self.build(feed_vars, 'test') + + def is_bbox_normalized(self): + # SSD use output_decoder in output layers, bbox is normalized + # to range [0, 1], is_bbox_normalized is used in eval.py and infer.py + return True diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/ttfnet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/ttfnet.py new file mode 100755 index 000000000..39df8cf3d --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/ttfnet.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +__all__ = ['TTFNet'] + + +@register +class TTFNet(object): + """ + TTFNet network, see https://arxiv.org/abs/1909.00700 + + Args: + backbone (object): backbone instance + ttf_head (object): `TTFHead` instance + num_classes (int): the number of classes, 80 by default. + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'ttf_head'] + __shared__ = ['num_classes'] + + def __init__(self, backbone, ttf_head='TTFHead', num_classes=80): + super(TTFNet, self).__init__() + self.backbone = backbone + self.ttf_head = ttf_head + self.num_classes = num_classes + + def build(self, feed_vars, mode='train', exclude_nms=False): + im = feed_vars['image'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + body_feats = self.backbone(im) + + if isinstance(body_feats, OrderedDict): + body_feat_names = list(body_feats.keys()) + body_feats = [body_feats[name] for name in body_feat_names] + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats] + + predict_hm, predict_wh = self.ttf_head.get_output( + body_feats, 'ttf_head', is_test=mode == 'test') + if mode == 'train': + heatmap = feed_vars['ttf_heatmap'] + box_target = feed_vars['ttf_box_target'] + reg_weight = feed_vars['ttf_reg_weight'] + loss = self.ttf_head.get_loss(predict_hm, predict_wh, heatmap, + box_target, reg_weight) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + results = self.ttf_head.get_bboxes(predict_hm, predict_wh, + feed_vars['scale_factor']) + return results + + def _inputs_def(self, image_shape, downsample): + im_shape = [None] + image_shape + H, W = im_shape[2:] + target_h = None if H is None else H // downsample + target_w = None if W is None else W // downsample + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'scale_factor': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'ttf_heatmap': {'shape': [None, self.num_classes, target_h, target_w], 'dtype': 'float32', 'lod_level': 0}, + 'ttf_box_target': {'shape': [None, 4, target_h, target_w], 'dtype': 'float32', 'lod_level': 0}, + 'ttf_reg_weight': {'shape': [None, 1, target_h, target_w], 'dtype': 'float32', 'lod_level': 0}, + } + # yapf: enable + + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=[ + 'image', 'ttf_heatmap', 'ttf_box_target', 'ttf_reg_weight' + ], # for train + use_dataloader=True, + iterable=False, + downsample=4): + inputs_def = self._inputs_def(image_shape, downsample) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, mode='train') + + def eval(self, feed_vars): + return self.build(feed_vars, mode='test') + + def test(self, feed_vars, exclude_nms=False): + return self.build(feed_vars, mode='test', exclude_nms=exclude_nms) diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/yolo.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/yolo.py new file mode 100755 index 000000000..80ac34e42 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/architectures/yolo.py @@ -0,0 +1,189 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +__all__ = ['YOLOv3', 'YOLOv4'] + + +@register +class YOLOv3(object): + """ + YOLOv3 network, see https://arxiv.org/abs/1804.02767 + + Args: + backbone (object): an backbone instance + yolo_head (object): an `YOLOv3Head` instance + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'yolo_head'] + __shared__ = ['use_fine_grained_loss'] + + def __init__(self, + backbone, + yolo_head='YOLOv3Head', + use_fine_grained_loss=False): + super(YOLOv3, self).__init__() + self.backbone = backbone + self.yolo_head = yolo_head + self.use_fine_grained_loss = use_fine_grained_loss + + def build(self, feed_vars, mode='train', exclude_nms=False): + im = feed_vars['image'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + body_feats = self.backbone(im) + + if isinstance(body_feats, OrderedDict): + body_feat_names = list(body_feats.keys()) + body_feats = [body_feats[name] for name in body_feat_names] + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats] + + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + gt_class = feed_vars['gt_class'] + gt_score = feed_vars['gt_score'] + + # Get targets for splited yolo loss calculation + num_output_layer = len(self.yolo_head.anchor_masks) + targets = [] + for i in range(num_output_layer): + k = 'target{}'.format(i) + if k in feed_vars: + targets.append(feed_vars[k]) + + loss = self.yolo_head.get_loss(body_feats, gt_bbox, gt_class, + gt_score, targets) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + im_size = feed_vars['im_size'] + # exclude_nms only for benchmark, postprocess(NMS) is not needed + return self.yolo_head.get_prediction( + body_feats, im_size, exclude_nms=exclude_nms) + + def _inputs_def(self, image_shape, num_max_boxes): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_size': {'shape': [None, 2], 'dtype': 'int32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, num_max_boxes, 4], 'dtype': 'float32', 'lod_level': 0}, + 'gt_class': {'shape': [None, num_max_boxes], 'dtype': 'int32', 'lod_level': 0}, + 'gt_score': {'shape': [None, num_max_boxes], 'dtype': 'float32', 'lod_level': 0}, + 'is_difficult': {'shape': [None, num_max_boxes],'dtype': 'int32', 'lod_level': 0}, + } + # yapf: enable + + if self.use_fine_grained_loss: + # yapf: disable + num_output_layer = len(self.yolo_head.anchor_masks) + targets_def = {} + for i in range(num_output_layer): + targets_def['target{}'.format(i)] = {'shape': [None, 3, None, None, None], 'dtype': 'float32', 'lod_level': 0} + # yapf: enable + + downsample = 32 + for k, mask in zip(targets_def.keys(), self.yolo_head.anchor_masks): + targets_def[k]['shape'][1] = len(mask) + targets_def[k]['shape'][2] = 6 + self.yolo_head.num_classes + targets_def[k]['shape'][3] = image_shape[ + -2] // downsample if image_shape[-2] else None + targets_def[k]['shape'][4] = image_shape[ + -1] // downsample if image_shape[-1] else None + downsample //= 2 + + inputs_def.update(targets_def) + + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=['image', 'gt_bbox', 'gt_class', 'gt_score'], # for train + num_max_boxes=50, + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape, num_max_boxes) + # if fields has im_size, this is in eval/infer mode, fine grained loss + # will be disabled for YOLOv3 architecture do not calculate loss in + # eval/infer mode. + if 'im_size' not in fields and self.use_fine_grained_loss: + num_output_layer = len(self.yolo_head.anchor_masks) + fields.extend( + ['target{}'.format(i) for i in range(num_output_layer)]) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, mode='train') + + def eval(self, feed_vars): + return self.build(feed_vars, mode='test') + + def test(self, feed_vars, exclude_nms=False): + return self.build(feed_vars, mode='test', exclude_nms=exclude_nms) + + +@register +class YOLOv4(YOLOv3): + """ + YOLOv4 network, see https://arxiv.org/abs/2004.10934 + + Args: + backbone (object): an backbone instance + yolo_head (object): an `YOLOv4Head` instance + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'yolo_head'] + __shared__ = ['use_fine_grained_loss'] + + def __init__(self, + backbone, + yolo_head='YOLOv4Head', + use_fine_grained_loss=False): + super(YOLOv4, self).__init__( + backbone=backbone, + yolo_head=yolo_head, + use_fine_grained_loss=use_fine_grained_loss) diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/__init__.py new file mode 100755 index 000000000..a6d2eb18f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/__init__.py @@ -0,0 +1,59 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from . import resnet +from . import resnext +from . import darknet +from . import mobilenet +from . import mobilenet_v3 +from . import senet +from . import fpn +from . import vgg +from . import blazenet +from . import faceboxnet +from . import cb_resnet +from . import res2net +from . import hrnet +from . import hrfpn +from . import bfp +from . import hourglass +from . import efficientnet +from . import bifpn +from . import cspdarknet +from . import acfpn +from . import ghostnet + +from .resnet import * +from .resnext import * +from .darknet import * +from .mobilenet import * +from .mobilenet_v3 import * +from .senet import * +from .fpn import * +from .vgg import * +from .blazenet import * +from .faceboxnet import * +from .cb_resnet import * +from .res2net import * +from .hrnet import * +from .hrfpn import * +from .bfp import * +from .hourglass import * +from .efficientnet import * +from .bifpn import * +from .cspdarknet import * +from .acfpn import * +from .ghostnet import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/acfpn.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/acfpn.py new file mode 100755 index 000000000..852586b2f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/acfpn.py @@ -0,0 +1,338 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import copy +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Xavier +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import ConvNorm + +__all__ = ['ACFPN'] + + +@register +class ACFPN(object): + """ + Attention-guided Context Feature Pyramid Network for Object Detection, + see https://arxiv.org/abs/2005.11475 + + Args: + num_chan (int): number of feature channels + min_level (int): lowest level of the backbone feature map to use + max_level (int): highest level of the backbone feature map to use + spatial_scale (list): feature map scaling factor + has_extra_convs (bool): whether has extral convolutions in higher levels + norm_type (str|None): normalization type, 'bn'/'sync_bn'/'affine_channel' + use_c5 (bool): whether to use C5 as the feature map. + norm_groups (int): group number of group norm. + """ + __shared__ = ['norm_type', 'freeze_norm'] + + def __init__(self, + num_chan=256, + min_level=2, + max_level=6, + spatial_scale=[1. / 32., 1. / 16., 1. / 8., 1. / 4.], + has_extra_convs=False, + norm_type=None, + freeze_norm=False, + use_c5=True, + norm_groups=32): + self.freeze_norm = freeze_norm + self.num_chan = num_chan + self.min_level = min_level + self.max_level = max_level + self.spatial_scale = spatial_scale + self.has_extra_convs = has_extra_convs + self.norm_type = norm_type + self.use_c5 = use_c5 + self.norm_groups = norm_groups + + def _add_topdown_lateral(self, body_name, body_input, upper_output): + lateral_name = 'fpn_inner_' + body_name + '_lateral' + topdown_name = 'fpn_topdown_' + body_name + fan = body_input.shape[1] + if self.norm_type: + initializer = Xavier(fan_out=fan) + lateral = ConvNorm( + body_input, + self.num_chan, + 1, + initializer=initializer, + norm_type=self.norm_type, + freeze_norm=self.freeze_norm, + name=lateral_name, + norm_name=lateral_name) + else: + lateral = fluid.layers.conv2d( + body_input, + self.num_chan, + 1, + param_attr=ParamAttr( + name=lateral_name + "_w", initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name=lateral_name + "_b", + learning_rate=2., + regularizer=L2Decay(0.)), + name=lateral_name) + topdown = fluid.layers.resize_nearest( + upper_output, scale=2., name=topdown_name) + + return lateral + topdown + + def dense_aspp_block(self, input, num_filters1, num_filters2, dilation_rate, + dropout_prob, name): + + conv = ConvNorm( + input, + num_filters=num_filters1, + filter_size=1, + stride=1, + groups=1, + norm_decay=0., + norm_type='gn', + norm_groups=self.norm_groups, + dilation=dilation_rate, + lr_scale=1, + freeze_norm=False, + act="relu", + norm_name=name + "_gn", + initializer=None, + bias_attr=False, + name=name + "_gn") + + conv = fluid.layers.conv2d( + conv, + num_filters2, + filter_size=3, + padding=dilation_rate, + dilation=dilation_rate, + act="relu", + param_attr=ParamAttr(name=name + "_conv_w"), + bias_attr=ParamAttr(name=name + "_conv_b"), ) + + if dropout_prob > 0: + conv = fluid.layers.dropout(conv, dropout_prob=dropout_prob) + + return conv + + def dense_aspp(self, input, name=None): + dropout0 = 0.1 + d_feature0 = 512 + d_feature1 = 256 + + aspp3 = self.dense_aspp_block( + input, + num_filters1=d_feature0, + num_filters2=d_feature1, + dropout_prob=dropout0, + name=name + '_aspp3', + dilation_rate=3) + conv = fluid.layers.concat([aspp3, input], axis=1) + + aspp6 = self.dense_aspp_block( + conv, + num_filters1=d_feature0, + num_filters2=d_feature1, + dropout_prob=dropout0, + name=name + '_aspp6', + dilation_rate=6) + conv = fluid.layers.concat([aspp6, conv], axis=1) + + aspp12 = self.dense_aspp_block( + conv, + num_filters1=d_feature0, + num_filters2=d_feature1, + dropout_prob=dropout0, + name=name + '_aspp12', + dilation_rate=12) + conv = fluid.layers.concat([aspp12, conv], axis=1) + + aspp18 = self.dense_aspp_block( + conv, + num_filters1=d_feature0, + num_filters2=d_feature1, + dropout_prob=dropout0, + name=name + '_aspp18', + dilation_rate=18) + conv = fluid.layers.concat([aspp18, conv], axis=1) + + aspp24 = self.dense_aspp_block( + conv, + num_filters1=d_feature0, + num_filters2=d_feature1, + dropout_prob=dropout0, + name=name + '_aspp24', + dilation_rate=24) + + conv = fluid.layers.concat( + [aspp3, aspp6, aspp12, aspp18, aspp24], axis=1) + + conv = ConvNorm( + conv, + num_filters=self.num_chan, + filter_size=1, + stride=1, + groups=1, + norm_decay=0., + norm_type='gn', + norm_groups=self.norm_groups, + dilation=1, + lr_scale=1, + freeze_norm=False, + act="relu", + norm_name=name + "_dense_aspp_reduce_gn", + initializer=None, + bias_attr=False, + name=name + "_dense_aspp_reduce_gn") + + return conv + + def get_output(self, body_dict): + """ + Add FPN onto backbone. + + Args: + body_dict(OrderedDict): Dictionary of variables and each element is the + output of backbone. + + Return: + fpn_dict(OrderedDict): A dictionary represents the output of FPN with + their name. + spatial_scale(list): A list of multiplicative spatial scale factor. + """ + spatial_scale = copy.deepcopy(self.spatial_scale) + body_name_list = list(body_dict.keys())[::-1] + num_backbone_stages = len(body_name_list) + self.fpn_inner_output = [[] for _ in range(num_backbone_stages)] + fpn_inner_name = 'fpn_inner_' + body_name_list[0] + body_input = body_dict[body_name_list[0]] + fan = body_input.shape[1] + if self.norm_type: + initializer = Xavier(fan_out=fan) + self.fpn_inner_output[0] = ConvNorm( + body_input, + self.num_chan, + 1, + initializer=initializer, + norm_type=self.norm_type, + freeze_norm=self.freeze_norm, + name=fpn_inner_name, + norm_name=fpn_inner_name) + else: + self.fpn_inner_output[0] = fluid.layers.conv2d( + body_input, + self.num_chan, + 1, + param_attr=ParamAttr( + name=fpn_inner_name + "_w", + initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name=fpn_inner_name + "_b", + learning_rate=2., + regularizer=L2Decay(0.)), + name=fpn_inner_name) + + self.fpn_inner_output[0] += self.dense_aspp( + self.fpn_inner_output[0], name="acfpn") + + for i in range(1, num_backbone_stages): + body_name = body_name_list[i] + body_input = body_dict[body_name] + top_output = self.fpn_inner_output[i - 1] + fpn_inner_single = self._add_topdown_lateral(body_name, body_input, + top_output) + self.fpn_inner_output[i] = fpn_inner_single + fpn_dict = {} + fpn_name_list = [] + for i in range(num_backbone_stages): + fpn_name = 'fpn_' + body_name_list[i] + fan = self.fpn_inner_output[i].shape[1] * 3 * 3 + if self.norm_type: + initializer = Xavier(fan_out=fan) + fpn_output = ConvNorm( + self.fpn_inner_output[i], + self.num_chan, + 3, + initializer=initializer, + norm_type=self.norm_type, + freeze_norm=self.freeze_norm, + name=fpn_name, + norm_name=fpn_name) + else: + fpn_output = fluid.layers.conv2d( + self.fpn_inner_output[i], + self.num_chan, + filter_size=3, + padding=1, + param_attr=ParamAttr( + name=fpn_name + "_w", initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name=fpn_name + "_b", + learning_rate=2., + regularizer=L2Decay(0.)), + name=fpn_name) + fpn_dict[fpn_name] = fpn_output + fpn_name_list.append(fpn_name) + if not self.has_extra_convs and self.max_level - self.min_level == len( + spatial_scale): + body_top_name = fpn_name_list[0] + body_top_extension = fluid.layers.pool2d( + fpn_dict[body_top_name], + 1, + 'max', + pool_stride=2, + name=body_top_name + '_subsampled_2x') + fpn_dict[body_top_name + '_subsampled_2x'] = body_top_extension + fpn_name_list.insert(0, body_top_name + '_subsampled_2x') + spatial_scale.insert(0, spatial_scale[0] * 0.5) + # Coarser FPN levels introduced for RetinaNet + highest_backbone_level = self.min_level + len(spatial_scale) - 1 + if self.has_extra_convs and self.max_level > highest_backbone_level: + if self.use_c5: + fpn_blob = body_dict[body_name_list[0]] + else: + fpn_blob = fpn_dict[fpn_name_list[0]] + for i in range(highest_backbone_level + 1, self.max_level + 1): + fpn_blob_in = fpn_blob + fpn_name = 'fpn_' + str(i) + if i > highest_backbone_level + 1: + fpn_blob_in = fluid.layers.relu(fpn_blob) + fan = fpn_blob_in.shape[1] * 3 * 3 + fpn_blob = fluid.layers.conv2d( + input=fpn_blob_in, + num_filters=self.num_chan, + filter_size=3, + stride=2, + padding=1, + param_attr=ParamAttr( + name=fpn_name + "_w", initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name=fpn_name + "_b", + learning_rate=2., + regularizer=L2Decay(0.)), + name=fpn_name) + fpn_dict[fpn_name] = fpn_blob + fpn_name_list.insert(0, fpn_name) + spatial_scale.insert(0, spatial_scale[0] * 0.5) + res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list]) + return res_dict, spatial_scale diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/bfp.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/bfp.py new file mode 100755 index 000000000..6bafc68c9 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/bfp.py @@ -0,0 +1,156 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Xavier +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register + +from .nonlocal_helper import add_space_nonlocal +from .fpn import FPN + +__all__ = ['BFP'] + + +@register +class BFP(object): + """ + Libra R-CNN, see https://arxiv.org/abs/1904.02701 + Args: + base_neck (dict): basic neck before balanced feature pyramid (bfp) + refine_level (int): index of integration and refine level of bfp + refine_type (str): refine type, None, conv or nonlocal + nonlocal_reduction (float): channel reduction level if refine_type is nonlocal + with_bias (bool): whether the nonlocal module contains bias + with_scale (bool): whether to scale feature in nonlocal module or not + """ + __inject__ = ['base_neck'] + + def __init__(self, + base_neck=FPN().__dict__, + refine_level=2, + refine_type="nonlocal", + nonlocal_reduction=1, + with_bias=True, + with_scale=False): + if isinstance(base_neck, dict): + self.base_neck = FPN(**base_neck) + self.refine_level = refine_level + self.refine_type = refine_type + self.nonlocal_reduction = nonlocal_reduction + self.with_bias = with_bias + self.with_scale = with_scale + + def get_output(self, body_dict): + # top-down order + res_dict, spatial_scale = self.base_neck.get_output(body_dict) + res_dict = self.get_output_bfp(res_dict) + return res_dict, spatial_scale + + def get_output_bfp(self, body_dict): + body_name_list = list(body_dict.keys()) + num_backbone_stages = len(body_name_list) + + self.num_levels = len(body_dict) + + # step 1: gather multi-level features by resize and average + feats = [] + refine_level_name = body_name_list[self.refine_level] + + for i in range(self.num_levels): + curr_fpn_name = body_name_list[i] + pool_stride = 2**(i - self.refine_level) + pool_size = [ + body_dict[refine_level_name].shape[2], + body_dict[refine_level_name].shape[3] + ] + if i > self.refine_level: + gathered = fluid.layers.pool2d( + input=body_dict[curr_fpn_name], + pool_type='max', + pool_size=pool_stride, + pool_stride=pool_stride, + ceil_mode=True, ) + else: + gathered = self._resize_input_tensor( + body_dict[curr_fpn_name], body_dict[refine_level_name], + 1.0 / pool_stride) + feats.append(gathered) + + bsf = sum(feats) / len(feats) + + # step 2: refine gathered features + if self.refine_type == "conv": + bsf = fluid.layers.conv2d( + bsf, + bsf.shape[1], + filter_size=3, + padding=1, + param_attr=ParamAttr(name="bsf_w"), + bias_attr=ParamAttr(name="bsf_b"), + name="bsf") + elif self.refine_type == "nonlocal": + dim_in = bsf.shape[1] + nonlocal_name = "nonlocal_bsf" + bsf = add_space_nonlocal( + bsf, + bsf.shape[1], + bsf.shape[1], + nonlocal_name, + int(bsf.shape[1] / self.nonlocal_reduction), + with_bias=self.with_bias, + with_scale=self.with_scale) + + # step 3: scatter refined features to multi-levels by a residual path + fpn_dict = {} + fpn_name_list = [] + for i in range(self.num_levels): + curr_fpn_name = body_name_list[i] + pool_stride = 2**(self.refine_level - i) + if i >= self.refine_level: + residual = self._resize_input_tensor( + bsf, body_dict[curr_fpn_name], 1.0 / pool_stride) + else: + residual = fluid.layers.pool2d( + input=bsf, + pool_type='max', + pool_size=pool_stride, + pool_stride=pool_stride, + ceil_mode=True, ) + + fpn_dict[curr_fpn_name] = residual + body_dict[curr_fpn_name] + fpn_name_list.append(curr_fpn_name) + + res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list]) + return res_dict + + def _resize_input_tensor(self, body_input, ref_output, scale): + shape = fluid.layers.shape(ref_output) + shape_hw = fluid.layers.slice(shape, axes=[0], starts=[2], ends=[4]) + out_shape_ = shape_hw + out_shape = fluid.layers.cast(out_shape_, dtype='int32') + out_shape.stop_gradient = True + body_output = fluid.layers.resize_nearest( + body_input, scale=scale, out_shape=out_shape) + return body_output diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/bifpn.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/bifpn.py new file mode 100755 index 000000000..d65517cea --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/bifpn.py @@ -0,0 +1,202 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay +from paddle.fluid.initializer import Constant, Xavier + +from ppdet.core.workspace import register + +__all__ = ['BiFPN'] + + +class FusionConv(object): + def __init__(self, num_chan): + super(FusionConv, self).__init__() + self.num_chan = num_chan + + def __call__(self, inputs, name=''): + x = fluid.layers.swish(inputs) + # depthwise + x = fluid.layers.conv2d( + x, + self.num_chan, + filter_size=3, + padding='SAME', + groups=self.num_chan, + param_attr=ParamAttr( + initializer=Xavier(), name=name + '_dw_w'), + bias_attr=False) + # pointwise + x = fluid.layers.conv2d( + x, + self.num_chan, + filter_size=1, + param_attr=ParamAttr( + initializer=Xavier(), name=name + '_pw_w'), + bias_attr=ParamAttr( + regularizer=L2Decay(0.), name=name + '_pw_b')) + # bn + act + x = fluid.layers.batch_norm( + x, + momentum=0.997, + epsilon=1e-04, + param_attr=ParamAttr( + initializer=Constant(1.0), + regularizer=L2Decay(0.), + name=name + '_bn_w'), + bias_attr=ParamAttr( + regularizer=L2Decay(0.), name=name + '_bn_b')) + return x + + +class BiFPNCell(object): + def __init__(self, num_chan, levels=5): + super(BiFPNCell, self).__init__() + self.levels = levels + self.num_chan = num_chan + num_trigates = levels - 2 + num_bigates = levels + self.trigates = fluid.layers.create_parameter( + shape=[num_trigates, 3], + dtype='float32', + default_initializer=fluid.initializer.Constant(1.)) + self.bigates = fluid.layers.create_parameter( + shape=[num_bigates, 2], + dtype='float32', + default_initializer=fluid.initializer.Constant(1.)) + self.eps = 1e-4 + + def __call__(self, inputs, cell_name=''): + assert len(inputs) == self.levels + + def upsample(feat): + return fluid.layers.resize_nearest(feat, scale=2.) + + def downsample(feat): + return fluid.layers.pool2d( + feat, + pool_type='max', + pool_size=3, + pool_stride=2, + pool_padding='SAME') + + fuse_conv = FusionConv(self.num_chan) + + # normalize weight + trigates = fluid.layers.relu(self.trigates) + bigates = fluid.layers.relu(self.bigates) + trigates /= fluid.layers.reduce_sum( + trigates, dim=1, keep_dim=True) + self.eps + bigates /= fluid.layers.reduce_sum( + bigates, dim=1, keep_dim=True) + self.eps + + feature_maps = list(inputs) # make a copy + # top down path + for l in range(self.levels - 1): + p = self.levels - l - 2 + w1 = fluid.layers.slice( + bigates, axes=[0, 1], starts=[l, 0], ends=[l + 1, 1]) + w2 = fluid.layers.slice( + bigates, axes=[0, 1], starts=[l, 1], ends=[l + 1, 2]) + above = upsample(feature_maps[p + 1]) + feature_maps[p] = fuse_conv( + w1 * above + w2 * inputs[p], + name='{}_tb_{}'.format(cell_name, l)) + # bottom up path + for l in range(1, self.levels): + p = l + name = '{}_bt_{}'.format(cell_name, l) + below = downsample(feature_maps[p - 1]) + if p == self.levels - 1: + # handle P7 + w1 = fluid.layers.slice( + bigates, axes=[0, 1], starts=[p, 0], ends=[p + 1, 1]) + w2 = fluid.layers.slice( + bigates, axes=[0, 1], starts=[p, 1], ends=[p + 1, 2]) + feature_maps[p] = fuse_conv( + w1 * below + w2 * inputs[p], name=name) + else: + w1 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) + w2 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) + w3 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) + feature_maps[p] = fuse_conv( + w1 * feature_maps[p] + w2 * below + w3 * inputs[p], + name=name) + return feature_maps + + +@register +class BiFPN(object): + """ + Bidirectional Feature Pyramid Network, see https://arxiv.org/abs/1911.09070 + + Args: + num_chan (int): number of feature channels + repeat (int): number of repeats of the BiFPN module + level (int): number of FPN levels, default: 5 + """ + + def __init__(self, num_chan, repeat=3, levels=5): + super(BiFPN, self).__init__() + self.num_chan = num_chan + self.repeat = repeat + self.levels = levels + + def __call__(self, inputs): + feats = [] + # NOTE add two extra levels + for idx in range(self.levels): + if idx <= len(inputs): + if idx == len(inputs): + feat = inputs[-1] + else: + feat = inputs[idx] + + if feat.shape[1] != self.num_chan: + feat = fluid.layers.conv2d( + feat, + self.num_chan, + filter_size=1, + padding='SAME', + param_attr=ParamAttr(initializer=Xavier()), + bias_attr=ParamAttr(regularizer=L2Decay(0.))) + feat = fluid.layers.batch_norm( + feat, + momentum=0.997, + epsilon=1e-04, + param_attr=ParamAttr( + initializer=Constant(1.0), regularizer=L2Decay(0.)), + bias_attr=ParamAttr(regularizer=L2Decay(0.))) + + if idx >= len(inputs): + feat = fluid.layers.pool2d( + feat, + pool_type='max', + pool_size=3, + pool_stride=2, + pool_padding='SAME') + feats.append(feat) + + biFPN = BiFPNCell(self.num_chan, self.levels) + for r in range(self.repeat): + feats = biFPN(feats, 'bifpn_{}'.format(r)) + return feats diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/blazenet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/blazenet.py new file mode 100755 index 000000000..d3987521a --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/blazenet.py @@ -0,0 +1,326 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +__all__ = ['BlazeNet'] + + +@register +class BlazeNet(object): + """ + BlazeFace, see https://arxiv.org/abs/1907.05047 + + Args: + blaze_filters (list): number of filter for each blaze block + double_blaze_filters (list): number of filter for each double_blaze block + with_extra_blocks (bool): whether or not extra blocks should be added + lite_edition (bool): whether or not is blazeface-lite + use_5x5kernel (bool): whether or not filter size is 5x5 in depth-wise conv + """ + + def __init__( + self, + blaze_filters=[[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]], + double_blaze_filters=[[48, 24, 96, 2], [96, 24, 96], [96, 24, 96], + [96, 24, 96, 2], [96, 24, 96], [96, 24, 96]], + with_extra_blocks=True, + lite_edition=False, + use_5x5kernel=True): + super(BlazeNet, self).__init__() + + self.blaze_filters = blaze_filters + self.double_blaze_filters = double_blaze_filters + self.with_extra_blocks = with_extra_blocks + self.lite_edition = lite_edition + self.use_5x5kernel = use_5x5kernel + + def __call__(self, input): + if not self.lite_edition: + conv1_num_filters = self.blaze_filters[0][0] + conv = self._conv_norm( + input=input, + num_filters=conv1_num_filters, + filter_size=3, + stride=2, + padding=1, + act='relu', + name="conv1") + + for k, v in enumerate(self.blaze_filters): + assert len(v) in [2, 3], \ + "blaze_filters {} not in [2, 3]" + if len(v) == 2: + conv = self.BlazeBlock( + conv, + v[0], + v[1], + use_5x5kernel=self.use_5x5kernel, + name='blaze_{}'.format(k)) + elif len(v) == 3: + conv = self.BlazeBlock( + conv, + v[0], + v[1], + stride=v[2], + use_5x5kernel=self.use_5x5kernel, + name='blaze_{}'.format(k)) + + layers = [] + for k, v in enumerate(self.double_blaze_filters): + assert len(v) in [3, 4], \ + "blaze_filters {} not in [3, 4]" + if len(v) == 3: + conv = self.BlazeBlock( + conv, + v[0], + v[1], + double_channels=v[2], + use_5x5kernel=self.use_5x5kernel, + name='double_blaze_{}'.format(k)) + elif len(v) == 4: + layers.append(conv) + conv = self.BlazeBlock( + conv, + v[0], + v[1], + double_channels=v[2], + stride=v[3], + use_5x5kernel=self.use_5x5kernel, + name='double_blaze_{}'.format(k)) + layers.append(conv) + + if not self.with_extra_blocks: + return layers[-1] + return layers[-2], layers[-1] + else: + conv1 = self._conv_norm( + input=input, + num_filters=24, + filter_size=5, + stride=2, + padding=2, + act='relu', + name="conv1") + conv2 = self.Blaze_lite(conv1, 24, 24, 1, 'conv2') + conv3 = self.Blaze_lite(conv2, 24, 28, 1, 'conv3') + conv4 = self.Blaze_lite(conv3, 28, 32, 2, 'conv4') + conv5 = self.Blaze_lite(conv4, 32, 36, 1, 'conv5') + conv6 = self.Blaze_lite(conv5, 36, 42, 1, 'conv6') + conv7 = self.Blaze_lite(conv6, 42, 48, 2, 'conv7') + in_ch = 48 + for i in range(5): + conv7 = self.Blaze_lite(conv7, in_ch, in_ch + 8, 1, + 'conv{}'.format(8 + i)) + in_ch += 8 + assert in_ch == 88 + conv13 = self.Blaze_lite(conv7, 88, 96, 2, 'conv13') + for i in range(4): + conv13 = self.Blaze_lite(conv13, 96, 96, 1, + 'conv{}'.format(14 + i)) + + return conv7, conv13 + + def BlazeBlock(self, + input, + in_channels, + out_channels, + double_channels=None, + stride=1, + use_5x5kernel=True, + name=None): + assert stride in [1, 2] + use_pool = not stride == 1 + use_double_block = double_channels is not None + act = 'relu' if use_double_block else None + mixed_precision_enabled = mixed_precision_global_state() is not None + + if use_5x5kernel: + conv_dw = self._conv_norm( + input=input, + filter_size=5, + num_filters=in_channels, + stride=stride, + padding=2, + num_groups=in_channels, + use_cudnn=mixed_precision_enabled, + name=name + "1_dw") + else: + conv_dw_1 = self._conv_norm( + input=input, + filter_size=3, + num_filters=in_channels, + stride=1, + padding=1, + num_groups=in_channels, + use_cudnn=mixed_precision_enabled, + name=name + "1_dw_1") + conv_dw = self._conv_norm( + input=conv_dw_1, + filter_size=3, + num_filters=in_channels, + stride=stride, + padding=1, + num_groups=in_channels, + use_cudnn=mixed_precision_enabled, + name=name + "1_dw_2") + + conv_pw = self._conv_norm( + input=conv_dw, + filter_size=1, + num_filters=out_channels, + stride=1, + padding=0, + act=act, + name=name + "1_sep") + + if use_double_block: + if use_5x5kernel: + conv_dw = self._conv_norm( + input=conv_pw, + filter_size=5, + num_filters=out_channels, + stride=1, + padding=2, + use_cudnn=mixed_precision_enabled, + name=name + "2_dw") + else: + conv_dw_1 = self._conv_norm( + input=conv_pw, + filter_size=3, + num_filters=out_channels, + stride=1, + padding=1, + num_groups=out_channels, + use_cudnn=mixed_precision_enabled, + name=name + "2_dw_1") + conv_dw = self._conv_norm( + input=conv_dw_1, + filter_size=3, + num_filters=out_channels, + stride=1, + padding=1, + num_groups=out_channels, + use_cudnn=mixed_precision_enabled, + name=name + "2_dw_2") + + conv_pw = self._conv_norm( + input=conv_dw, + filter_size=1, + num_filters=double_channels, + stride=1, + padding=0, + name=name + "2_sep") + + # shortcut + if use_pool: + shortcut_channel = double_channels or out_channels + shortcut_pool = self._pooling_block(input, stride, stride) + channel_pad = self._conv_norm( + input=shortcut_pool, + filter_size=1, + num_filters=shortcut_channel, + stride=1, + padding=0, + name="shortcut" + name) + return fluid.layers.elementwise_add( + x=channel_pad, y=conv_pw, act='relu') + return fluid.layers.elementwise_add(x=input, y=conv_pw, act='relu') + + def Blaze_lite(self, input, in_channels, out_channels, stride=1, name=None): + assert stride in [1, 2] + use_pool = not stride == 1 + ues_pad = not in_channels == out_channels + conv_dw = self._conv_norm( + input=input, + filter_size=3, + num_filters=in_channels, + stride=stride, + padding=1, + num_groups=in_channels, + name=name + "_dw") + + conv_pw = self._conv_norm( + input=conv_dw, + filter_size=1, + num_filters=out_channels, + stride=1, + padding=0, + name=name + "_sep") + + if use_pool: + shortcut_pool = self._pooling_block(input, stride, stride) + if ues_pad: + conv_pad = shortcut_pool if use_pool else input + channel_pad = self._conv_norm( + input=conv_pad, + filter_size=1, + num_filters=out_channels, + stride=1, + padding=0, + name="shortcut" + name) + return fluid.layers.elementwise_add( + x=channel_pad, y=conv_pw, act='relu') + return fluid.layers.elementwise_add(x=input, y=conv_pw, act='relu') + + def _conv_norm( + self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + act='relu', # None + use_cudnn=True, + name=None): + parameter_attr = ParamAttr( + learning_rate=0.1, + initializer=fluid.initializer.MSRA(), + name=name + "_weights") + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=parameter_attr, + bias_attr=False) + return fluid.layers.batch_norm(input=conv, act=act) + + def _pooling_block(self, + conv, + pool_size, + pool_stride, + pool_padding=0, + ceil_mode=True): + pool = fluid.layers.pool2d( + input=conv, + pool_size=pool_size, + pool_type='max', + pool_stride=pool_stride, + pool_padding=pool_padding, + ceil_mode=ceil_mode) + return pool diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/cb_resnet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/cb_resnet.py new file mode 100755 index 000000000..67afdb300 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/cb_resnet.py @@ -0,0 +1,451 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.framework import Variable +from paddle.fluid.regularizer import L2Decay +from paddle.fluid.initializer import Constant + +from ppdet.core.workspace import register, serializable +from numbers import Integral + +from .name_adapter import NameAdapter +from .nonlocal_helper import add_space_nonlocal + +__all__ = ['CBResNet'] + + +@register +@serializable +class CBResNet(object): + """ + CBNet, see https://arxiv.org/abs/1909.03625 + Args: + depth (int): ResNet depth, should be 18, 34, 50, 101, 152. + freeze_at (int): freeze the backbone at which stage + norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel' + freeze_norm (bool): freeze normalization layers + norm_decay (float): weight decay for normalization layer weights + variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently + feature_maps (list): index of stages whose feature maps are returned + dcn_v2_stages (list): index of stages who select deformable conv v2 + nonlocal_stages (list): index of stages who select nonlocal networks + repeat_num (int): number of repeat for backbone + Attention: + 1. Here we set the ResNet as the base backbone. + 2. All the pretraned params are copied from corresponding names, + but with different names to avoid name refliction. + """ + + def __init__(self, + depth=50, + freeze_at=2, + norm_type='bn', + freeze_norm=True, + norm_decay=0., + variant='b', + feature_maps=[2, 3, 4, 5], + dcn_v2_stages=[], + nonlocal_stages=[], + repeat_num=2, + lr_mult_list=[1., 1., 1., 1.]): + super(CBResNet, self).__init__() + + if isinstance(feature_maps, Integral): + feature_maps = [feature_maps] + + assert depth in [18, 34, 50, 101, 152, 200], \ + "depth {} not in [18, 34, 50, 101, 152, 200]" + assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant" + assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4" + assert len(feature_maps) > 0, "need one or more feature maps" + assert norm_type in ['bn', 'sync_bn', 'affine_channel'] + assert not (len(nonlocal_stages)>0 and depth<50), \ + "non-local is not supported for resnet18 or resnet34" + + self.depth = depth + self.dcn_v2_stages = dcn_v2_stages + self.freeze_at = freeze_at + self.norm_type = norm_type + self.norm_decay = norm_decay + self.freeze_norm = freeze_norm + self.variant = variant + self._model_type = 'ResNet' + self.feature_maps = feature_maps + self.repeat_num = repeat_num + self.curr_level = 0 + self.depth_cfg = { + 18: ([2, 2, 2, 2], self.basicblock), + 34: ([3, 4, 6, 3], self.basicblock), + 50: ([3, 4, 6, 3], self.bottleneck), + 101: ([3, 4, 23, 3], self.bottleneck), + 152: ([3, 8, 36, 3], self.bottleneck), + 200: ([3, 12, 48, 3], self.bottleneck), + } + + self.nonlocal_stages = nonlocal_stages + self.nonlocal_mod_cfg = { + 50: 2, + 101: 5, + 152: 8, + 200: 12, + } + + self.lr_mult_list = lr_mult_list + self.stage_num = -1 + + self.stage_filters = [64, 128, 256, 512] + self._c1_out_chan_num = 64 + self.na = NameAdapter(self) + + def _conv_offset(self, + input, + filter_size, + stride, + padding, + act=None, + name=None): + out_channel = filter_size * filter_size * 3 + out = fluid.layers.conv2d( + input, + num_filters=out_channel, + filter_size=filter_size, + stride=stride, + padding=padding, + param_attr=ParamAttr( + initializer=Constant(0.0), name=name + ".w_0"), + bias_attr=ParamAttr( + initializer=Constant(0.0), name=name + ".b_0"), + act=act, + name=name) + return out + + def _conv_norm(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None, + dcn=False): + + # need fine lr for distilled model, default as 1.0 + lr_mult = 1.0 + mult_idx = max(self.stage_num - 2, 0) + mult_idx = min(self.stage_num - 2, 3) + lr_mult = self.lr_mult_list[mult_idx] + + if not dcn: + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr( + name=name + "_weights_" + str(self.curr_level), + learning_rate=lr_mult), + bias_attr=False) + else: + offset_mask = self._conv_offset( + input=input, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + act=None, + name=name + "_conv_offset_" + str(self.curr_level)) + offset_channel = filter_size**2 * 2 + mask_channel = filter_size**2 + offset, mask = fluid.layers.split( + input=offset_mask, + num_or_sections=[offset_channel, mask_channel], + dim=1) + mask = fluid.layers.sigmoid(mask) + conv = fluid.layers.deformable_conv( + input=input, + offset=offset, + mask=mask, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + deformable_groups=1, + im2col_step=1, + param_attr=ParamAttr( + name=name + "_weights_" + str(self.curr_level), + learning_rate=lr_mult), + bias_attr=False) + + bn_name = self.na.fix_conv_norm_name(name) + + norm_lr = 0. if self.freeze_norm else lr_mult + norm_decay = self.norm_decay + pattr = ParamAttr( + name=bn_name + '_scale_' + str(self.curr_level), + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + battr = ParamAttr( + name=bn_name + '_offset_' + str(self.curr_level), + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + + if self.norm_type in ['bn', 'sync_bn']: + global_stats = True if self.freeze_norm else False + out = fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1_' + str(self.curr_level), + param_attr=pattr, + bias_attr=battr, + moving_mean_name=bn_name + '_mean_' + str(self.curr_level), + moving_variance_name=bn_name + '_variance_' + + str(self.curr_level), + use_global_stats=global_stats) + scale = fluid.framework._get_var(pattr.name) + bias = fluid.framework._get_var(battr.name) + elif self.norm_type == 'affine_channel': + assert False, "deprecated!!!" + if self.freeze_norm: + scale.stop_gradient = True + bias.stop_gradient = True + return out + + def _shortcut(self, input, ch_out, stride, is_first, name): + max_pooling_in_short_cut = self.variant == 'd' + ch_in = input.shape[1] + # the naming rule is same as pretrained weight + name = self.na.fix_shortcut_name(name) + if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first): + if max_pooling_in_short_cut and not is_first: + input = fluid.layers.pool2d( + input=input, + pool_size=2, + pool_stride=2, + pool_padding=0, + ceil_mode=True, + pool_type='avg') + return self._conv_norm(input, ch_out, 1, 1, name=name) + return self._conv_norm(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck(self, input, num_filters, stride, is_first, name, dcn=False): + if self.variant == 'a': + stride1, stride2 = stride, 1 + else: + stride1, stride2 = 1, stride + + # ResNeXt + groups = getattr(self, 'groups', 1) + group_width = getattr(self, 'group_width', -1) + if groups == 1: + expand = 4 + elif (groups * group_width) == 256: + expand = 1 + else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d + num_filters = num_filters // 2 + expand = 2 + + conv_name1, conv_name2, conv_name3, \ + shortcut_name = self.na.fix_bottleneck_name(name) + + conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1], + [num_filters, 3, stride2, 'relu', groups, conv_name2], + [num_filters * expand, 1, 1, None, 1, conv_name3]] + + residual = input + for i, (c, k, s, act, g, _name) in enumerate(conv_def): + residual = self._conv_norm( + input=residual, + num_filters=c, + filter_size=k, + stride=s, + act=act, + groups=g, + name=_name, + dcn=(i == 1 and dcn)) + short = self._shortcut( + input, + num_filters * expand, + stride, + is_first=is_first, + name=shortcut_name) + # Squeeze-and-Excitation + if callable(getattr(self, '_squeeze_excitation', None)): + residual = self._squeeze_excitation( + input=residual, num_channels=num_filters, name='fc' + name) + return fluid.layers.elementwise_add(x=short, y=residual, act='relu') + + def basicblock(self, input, num_filters, stride, is_first, name, dcn=False): + assert dcn is False, "Not implemented yet." + conv0 = self._conv_norm( + input=input, + num_filters=num_filters, + filter_size=3, + act='relu', + stride=stride, + name=name + "_branch2a") + conv1 = self._conv_norm( + input=conv0, + num_filters=num_filters, + filter_size=3, + act=None, + name=name + "_branch2b") + short = self._shortcut( + input, num_filters, stride, is_first, name=name + "_branch1") + return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') + + def layer_warp(self, input, stage_num): + """ + Args: + input (Variable): input variable. + stage_num (int): the stage number, should be 2, 3, 4, 5 + + Returns: + The last variable in endpoint-th stage. + """ + assert stage_num in [2, 3, 4, 5] + + self.stage_num = stage_num + + stages, block_func = self.depth_cfg[self.depth] + count = stages[stage_num - 2] + + ch_out = self.stage_filters[stage_num - 2] + is_first = False if stage_num != 2 else True + dcn = True if stage_num in self.dcn_v2_stages else False + + nonlocal_mod = 1000 + if stage_num in self.nonlocal_stages: + nonlocal_mod = self.nonlocal_mod_cfg[ + self.depth] if stage_num == 4 else 2 + + # Make the layer name and parameter name consistent + # with ImageNet pre-trained model + conv = input + for i in range(count): + conv_name = self.na.fix_layer_warp_name(stage_num, count, i) + if self.depth < 50: + is_first = True if i == 0 and stage_num == 2 else False + conv = block_func( + input=conv, + num_filters=ch_out, + stride=2 if i == 0 and stage_num != 2 else 1, + is_first=is_first, + name=conv_name, + dcn=dcn) + + # add non local model + dim_in = conv.shape[1] + nonlocal_name = "nonlocal_conv{}_lvl{}".format(stage_num, + self.curr_level) + if i % nonlocal_mod == nonlocal_mod - 1: + conv = add_space_nonlocal(conv, dim_in, dim_in, + nonlocal_name + '_{}'.format(i), + int(dim_in / 2)) + + return conv + + def c1_stage(self, input): + out_chan = self._c1_out_chan_num + + conv1_name = self.na.fix_c1_stage_name() + + if self.variant in ['c', 'd']: + conv1_1_name = "conv1_1" + conv1_2_name = "conv1_2" + conv1_3_name = "conv1_3" + conv_def = [ + [out_chan // 2, 3, 2, conv1_1_name], + [out_chan // 2, 3, 1, conv1_2_name], + [out_chan, 3, 1, conv1_3_name], + ] + else: + conv_def = [[out_chan, 7, 2, conv1_name]] + + for (c, k, s, _name) in conv_def: + input = self._conv_norm( + input=input, + num_filters=c, + filter_size=k, + stride=s, + act='relu', + name=_name) + + output = fluid.layers.pool2d( + input=input, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + return output + + def connect(self, left, right, name): + ch_right = right.shape[1] + conv = self._conv_norm( + left, + num_filters=ch_right, + filter_size=1, + stride=1, + act="relu", + name=name + "_connect") + shape = fluid.layers.shape(right) + shape_hw = fluid.layers.slice(shape, axes=[0], starts=[2], ends=[4]) + out_shape_ = shape_hw + out_shape = fluid.layers.cast(out_shape_, dtype='int32') + out_shape.stop_gradient = True + conv = fluid.layers.resize_nearest(conv, scale=2., out_shape=out_shape) + + output = fluid.layers.elementwise_add(x=right, y=conv) + return output + + def __call__(self, input): + assert isinstance(input, Variable) + assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \ + "feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps) + + res_endpoints = [] + + self.curr_level = 0 + res = self.c1_stage(input) + feature_maps = range(2, max(self.feature_maps) + 1) + for i in feature_maps: + res = self.layer_warp(res, i) + if i in self.feature_maps: + res_endpoints.append(res) + + for num in range(1, self.repeat_num): + self.stage_num = -1 + self.curr_level = num + res = self.c1_stage(input) + for i in range(len(res_endpoints)): + res = self.connect(res_endpoints[i], res, "test_c" + str(i + 1)) + res = self.layer_warp(res, i + 2) + res_endpoints[i] = res + if self.freeze_at >= i + 2: + res.stop_gradient = True + + return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat) + for idx, feat in enumerate(res_endpoints)]) diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/cspdarknet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/cspdarknet.py new file mode 100755 index 000000000..c789e2229 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/cspdarknet.py @@ -0,0 +1,212 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register + +__all__ = ['CSPDarkNet'] + + +@register +class CSPDarkNet(object): + """ + CSPDarkNet, see https://arxiv.org/abs/1911.11929 + Args: + depth (int): network depth, currently only cspdarknet 53 is supported + norm_type (str): normalization type, 'bn' and 'sync_bn' are supported + norm_decay (float): weight decay for normalization layer weights + """ + __shared__ = ['norm_type', 'weight_prefix_name'] + + def __init__(self, + depth=53, + norm_type='bn', + norm_decay=0., + weight_prefix_name=''): + assert depth in [53], "unsupported depth value" + self.depth = depth + self.norm_type = norm_type + self.norm_decay = norm_decay + self.depth_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)} + self.prefix_name = weight_prefix_name + + def _softplus(self, input): + expf = fluid.layers.exp(fluid.layers.clip(input, -200, 50)) + return fluid.layers.log(1 + expf) + + def _mish(self, input): + return input * fluid.layers.tanh(self._softplus(input)) + + def _conv_norm(self, + input, + ch_out, + filter_size, + stride, + padding, + act='mish', + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + act=None, + param_attr=ParamAttr(name=name + ".conv.weights"), + bias_attr=False) + + bn_name = name + ".bn" + bn_param_attr = ParamAttr( + regularizer=L2Decay(float(self.norm_decay)), + name=bn_name + '.scale') + bn_bias_attr = ParamAttr( + regularizer=L2Decay(float(self.norm_decay)), + name=bn_name + '.offset') + + out = fluid.layers.batch_norm( + input=conv, + act=None, + param_attr=bn_param_attr, + bias_attr=bn_bias_attr, + moving_mean_name=bn_name + '.mean', + moving_variance_name=bn_name + '.var') + + if act == 'mish': + out = self._mish(out) + + return out + + def _downsample(self, + input, + ch_out, + filter_size=3, + stride=2, + padding=1, + name=None): + return self._conv_norm( + input, + ch_out=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + name=name) + + def conv_layer(self, + input, + ch_out, + filter_size=1, + stride=1, + padding=0, + name=None): + return self._conv_norm( + input, + ch_out=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + name=name) + + def basicblock(self, input, ch_out, scale_first=False, name=None): + conv1 = self._conv_norm( + input, + ch_out=ch_out // 2 if scale_first else ch_out, + filter_size=1, + stride=1, + padding=0, + name=name + ".0") + conv2 = self._conv_norm( + conv1, + ch_out=ch_out, + filter_size=3, + stride=1, + padding=1, + name=name + ".1") + out = fluid.layers.elementwise_add(x=input, y=conv2, act=None) + return out + + def layer_warp(self, + block_func, + input, + ch_out, + count, + keep_ch=False, + scale_first=False, + name=None): + if scale_first: + ch_out = ch_out * 2 + right = self.conv_layer( + input, ch_out, name='{}.route_in.right'.format(name)) + neck = self.conv_layer(input, ch_out, name='{}.neck'.format(name)) + out = block_func( + neck, + ch_out=ch_out, + scale_first=scale_first, + name='{}.0'.format(name)) + for j in six.moves.xrange(1, count): + out = block_func(out, ch_out=ch_out, name='{}.{}'.format(name, j)) + left = self.conv_layer( + out, ch_out, name='{}.route_in.left'.format(name)) + route = fluid.layers.concat([left, right], axis=1) + out = self.conv_layer( + route, + ch_out=ch_out if keep_ch else ch_out * 2, + name='{}.conv_layer'.format(name)) + return out + + def __call__(self, input): + """ + Get the backbone of CSPDarkNet, that is output for the 5 stages. + + Args: + input (Variable): input variable. + + Returns: + The last variables of each stage. + """ + stages, block_func = self.depth_cfg[self.depth] + stages = stages[0:5] + conv = self._conv_norm( + input=input, + ch_out=32, + filter_size=3, + stride=1, + padding=1, + act='mish', + name=self.prefix_name + "conv") + blocks = [] + for i, stage in enumerate(stages): + input = conv if i == 0 else block + downsample_ = self._downsample( + input=input, + ch_out=input.shape[1] * 2, + name=self.prefix_name + "stage.{}.downsample".format(i)) + block = self.layer_warp( + block_func=block_func, + input=downsample_, + ch_out=32 * 2**i, + count=stage, + keep_ch=(i == 0), + scale_first=i == 0, + name=self.prefix_name + "stage.{}".format(i)) + blocks.append(block) + return blocks diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/darknet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/darknet.py new file mode 100755 index 000000000..0f5a7a542 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/darknet.py @@ -0,0 +1,174 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register + +__all__ = ['DarkNet'] + + +@register +class DarkNet(object): + """ + DarkNet, see https://pjreddie.com/darknet/yolo/ + Args: + depth (int): network depth, currently only darknet 53 is supported + norm_type (str): normalization type, 'bn' and 'sync_bn' are supported + norm_decay (float): weight decay for normalization layer weights + """ + __shared__ = ['norm_type', 'weight_prefix_name'] + + def __init__(self, + depth=53, + norm_type='bn', + norm_decay=0., + weight_prefix_name='', + freeze_at=-1): + assert depth in [53], "unsupported depth value" + self.depth = depth + self.norm_type = norm_type + self.norm_decay = norm_decay + self.depth_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)} + self.prefix_name = weight_prefix_name + self.freeze_at = freeze_at + + def _conv_norm(self, + input, + ch_out, + filter_size, + stride, + padding, + act='leaky', + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + act=None, + param_attr=ParamAttr(name=name + ".conv.weights"), + bias_attr=False) + + bn_name = name + ".bn" + bn_param_attr = ParamAttr( + regularizer=L2Decay(float(self.norm_decay)), + name=bn_name + '.scale') + bn_bias_attr = ParamAttr( + regularizer=L2Decay(float(self.norm_decay)), + name=bn_name + '.offset') + + out = fluid.layers.batch_norm( + input=conv, + act=None, + param_attr=bn_param_attr, + bias_attr=bn_bias_attr, + moving_mean_name=bn_name + '.mean', + moving_variance_name=bn_name + '.var') + + # leaky relu here has `alpha` as 0.1, can not be set by + # `act` param in fluid.layers.batch_norm above. + if act == 'leaky': + out = fluid.layers.leaky_relu(x=out, alpha=0.1) + + return out + + def _downsample(self, + input, + ch_out, + filter_size=3, + stride=2, + padding=1, + name=None): + return self._conv_norm( + input, + ch_out=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + name=name) + + def basicblock(self, input, ch_out, name=None): + conv1 = self._conv_norm( + input, + ch_out=ch_out, + filter_size=1, + stride=1, + padding=0, + name=name + ".0") + conv2 = self._conv_norm( + conv1, + ch_out=ch_out * 2, + filter_size=3, + stride=1, + padding=1, + name=name + ".1") + out = fluid.layers.elementwise_add(x=input, y=conv2, act=None) + return out + + def layer_warp(self, block_func, input, ch_out, count, name=None): + out = block_func(input, ch_out=ch_out, name='{}.0'.format(name)) + for j in six.moves.xrange(1, count): + out = block_func(out, ch_out=ch_out, name='{}.{}'.format(name, j)) + return out + + def __call__(self, input): + """ + Get the backbone of DarkNet, that is output for the 5 stages. + + Args: + input (Variable): input variable. + + Returns: + The last variables of each stage. + """ + stages, block_func = self.depth_cfg[self.depth] + stages = stages[0:5] + conv = self._conv_norm( + input=input, + ch_out=32, + filter_size=3, + stride=1, + padding=1, + name=self.prefix_name + "yolo_input") + downsample_ = self._downsample( + input=conv, + ch_out=conv.shape[1] * 2, + name=self.prefix_name + "yolo_input.downsample") + blocks = [] + for i, stage in enumerate(stages): + block = self.layer_warp( + block_func=block_func, + input=downsample_, + ch_out=32 * 2**i, + count=stage, + name=self.prefix_name + "stage.{}".format(i)) + if i < self.freeze_at: + block.stop_gradient = True + blocks.append(block) + if i < len(stages) - 1: # do not downsaple in the last stage + downsample_ = self._downsample( + input=block, + ch_out=block.shape[1] * 2, + name=self.prefix_name + "stage.{}.downsample".format(i)) + return blocks diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/efficientnet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/efficientnet.py new file mode 100755 index 000000000..c70db3649 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/efficientnet.py @@ -0,0 +1,291 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division + +import collections +import math +import re + +from paddle import fluid +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register + +__all__ = ['EfficientNet'] + +GlobalParams = collections.namedtuple('GlobalParams', [ + 'batch_norm_momentum', 'batch_norm_epsilon', 'width_coefficient', + 'depth_coefficient', 'depth_divisor' +]) + +BlockArgs = collections.namedtuple('BlockArgs', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'stride', 'se_ratio' +]) + +GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields) + + +def _decode_block_string(block_string): + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + + return BlockArgs( + kernel_size=int(options['k']), + num_repeat=int(options['r']), + input_filters=int(options['i']), + output_filters=int(options['o']), + expand_ratio=int(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s'][0])) + + +def get_model_params(scale): + block_strings = [ + 'r1_k3_s11_e1_i32_o16_se0.25', + 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', + 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', + 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + block_args = [] + for block_string in block_strings: + block_args.append(_decode_block_string(block_string)) + + params_dict = { + # width, depth + 'b0': (1.0, 1.0), + 'b1': (1.0, 1.1), + 'b2': (1.1, 1.2), + 'b3': (1.2, 1.4), + 'b4': (1.4, 1.8), + 'b5': (1.6, 2.2), + 'b6': (1.8, 2.6), + 'b7': (2.0, 3.1), + } + + w, d = params_dict[scale] + + global_params = GlobalParams( + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + width_coefficient=w, + depth_coefficient=d, + depth_divisor=8) + + return block_args, global_params + + +def round_filters(filters, global_params): + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + filters *= multiplier + min_depth = divisor + new_filters = max(min_depth, + int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +def conv2d(inputs, + num_filters, + filter_size, + stride=1, + padding='SAME', + groups=1, + use_bias=False, + name='conv2d'): + param_attr = fluid.ParamAttr(name=name + '_weights') + bias_attr = False + if use_bias: + bias_attr = fluid.ParamAttr( + name=name + '_offset', regularizer=L2Decay(0.)) + feats = fluid.layers.conv2d( + inputs, + num_filters, + filter_size, + groups=groups, + name=name, + stride=stride, + padding=padding, + param_attr=param_attr, + bias_attr=bias_attr) + return feats + + +def batch_norm(inputs, momentum, eps, name=None): + param_attr = fluid.ParamAttr(name=name + '_scale', regularizer=L2Decay(0.)) + bias_attr = fluid.ParamAttr(name=name + '_offset', regularizer=L2Decay(0.)) + return fluid.layers.batch_norm( + input=inputs, + momentum=momentum, + epsilon=eps, + name=name, + moving_mean_name=name + '_mean', + moving_variance_name=name + '_variance', + param_attr=param_attr, + bias_attr=bias_attr) + + +def mb_conv_block(inputs, + input_filters, + output_filters, + expand_ratio, + kernel_size, + stride, + momentum, + eps, + se_ratio=None, + name=None): + feats = inputs + num_filters = input_filters * expand_ratio + + if expand_ratio != 1: + feats = conv2d(feats, num_filters, 1, name=name + '_expand_conv') + feats = batch_norm(feats, momentum, eps, name=name + '_bn0') + feats = fluid.layers.swish(feats) + + feats = conv2d( + feats, + num_filters, + kernel_size, + stride, + groups=num_filters, + name=name + '_depthwise_conv') + feats = batch_norm(feats, momentum, eps, name=name + '_bn1') + feats = fluid.layers.swish(feats) + + if se_ratio is not None: + filter_squeezed = max(1, int(input_filters * se_ratio)) + squeezed = fluid.layers.pool2d( + feats, pool_type='avg', global_pooling=True) + squeezed = conv2d( + squeezed, + filter_squeezed, + 1, + use_bias=True, + name=name + '_se_reduce') + squeezed = fluid.layers.swish(squeezed) + squeezed = conv2d( + squeezed, num_filters, 1, use_bias=True, name=name + '_se_expand') + feats = feats * fluid.layers.sigmoid(squeezed) + + feats = conv2d(feats, output_filters, 1, name=name + '_project_conv') + feats = batch_norm(feats, momentum, eps, name=name + '_bn2') + + if stride == 1 and input_filters == output_filters: + feats = fluid.layers.elementwise_add(feats, inputs) + + return feats + + +@register +class EfficientNet(object): + """ + EfficientNet, see https://arxiv.org/abs/1905.11946 + + Args: + scale (str): compounding scale factor, 'b0' - 'b7'. + use_se (bool): use squeeze and excite module. + norm_type (str): normalization type, 'bn' and 'sync_bn' are supported + """ + __shared__ = ['norm_type'] + + def __init__(self, scale='b0', use_se=True, norm_type='bn'): + assert scale in ['b' + str(i) for i in range(8)], \ + "valid scales are b0 - b7" + assert norm_type in ['bn', 'sync_bn'], \ + "only 'bn' and 'sync_bn' are supported" + + super(EfficientNet, self).__init__() + self.norm_type = norm_type + self.scale = scale + self.use_se = use_se + + def __call__(self, inputs): + blocks_args, global_params = get_model_params(self.scale) + momentum = global_params.batch_norm_momentum + eps = global_params.batch_norm_epsilon + + num_filters = round_filters(32, global_params) + feats = conv2d( + inputs, + num_filters=num_filters, + filter_size=3, + stride=2, + name='_conv_stem') + feats = batch_norm(feats, momentum=momentum, eps=eps, name='_bn0') + feats = fluid.layers.swish(feats) + + layer_count = 0 + feature_maps = [] + + for b, block_arg in enumerate(blocks_args): + for r in range(block_arg.num_repeat): + input_filters = round_filters(block_arg.input_filters, + global_params) + output_filters = round_filters(block_arg.output_filters, + global_params) + kernel_size = block_arg.kernel_size + stride = block_arg.stride + se_ratio = None + if self.use_se: + se_ratio = block_arg.se_ratio + + if r > 0: + input_filters = output_filters + stride = 1 + + feats = mb_conv_block( + feats, + input_filters, + output_filters, + block_arg.expand_ratio, + kernel_size, + stride, + momentum, + eps, + se_ratio=se_ratio, + name='_blocks.{}.'.format(layer_count)) + + layer_count += 1 + + feature_maps.append(feats) + + return list(feature_maps[i] for i in [2, 4, 6]) diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/faceboxnet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/faceboxnet.py new file mode 100755 index 000000000..77a06f141 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/faceboxnet.py @@ -0,0 +1,359 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr + +from ppdet.core.workspace import register + +__all__ = ['FaceBoxNet'] + + +@register +class FaceBoxNet(object): + """ + FaceBoxes, see https://https://arxiv.org/abs/1708.05234 + + Args: + with_extra_blocks (bool): whether or not extra blocks should be added + lite_edition (bool): whether or not is FaceBoxes-lite + """ + + def __init__(self, with_extra_blocks=True, lite_edition=False): + super(FaceBoxNet, self).__init__() + + self.with_extra_blocks = with_extra_blocks + self.lite_edition = lite_edition + + def __call__(self, input): + if self.lite_edition: + return self._simplified_edition(input) + else: + return self._original_edition(input) + + def _simplified_edition(self, input): + conv_1_1 = self._conv_norm_crelu( + input=input, + num_filters=8, + filter_size=3, + stride=2, + padding=1, + act='relu', + name="conv_1_1") + + conv_1_2 = self._conv_norm_crelu( + input=conv_1_1, + num_filters=24, + filter_size=3, + stride=2, + padding=1, + act='relu', + name="conv_1_2") + + pool1 = fluid.layers.pool2d( + input=conv_1_2, + pool_size=3, + pool_padding=1, + pool_type='avg', + name="pool_1") + + conv_2_1 = self._conv_norm( + input=pool1, + num_filters=48, + filter_size=3, + stride=2, + padding=1, + act='relu', + name="conv_2_1") + + conv_2_2 = self._conv_norm( + input=conv_2_1, + num_filters=64, + filter_size=1, + stride=1, + padding=0, + act='relu', + name="conv_2_2") + + conv_inception = conv_2_2 + + for i in range(3): + conv_inception = self._inceptionA(conv_inception, i) + + layers = [] + layers.append(conv_inception) + + conv_3_1 = self._conv_norm( + input=conv_inception, + num_filters=128, + filter_size=1, + stride=1, + padding=0, + act='relu', + name="conv_3_1") + + conv_3_2 = self._conv_norm( + input=conv_3_1, + num_filters=256, + filter_size=3, + stride=2, + padding=1, + act='relu', + name="conv_3_2") + + layers.append(conv_3_2) + + if not self.with_extra_blocks: + return layers[-1] + return layers[-2], layers[-1] + + def _original_edition(self, input): + conv_1 = self._conv_norm_crelu( + input=input, + num_filters=24, + filter_size=7, + stride=4, + padding=3, + act='relu', + name="conv_1") + + pool_1 = fluid.layers.pool2d( + input=conv_1, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max', + name="pool_1") + + conv_2 = self._conv_norm_crelu( + input=pool_1, + num_filters=64, + filter_size=5, + stride=2, + padding=2, + act='relu', + name="conv_2") + + pool_2 = fluid.layers.pool2d( + input=conv_1, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max', + name="pool_2") + + conv_inception = pool_2 + + for i in range(3): + conv_inception = self._inceptionA(conv_inception, i) + + layers = [] + layers.append(conv_inception) + + conv_3_1 = self._conv_norm( + input=conv_inception, + num_filters=128, + filter_size=1, + stride=1, + padding=0, + act='relu', + name="conv_3_1") + + conv_3_2 = self._conv_norm( + input=conv_3_1, + num_filters=256, + filter_size=3, + stride=2, + padding=1, + act='relu', + name="conv_3_2") + + layers.append(conv_3_2) + + conv_4_1 = self._conv_norm( + input=conv_3_2, + num_filters=128, + filter_size=1, + stride=1, + padding=0, + act='relu', + name="conv_4_1") + + conv_4_2 = self._conv_norm( + input=conv_4_1, + num_filters=256, + filter_size=3, + stride=2, + padding=1, + act='relu', + name="conv_4_2") + + layers.append(conv_4_2) + + if not self.with_extra_blocks: + return layers[-1] + + return layers[-3], layers[-2], layers[-1] + + def _conv_norm(self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + act='relu', + use_cudnn=True, + name=None): + parameter_attr = ParamAttr( + learning_rate=0.1, + initializer=fluid.initializer.MSRA(), + name=name + "_weights") + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=parameter_attr, + bias_attr=False) + return fluid.layers.batch_norm(input=conv, act=act) + + def _conv_norm_crelu(self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + act='relu', + use_cudnn=True, + name=None): + parameter_attr = ParamAttr( + learning_rate=0.1, + initializer=fluid.initializer.MSRA(), + name=name + "_weights") + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=parameter_attr, + bias_attr=False) + + conv_a = fluid.layers.batch_norm(input=conv, act=act) + conv_b = fluid.layers.scale(conv_a, -1) + + concat = fluid.layers.concat([conv_a, conv_b], axis=1) + + return concat + + def _pooling_block(self, + conv, + pool_size, + pool_stride, + pool_padding=0, + ceil_mode=True): + pool = fluid.layers.pool2d( + input=conv, + pool_size=pool_size, + pool_type='max', + pool_stride=pool_stride, + pool_padding=pool_padding, + ceil_mode=ceil_mode) + return pool + + def _inceptionA(self, data, idx): + idx = str(idx) + + pool1 = fluid.layers.pool2d( + input=data, + pool_size=3, + pool_padding=1, + pool_type='avg', + name='inceptionA_' + idx + '_pool1') + conv1 = self._conv_norm( + input=pool1, + filter_size=1, + num_filters=32, + stride=1, + padding=0, + act='relu', + name='inceptionA_' + idx + '_conv1') + + conv2 = self._conv_norm( + input=data, + filter_size=1, + num_filters=32, + stride=1, + padding=0, + act='relu', + name='inceptionA_' + idx + '_conv2') + + conv3 = self._conv_norm( + input=data, + filter_size=1, + num_filters=24, + stride=1, + padding=0, + act='relu', + name='inceptionA_' + idx + '_conv3_1') + conv3 = self._conv_norm( + input=conv3, + filter_size=3, + num_filters=32, + stride=1, + padding=1, + act='relu', + name='inceptionA_' + idx + '_conv3_2') + + conv4 = self._conv_norm( + input=data, + filter_size=1, + num_filters=24, + stride=1, + padding=0, + act='relu', + name='inceptionA_' + idx + '_conv4_1') + conv4 = self._conv_norm( + input=conv4, + filter_size=3, + num_filters=32, + stride=1, + padding=1, + act='relu', + name='inceptionA_' + idx + '_conv4_2') + conv4 = self._conv_norm( + input=conv4, + filter_size=3, + num_filters=32, + stride=1, + padding=1, + act='relu', + name='inceptionA_' + idx + '_conv4_3') + + concat = fluid.layers.concat([conv1, conv2, conv3, conv4], axis=1) + + return concat diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/fpn.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/fpn.py new file mode 100755 index 000000000..a89730f4b --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/fpn.py @@ -0,0 +1,239 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import copy +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Xavier +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import ConvNorm + +__all__ = ['FPN'] + + +@register +class FPN(object): + """ + Feature Pyramid Network, see https://arxiv.org/abs/1612.03144 + + Args: + num_chan (int): number of feature channels + min_level (int): lowest level of the backbone feature map to use + max_level (int): highest level of the backbone feature map to use + spatial_scale (list): feature map scaling factor + has_extra_convs (bool): whether has extral convolutions in higher levels + norm_type (str|None): normalization type, 'bn'/'sync_bn'/'affine_channel' + norm_decay (float): weight decay for normalization layer weights. + reverse_out (bool): whether to flip the output. + """ + __shared__ = ['norm_type', 'freeze_norm'] + + def __init__(self, + num_chan=256, + min_level=2, + max_level=6, + spatial_scale=[1. / 32., 1. / 16., 1. / 8., 1. / 4.], + has_extra_convs=False, + norm_type=None, + norm_decay=0., + freeze_norm=False, + use_c5=True, + reverse_out=False): + self.freeze_norm = freeze_norm + self.num_chan = num_chan + self.min_level = min_level + self.max_level = max_level + self.spatial_scale = spatial_scale + self.has_extra_convs = has_extra_convs + self.norm_type = norm_type + self.norm_decay = norm_decay + self.use_c5 = use_c5 + self.reverse_out = reverse_out + + def _add_topdown_lateral(self, body_name, body_input, upper_output): + lateral_name = 'fpn_inner_' + body_name + '_lateral' + topdown_name = 'fpn_topdown_' + body_name + fan = body_input.shape[1] + if self.norm_type: + initializer = Xavier(fan_out=fan) + lateral = ConvNorm( + body_input, + self.num_chan, + 1, + initializer=initializer, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + name=lateral_name, + norm_name=lateral_name) + else: + lateral = fluid.layers.conv2d( + body_input, + self.num_chan, + 1, + param_attr=ParamAttr( + name=lateral_name + "_w", initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name=lateral_name + "_b", + learning_rate=2., + regularizer=L2Decay(0.)), + name=lateral_name) + if body_input.shape[2] == -1 and body_input.shape[3] == -1: + topdown = fluid.layers.resize_nearest( + upper_output, scale=2., name=topdown_name) + else: + topdown = fluid.layers.resize_nearest( + upper_output, + out_shape=[body_input.shape[2], body_input.shape[3]], + name=topdown_name) + + return lateral + topdown + + def get_output(self, body_dict): + """ + Add FPN onto backbone. + + Args: + body_dict(OrderedDict): Dictionary of variables and each element is the + output of backbone. + + Return: + fpn_dict(OrderedDict): A dictionary represents the output of FPN with + their name. + spatial_scale(list): A list of multiplicative spatial scale factor. + """ + spatial_scale = copy.deepcopy(self.spatial_scale) + body_name_list = list(body_dict.keys())[::-1] + num_backbone_stages = len(body_name_list) + self.fpn_inner_output = [[] for _ in range(num_backbone_stages)] + fpn_inner_name = 'fpn_inner_' + body_name_list[0] + body_input = body_dict[body_name_list[0]] + fan = body_input.shape[1] + if self.norm_type: + initializer = Xavier(fan_out=fan) + self.fpn_inner_output[0] = ConvNorm( + body_input, + self.num_chan, + 1, + initializer=initializer, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + name=fpn_inner_name, + norm_name=fpn_inner_name) + else: + self.fpn_inner_output[0] = fluid.layers.conv2d( + body_input, + self.num_chan, + 1, + param_attr=ParamAttr( + name=fpn_inner_name + "_w", + initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name=fpn_inner_name + "_b", + learning_rate=2., + regularizer=L2Decay(0.)), + name=fpn_inner_name) + for i in range(1, num_backbone_stages): + body_name = body_name_list[i] + body_input = body_dict[body_name] + top_output = self.fpn_inner_output[i - 1] + fpn_inner_single = self._add_topdown_lateral(body_name, body_input, + top_output) + self.fpn_inner_output[i] = fpn_inner_single + fpn_dict = {} + fpn_name_list = [] + for i in range(num_backbone_stages): + fpn_name = 'fpn_' + body_name_list[i] + fan = self.fpn_inner_output[i].shape[1] * 3 * 3 + if self.norm_type: + initializer = Xavier(fan_out=fan) + fpn_output = ConvNorm( + self.fpn_inner_output[i], + self.num_chan, + 3, + initializer=initializer, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + name=fpn_name, + norm_name=fpn_name) + else: + fpn_output = fluid.layers.conv2d( + self.fpn_inner_output[i], + self.num_chan, + filter_size=3, + padding=1, + param_attr=ParamAttr( + name=fpn_name + "_w", initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name=fpn_name + "_b", + learning_rate=2., + regularizer=L2Decay(0.)), + name=fpn_name) + fpn_dict[fpn_name] = fpn_output + fpn_name_list.append(fpn_name) + if not self.has_extra_convs and self.max_level - self.min_level == len( + spatial_scale): + body_top_name = fpn_name_list[0] + body_top_extension = fluid.layers.pool2d( + fpn_dict[body_top_name], + 1, + 'max', + pool_stride=2, + name=body_top_name + '_subsampled_2x') + fpn_dict[body_top_name + '_subsampled_2x'] = body_top_extension + fpn_name_list.insert(0, body_top_name + '_subsampled_2x') + spatial_scale.insert(0, spatial_scale[0] * 0.5) + # Coarser FPN levels introduced for RetinaNet + highest_backbone_level = self.min_level + len(spatial_scale) - 1 + if self.has_extra_convs and self.max_level > highest_backbone_level: + if self.use_c5: + fpn_blob = body_dict[body_name_list[0]] + else: + fpn_blob = fpn_dict[fpn_name_list[0]] + for i in range(highest_backbone_level + 1, self.max_level + 1): + fpn_blob_in = fpn_blob + fpn_name = 'fpn_' + str(i) + if i > highest_backbone_level + 1: + fpn_blob_in = fluid.layers.relu(fpn_blob) + fan = fpn_blob_in.shape[1] * 3 * 3 + fpn_blob = fluid.layers.conv2d( + input=fpn_blob_in, + num_filters=self.num_chan, + filter_size=3, + stride=2, + padding=1, + param_attr=ParamAttr( + name=fpn_name + "_w", initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name=fpn_name + "_b", + learning_rate=2., + regularizer=L2Decay(0.)), + name=fpn_name) + fpn_dict[fpn_name] = fpn_blob + fpn_name_list.insert(0, fpn_name) + spatial_scale.insert(0, spatial_scale[0] * 0.5) + + if self.reverse_out: + fpn_name_list = fpn_name_list[::-1] + res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list]) + return res_dict, spatial_scale diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/gc_block.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/gc_block.py new file mode 100755 index 000000000..fbd374223 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/gc_block.py @@ -0,0 +1,124 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import paddle +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +from paddle.fluid.initializer import ConstantInitializer + + +def spatial_pool(x, pooling_type, name): + _, channel, height, width = x.shape + if pooling_type == 'att': + input_x = x + # [N, 1, C, H * W] + input_x = fluid.layers.reshape(input_x, shape=(0, 1, channel, -1)) + context_mask = fluid.layers.conv2d( + input=x, + num_filters=1, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=ParamAttr(name=name + "_bias")) + # [N, 1, H * W] + context_mask = fluid.layers.reshape(context_mask, shape=(0, 0, -1)) + # [N, 1, H * W] + context_mask = fluid.layers.softmax(context_mask, axis=2) + # [N, 1, H * W, 1] + context_mask = fluid.layers.reshape(context_mask, shape=(0, 0, -1, 1)) + # [N, 1, C, 1] + context = fluid.layers.matmul(input_x, context_mask) + # [N, C, 1, 1] + context = fluid.layers.reshape(context, shape=(0, channel, 1, 1)) + else: + # [N, C, 1, 1] + context = fluid.layers.pool2d( + input=x, pool_type='avg', global_pooling=True) + return context + + +def channel_conv(input, inner_ch, out_ch, name): + conv = fluid.layers.conv2d( + input=input, + num_filters=inner_ch, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=name + "_conv1_weights"), + bias_attr=ParamAttr(name=name + "_conv1_bias"), + name=name + "_conv1", ) + conv = fluid.layers.layer_norm( + conv, + begin_norm_axis=1, + param_attr=ParamAttr(name=name + "_ln_weights"), + bias_attr=ParamAttr(name=name + "_ln_bias"), + act="relu", + name=name + "_ln") + + conv = fluid.layers.conv2d( + input=conv, + num_filters=out_ch, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr( + name=name + "_conv2_weights", + initializer=ConstantInitializer(value=0.0), ), + bias_attr=ParamAttr( + name=name + "_conv2_bias", + initializer=ConstantInitializer(value=0.0), ), + name=name + "_conv2") + return conv + + +def add_gc_block(x, + ratio=1.0 / 16, + pooling_type='att', + fusion_types=['channel_add'], + name=None): + ''' + GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond, see https://arxiv.org/abs/1904.11492 + Args: + ratio (float): channel reduction ratio + pooling_type (str): pooling type, support att and avg + fusion_types (list): fusion types, support channel_add and channel_mul + name (str): prefix name of gc block + ''' + assert pooling_type in ['avg', 'att'] + assert isinstance(fusion_types, (list, tuple)) + valid_fusion_types = ['channel_add', 'channel_mul'] + assert all([f in valid_fusion_types for f in fusion_types]) + assert len(fusion_types) > 0, 'at least one fusion should be used' + + inner_ch = int(ratio * x.shape[1]) + out_ch = x.shape[1] + context = spatial_pool(x, pooling_type, name + "_spatial_pool") + out = x + if 'channel_mul' in fusion_types: + inner_out = channel_conv(context, inner_ch, out_ch, name + "_mul") + channel_mul_term = fluid.layers.sigmoid(inner_out) + out = out * channel_mul_term + + if 'channel_add' in fusion_types: + channel_add_term = channel_conv(context, inner_ch, out_ch, + name + "_add") + out = out + channel_add_term + + return out diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/ghostnet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/ghostnet.py new file mode 100755 index 000000000..b40ca84e3 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/ghostnet.py @@ -0,0 +1,361 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from collections import OrderedDict + +from ppdet.core.workspace import register + +__all__ = ["GhostNet"] + + +@register +class GhostNet(object): + """ + scale (float): scaling factor for convolution groups proportion of GhostNet. + feature_maps (list): index of stages whose feature maps are returned. + conv_decay (float): weight decay for convolution layer weights. + extra_block_filters (list): number of filter for each extra block. + lr_mult_list (list): learning rate ratio of different blocks, lower learning rate ratio + is need for pretrained model got using distillation(default as + [1.0, 1.0, 1.0, 1.0, 1.0]). + """ + + def __init__( + self, + scale, + feature_maps=[5, 6, 7, 8, 9, 10], + conv_decay=0.00001, + extra_block_filters=[[256, 512], [128, 256], [128, 256], [64, 128]], + lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], + freeze_norm=False): + self.scale = scale + self.feature_maps = feature_maps + self.extra_block_filters = extra_block_filters + self.end_points = [] + self.block_stride = 0 + self.conv_decay = conv_decay + self.lr_mult_list = lr_mult_list + self.freeze_norm = freeze_norm + self.curr_stage = 0 + + self.cfgs = [ + # k, t, c, se, s + [3, 16, 16, 0, 1], + [3, 48, 24, 0, 2], + [3, 72, 24, 0, 1], + [5, 72, 40, 1, 2], + [5, 120, 40, 1, 1], + [3, 240, 80, 0, 2], + [3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 1, 1], + [3, 672, 112, 1, 1], + [5, 672, 160, 1, 2], + [5, 960, 160, 0, 1], + [5, 960, 160, 1, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 1, 1] + ] + + def _conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + lr_idx = self.curr_stage // 3 + lr_idx = min(lr_idx, len(self.lr_mult_list) - 1) + lr_mult = self.lr_mult_list[lr_idx] + norm_lr = 0. if self.freeze_norm else lr_mult + + x = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr( + regularizer=L2Decay(self.conv_decay), + learning_rate=lr_mult, + initializer=fluid.initializer.MSRA(), + name=name + "_weights"), + bias_attr=False) + bn_name = name + "_bn" + x = fluid.layers.batch_norm( + input=x, + act=act, + param_attr=ParamAttr( + name=bn_name + "_scale", + learning_rate=norm_lr, + regularizer=L2Decay(0.0)), + bias_attr=ParamAttr( + name=bn_name + "_offset", + learning_rate=norm_lr, + regularizer=L2Decay(0.0)), + moving_mean_name=bn_name + "_mean", + moving_variance_name=name + "_variance") + return x + + def se_block(self, input, num_channels, reduction_ratio=4, name=None): + lr_idx = self.curr_stage // 3 + lr_idx = min(lr_idx, len(self.lr_mult_list) - 1) + lr_mult = self.lr_mult_list[lr_idx] + pool = fluid.layers.pool2d( + input=input, pool_type='avg', global_pooling=True, use_cudnn=False) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + squeeze = fluid.layers.fc( + input=pool, + size=num_channels // reduction_ratio, + act='relu', + param_attr=ParamAttr( + learning_rate=lr_mult, + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_1_weights'), + bias_attr=ParamAttr( + name=name + '_1_offset', learning_rate=lr_mult)) + stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) + excitation = fluid.layers.fc( + input=squeeze, + size=num_channels, + act=None, + param_attr=ParamAttr( + learning_rate=lr_mult, + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_2_weights'), + bias_attr=ParamAttr( + name=name + '_2_offset', learning_rate=lr_mult)) + excitation = fluid.layers.clip(x=excitation, min=0, max=1) + se_scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return se_scale + + def depthwise_conv(self, + input, + output, + kernel_size, + stride=1, + relu=False, + name=None): + return self._conv_bn_layer( + input=input, + num_filters=output, + filter_size=kernel_size, + stride=stride, + groups=input.shape[1], + act="relu" if relu else None, + name=name + "_depthwise") + + def ghost_module(self, + input, + output, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + relu=True, + name=None): + self.output = output + init_channels = int(math.ceil(output / ratio)) + new_channels = int(init_channels * (ratio - 1)) + primary_conv = self._conv_bn_layer( + input=input, + num_filters=init_channels, + filter_size=kernel_size, + stride=stride, + groups=1, + act="relu" if relu else None, + name=name + "_primary_conv") + cheap_operation = self._conv_bn_layer( + input=primary_conv, + num_filters=new_channels, + filter_size=dw_size, + stride=1, + groups=init_channels, + act="relu" if relu else None, + name=name + "_cheap_operation") + out = fluid.layers.concat([primary_conv, cheap_operation], axis=1) + return out + + def ghost_bottleneck(self, + input, + hidden_dim, + output, + kernel_size, + stride, + use_se, + name=None): + inp_channels = input.shape[1] + x = self.ghost_module( + input=input, + output=hidden_dim, + kernel_size=1, + stride=1, + relu=True, + name=name + "_ghost_module_1") + + if self.block_stride == 4 and stride == 2: + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(x) + + if stride == 2: + x = self.depthwise_conv( + input=x, + output=hidden_dim, + kernel_size=kernel_size, + stride=stride, + relu=False, + name=name + "_depthwise") + if use_se: + x = self.se_block( + input=x, num_channels=hidden_dim, name=name + "_se") + x = self.ghost_module( + input=x, + output=output, + kernel_size=1, + relu=False, + name=name + "_ghost_module_2") + if stride == 1 and inp_channels == output: + shortcut = input + else: + shortcut = self.depthwise_conv( + input=input, + output=inp_channels, + kernel_size=kernel_size, + stride=stride, + relu=False, + name=name + "_shortcut_depthwise") + shortcut = self._conv_bn_layer( + input=shortcut, + num_filters=output, + filter_size=1, + stride=1, + groups=1, + act=None, + name=name + "_shortcut_conv") + return fluid.layers.elementwise_add(x=x, y=shortcut, axis=-1) + + def _extra_block_dw(self, + input, + num_filters1, + num_filters2, + stride, + name=None): + pointwise_conv = self._conv_bn_layer( + input=input, + filter_size=1, + num_filters=int(num_filters1), + stride=1, + act='relu6', + name=name + "_extra1") + depthwise_conv = self._conv_bn_layer( + input=pointwise_conv, + filter_size=3, + num_filters=int(num_filters2), + stride=stride, + groups=int(num_filters1), + act='relu6', + name=name + "_extra2_dw") + normal_conv = self._conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2), + stride=1, + act='relu6', + name=name + "_extra2_sep") + return normal_conv + + def _make_divisible(self, v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + def __call__(self, input): + # build first layer + output_channel = int(self._make_divisible(16 * self.scale, 4)) + x = self._conv_bn_layer( + input=input, + num_filters=output_channel, + filter_size=3, + stride=2, + groups=1, + act="relu", + name="conv1") + # build inverted residual blocks + idx = 0 + for k, exp_size, c, use_se, s in self.cfgs: + if s == 2: + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(x) + output_channel = int(self._make_divisible(c * self.scale, 4)) + hidden_channel = int(self._make_divisible(exp_size * self.scale, 4)) + x = self.ghost_bottleneck( + input=x, + hidden_dim=hidden_channel, + output=output_channel, + kernel_size=k, + stride=s, + use_se=use_se, + name="_ghostbottleneck_" + str(idx)) + idx += 1 + self.curr_stage += 1 + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(conv) + + # extra block + # check whether conv_extra is needed + if self.block_stride < max(self.feature_maps): + conv_extra = self._conv_bn_layer( + x, + num_filters=self._make_divisible(self.scale * self.cfgs[-1][1]), + filter_size=1, + stride=1, + groups=1, + act='relu6', + name='conv' + str(idx + 2)) + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(conv_extra) + idx += 1 + for block_filter in self.extra_block_filters: + conv_extra = self._extra_block_dw(conv_extra, block_filter[0], + block_filter[1], 2, + 'conv' + str(idx + 2)) + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(conv_extra) + idx += 1 + + return OrderedDict([('ghost_{}'.format(idx), feat) + for idx, feat in enumerate(self.end_points)]) + return res diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/hourglass.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/hourglass.py new file mode 100755 index 000000000..b38f79bb4 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/hourglass.py @@ -0,0 +1,275 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Uniform + +import functools +from ppdet.core.workspace import register +from .resnet import ResNet +import math + +__all__ = ['Hourglass'] + + +def kaiming_init(input, filter_size): + fan_in = input.shape[1] + std = (1.0 / (fan_in * filter_size * filter_size))**0.5 + return Uniform(0. - std, std) + + +def _conv_norm(x, + k, + out_dim, + stride=1, + pad=0, + groups=None, + with_bn=True, + bn_act=None, + ind=None, + name=None): + conv_name = "_conv" if ind is None else "_conv" + str(ind) + bn_name = "_bn" if ind is None else "_bn" + str(ind) + + conv = fluid.layers.conv2d( + input=x, + filter_size=k, + num_filters=out_dim, + stride=stride, + padding=pad, + groups=groups, + param_attr=ParamAttr( + name=name + conv_name + "_weight", initializer=kaiming_init(x, k)), + bias_attr=ParamAttr( + name=name + conv_name + "_bias", initializer=kaiming_init(x, k)) + if not with_bn else False, + name=name + '_output') + if with_bn: + pattr = ParamAttr(name=name + bn_name + '_weight') + battr = ParamAttr(name=name + bn_name + '_bias') + out = fluid.layers.batch_norm( + input=conv, + act=bn_act, + name=name + '_bn_output', + param_attr=pattr, + bias_attr=battr, + moving_mean_name=name + bn_name + '_running_mean', + moving_variance_name=name + bn_name + + '_running_var') if with_bn else conv + else: + out = fluid.layers.relu(conv) + return out + + +def residual_block(x, out_dim, k=3, stride=1, name=None): + p = (k - 1) // 2 + conv1 = _conv_norm( + x, k, out_dim, pad=p, stride=stride, bn_act='relu', ind=1, name=name) + conv2 = _conv_norm(conv1, k, out_dim, pad=p, ind=2, name=name) + + skip = _conv_norm( + x, 1, out_dim, stride=stride, + name=name + '_skip') if stride != 1 or x.shape[1] != out_dim else x + return fluid.layers.elementwise_add( + x=skip, y=conv2, act='relu', name=name + "_add") + + +def fire_block(x, out_dim, sr=2, stride=1, name=None): + conv1 = _conv_norm(x, 1, out_dim // sr, ind=1, name=name) + conv_1x1 = fluid.layers.conv2d( + conv1, + filter_size=1, + num_filters=out_dim // 2, + stride=stride, + param_attr=ParamAttr( + name=name + "_conv_1x1_weight", initializer=kaiming_init(conv1, 1)), + bias_attr=False, + name=name + '_conv_1x1') + conv_3x3 = fluid.layers.conv2d( + conv1, + filter_size=3, + num_filters=out_dim // 2, + stride=stride, + padding=1, + groups=out_dim // sr, + param_attr=ParamAttr( + name=name + "_conv_3x3_weight", initializer=kaiming_init(conv1, 3)), + bias_attr=False, + name=name + '_conv_3x3', + use_cudnn=False) + conv2 = fluid.layers.concat( + [conv_1x1, conv_3x3], axis=1, name=name + '_conv2') + pattr = ParamAttr(name=name + '_bn2_weight') + battr = ParamAttr(name=name + '_bn2_bias') + + bn2 = fluid.layers.batch_norm( + input=conv2, + name=name + '_bn2', + param_attr=pattr, + bias_attr=battr, + moving_mean_name=name + '_bn2_running_mean', + moving_variance_name=name + '_bn2_running_var') + + if stride == 1 and x.shape[1] == out_dim: + return fluid.layers.elementwise_add( + x=bn2, y=x, act='relu', name=name + "_add_relu") + else: + return fluid.layers.relu(bn2, name="_relu") + + +def make_layer(x, in_dim, out_dim, modules, block, name=None): + layers = block(x, out_dim, name=name + '_0') + for i in range(1, modules): + layers = block(layers, out_dim, name=name + '_' + str(i)) + return layers + + +def make_hg_layer(x, in_dim, out_dim, modules, block, name=None): + layers = block(x, out_dim, stride=2, name=name + '_0') + for i in range(1, modules): + layers = block(layers, out_dim, name=name + '_' + str(i)) + return layers + + +def make_layer_revr(x, in_dim, out_dim, modules, block, name=None): + for i in range(modules - 1): + x = block(x, in_dim, name=name + '_' + str(i)) + layers = block(x, out_dim, name=name + '_' + str(modules - 1)) + return layers + + +def make_unpool_layer(x, dim, name=None): + pattr = ParamAttr(name=name + '_weight', initializer=kaiming_init(x, 4)) + battr = ParamAttr(name=name + '_bias', initializer=kaiming_init(x, 4)) + layer = fluid.layers.conv2d_transpose( + input=x, + num_filters=dim, + filter_size=4, + stride=2, + padding=1, + param_attr=pattr, + bias_attr=battr) + return layer + + +@register +class Hourglass(object): + """ + Hourglass Network, see https://arxiv.org/abs/1603.06937 + Args: + stack (int): stack of hourglass, 2 by default + dims (list): dims of each level in hg_module + modules (list): num of modules in each level + """ + __shared__ = ['stack'] + + def __init__(self, + stack=2, + dims=[256, 256, 384, 384, 512], + modules=[2, 2, 2, 2, 4], + block_name='fire'): + super(Hourglass, self).__init__() + self.stack = stack + assert len(dims) == len(modules), \ + "Expected len of dims equal to len of modules, Receiced len of "\ + "dims: {}, len of modules: {}".format(len(dims), len(modules)) + self.dims = dims + self.modules = modules + self.num_level = len(dims) - 1 + block_dict = {'fire': fire_block} + self.block = block_dict[block_name] + + def __call__(self, input, name='hg'): + inter = self.pre(input, name + '_pre') + cnvs = [] + for ind in range(self.stack): + hg = self.hg_module( + inter, + self.num_level, + self.dims, + self.modules, + name=name + '_hgs_' + str(ind)) + cnv = _conv_norm( + hg, + 3, + 256, + bn_act='relu', + pad=1, + name=name + '_cnvs_' + str(ind)) + cnvs.append(cnv) + + if ind < self.stack - 1: + inter = _conv_norm( + inter, 1, 256, name=name + '_inters__' + + str(ind)) + _conv_norm( + cnv, 1, 256, name=name + '_cnvs__' + str(ind)) + inter = fluid.layers.relu(inter) + inter = residual_block( + inter, 256, name=name + '_inters_' + str(ind)) + return cnvs + + def pre(self, x, name=None): + conv = _conv_norm( + x, 7, 128, stride=2, pad=3, bn_act='relu', name=name + '_0') + res1 = residual_block(conv, 256, stride=2, name=name + '_1') + res2 = residual_block(res1, 256, stride=2, name=name + '_2') + return res2 + + def hg_module(self, + x, + n=4, + dims=[256, 256, 384, 384, 512], + modules=[2, 2, 2, 2, 4], + make_up_layer=make_layer, + make_hg_layer=make_hg_layer, + make_low_layer=make_layer, + make_hg_layer_revr=make_layer_revr, + make_unpool_layer=make_unpool_layer, + name=None): + curr_mod = modules[0] + next_mod = modules[1] + curr_dim = dims[0] + next_dim = dims[1] + up1 = make_up_layer( + x, curr_dim, curr_dim, curr_mod, self.block, name=name + '_up1') + max1 = x + low1 = make_hg_layer( + max1, curr_dim, next_dim, curr_mod, self.block, name=name + '_low1') + low2 = self.hg_module( + low1, + n - 1, + dims[1:], + modules[1:], + make_up_layer=make_up_layer, + make_hg_layer=make_hg_layer, + make_low_layer=make_low_layer, + make_hg_layer_revr=make_hg_layer_revr, + make_unpool_layer=make_unpool_layer, + name=name + '_low2') if n > 1 else make_low_layer( + low1, + next_dim, + next_dim, + next_mod, + self.block, + name=name + '_low2') + low3 = make_hg_layer_revr( + low2, next_dim, curr_dim, curr_mod, self.block, name=name + '_low3') + up2 = make_unpool_layer(low3, curr_dim, name=name + '_up2') + merg = fluid.layers.elementwise_add(x=up1, y=up2, name=name + '_merg') + return merg diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/hrfpn.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/hrfpn.py new file mode 100755 index 000000000..174c6d10d --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/hrfpn.py @@ -0,0 +1,132 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Xavier +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register + +__all__ = ['HRFPN'] + + +@register +class HRFPN(object): + """ + HRNet, see https://arxiv.org/abs/1908.07919 + + Args: + num_chan (int): number of feature channels + pooling_type (str): pooling type of downsampling + share_conv (bool): whethet to share conv for different layers' reduction + spatial_scale (list): feature map scaling factor + """ + + def __init__( + self, + num_chan=256, + pooling_type="avg", + share_conv=False, + spatial_scale=[1. / 64, 1. / 32, 1. / 16, 1. / 8, 1. / 4], ): + self.num_chan = num_chan + self.pooling_type = pooling_type + self.share_conv = share_conv + self.spatial_scale = spatial_scale + return + + def get_output(self, body_dict): + num_out = len(self.spatial_scale) + body_name_list = list(body_dict.keys()) + + num_backbone_stages = len(body_name_list) + + outs = [] + outs.append(body_dict[body_name_list[0]]) + + # resize + for i in range(1, len(body_dict)): + resized = self.resize_input_tensor(body_dict[body_name_list[i]], + outs[0], 2**i) + outs.append(resized) + + # concat + out = fluid.layers.concat(outs, axis=1) + + # reduction + out = fluid.layers.conv2d( + input=out, + num_filters=self.num_chan, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name='hrfpn_reduction_weights'), + bias_attr=False) + + # conv + outs = [out] + for i in range(1, num_out): + outs.append( + self.pooling( + out, size=2**i, stride=2**i, + pooling_type=self.pooling_type)) + outputs = [] + + for i in range(num_out): + conv_name = "shared_fpn_conv" if self.share_conv else "shared_fpn_conv_" + str( + i) + conv = fluid.layers.conv2d( + input=outs[i], + num_filters=self.num_chan, + filter_size=3, + stride=1, + padding=1, + param_attr=ParamAttr(name=conv_name + "_weights"), + bias_attr=False) + outputs.append(conv) + + for idx in range(0, num_out - len(body_name_list)): + body_name_list.append("fpn_res5_sum_subsampled_{}x".format(2**(idx + + 1))) + + outputs = outputs[::-1] + body_name_list = body_name_list[::-1] + + res_dict = OrderedDict([(body_name_list[k], outputs[k]) + for k in range(len(body_name_list))]) + return res_dict, self.spatial_scale + + def resize_input_tensor(self, body_input, ref_output, scale): + shape = fluid.layers.shape(ref_output) + shape_hw = fluid.layers.slice(shape, axes=[0], starts=[2], ends=[4]) + out_shape_ = shape_hw + out_shape = fluid.layers.cast(out_shape_, dtype='int32') + out_shape.stop_gradient = True + body_output = fluid.layers.resize_bilinear( + body_input, scale=scale, out_shape=out_shape) + return body_output + + def pooling(self, input, size, stride, pooling_type): + pool = fluid.layers.pool2d( + input=input, + pool_size=size, + pool_stride=stride, + pool_type=pooling_type) + return pool diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/hrnet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/hrnet.py new file mode 100755 index 000000000..2849d942a --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/hrnet.py @@ -0,0 +1,431 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.framework import Variable +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register, serializable +from numbers import Integral +from paddle.fluid.initializer import MSRA +import math + +from .name_adapter import NameAdapter + +__all__ = ['HRNet'] + + +@register +@serializable +class HRNet(object): + """ + HRNet, see https://arxiv.org/abs/1908.07919 + Args: + depth (int): ResNet depth, should be 18, 34, 50, 101, 152. + freeze_at (int): freeze the backbone at which stage + norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel' + freeze_norm (bool): freeze normalization layers + norm_decay (float): weight decay for normalization layer weights + variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently + feature_maps (list): index of stages whose feature maps are returned + """ + + def __init__(self, + width=40, + has_se=False, + freeze_at=2, + norm_type='bn', + freeze_norm=True, + norm_decay=0., + feature_maps=[2, 3, 4, 5]): + super(HRNet, self).__init__() + + if isinstance(feature_maps, Integral): + feature_maps = [feature_maps] + + assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4" + assert len(feature_maps) > 0, "need one or more feature maps" + assert norm_type in ['bn', 'sync_bn'] + + self.width = width + self.has_se = has_se + self.channels = { + 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]], + 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]], + 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]], + 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]], + 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]], + 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]], + 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]], + } + + self.freeze_at = freeze_at + self.norm_type = norm_type + self.norm_decay = norm_decay + self.freeze_norm = freeze_norm + self._model_type = 'HRNet' + self.feature_maps = feature_maps + self.end_points = [] + return + + def net(self, input, class_dim=1000): + width = self.width + channels_2, channels_3, channels_4 = self.channels[width] + num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3 + + x = self.conv_bn_layer( + input=input, + filter_size=3, + num_filters=64, + stride=2, + if_act=True, + name='layer1_1') + x = self.conv_bn_layer( + input=x, + filter_size=3, + num_filters=64, + stride=2, + if_act=True, + name='layer1_2') + + la1 = self.layer1(x, name='layer2') + tr1 = self.transition_layer([la1], [256], channels_2, name='tr1') + st2 = self.stage(tr1, num_modules_2, channels_2, name='st2') + tr2 = self.transition_layer(st2, channels_2, channels_3, name='tr2') + st3 = self.stage(tr2, num_modules_3, channels_3, name='st3') + tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3') + st4 = self.stage(tr3, num_modules_4, channels_4, name='st4') + + self.end_points = st4 + return st4[-1] + + def layer1(self, input, name=None): + conv = input + for i in range(4): + conv = self.bottleneck_block( + conv, + num_filters=64, + downsample=True if i == 0 else False, + name=name + '_' + str(i + 1)) + return conv + + def transition_layer(self, x, in_channels, out_channels, name=None): + num_in = len(in_channels) + num_out = len(out_channels) + out = [] + for i in range(num_out): + if i < num_in: + if in_channels[i] != out_channels[i]: + residual = self.conv_bn_layer( + x[i], + filter_size=3, + num_filters=out_channels[i], + name=name + '_layer_' + str(i + 1)) + out.append(residual) + else: + out.append(x[i]) + else: + residual = self.conv_bn_layer( + x[-1], + filter_size=3, + num_filters=out_channels[i], + stride=2, + name=name + '_layer_' + str(i + 1)) + out.append(residual) + return out + + def branches(self, x, block_num, channels, name=None): + out = [] + for i in range(len(channels)): + residual = x[i] + for j in range(block_num): + residual = self.basic_block( + residual, + channels[i], + name=name + '_branch_layer_' + str(i + 1) + '_' + + str(j + 1)) + out.append(residual) + return out + + def fuse_layers(self, x, channels, multi_scale_output=True, name=None): + out = [] + for i in range(len(channels) if multi_scale_output else 1): + residual = x[i] + for j in range(len(channels)): + if j > i: + y = self.conv_bn_layer( + x[j], + filter_size=1, + num_filters=channels[i], + if_act=False, + name=name + '_layer_' + str(i + 1) + '_' + str(j + 1)) + y = fluid.layers.resize_nearest(input=y, scale=2**(j - i)) + residual = fluid.layers.elementwise_add( + x=residual, y=y, act=None) + elif j < i: + y = x[j] + for k in range(i - j): + if k == i - j - 1: + y = self.conv_bn_layer( + y, + filter_size=3, + num_filters=channels[i], + stride=2, + if_act=False, + name=name + '_layer_' + str(i + 1) + '_' + + str(j + 1) + '_' + str(k + 1)) + else: + y = self.conv_bn_layer( + y, + filter_size=3, + num_filters=channels[j], + stride=2, + name=name + '_layer_' + str(i + 1) + '_' + + str(j + 1) + '_' + str(k + 1)) + residual = fluid.layers.elementwise_add( + x=residual, y=y, act=None) + + residual = fluid.layers.relu(residual) + out.append(residual) + return out + + def high_resolution_module(self, + x, + channels, + multi_scale_output=True, + name=None): + residual = self.branches(x, 4, channels, name=name) + out = self.fuse_layers( + residual, + channels, + multi_scale_output=multi_scale_output, + name=name) + return out + + def stage(self, + x, + num_modules, + channels, + multi_scale_output=True, + name=None): + out = x + for i in range(num_modules): + if i == num_modules - 1 and multi_scale_output == False: + out = self.high_resolution_module( + out, + channels, + multi_scale_output=False, + name=name + '_' + str(i + 1)) + else: + out = self.high_resolution_module( + out, channels, name=name + '_' + str(i + 1)) + + return out + + def last_cls_out(self, x, name=None): + out = [] + num_filters_list = [128, 256, 512, 1024] + for i in range(len(x)): + out.append( + self.conv_bn_layer( + input=x[i], + filter_size=1, + num_filters=num_filters_list[i], + name=name + 'conv_' + str(i + 1))) + return out + + def basic_block(self, + input, + num_filters, + stride=1, + downsample=False, + name=None): + residual = input + conv = self.conv_bn_layer( + input=input, + filter_size=3, + num_filters=num_filters, + stride=stride, + name=name + '_conv1') + conv = self.conv_bn_layer( + input=conv, + filter_size=3, + num_filters=num_filters, + if_act=False, + name=name + '_conv2') + if downsample: + residual = self.conv_bn_layer( + input=input, + filter_size=1, + num_filters=num_filters, + if_act=False, + name=name + '_downsample') + if self.has_se: + conv = self.squeeze_excitation( + input=conv, + num_channels=num_filters, + reduction_ratio=16, + name='fc' + name) + return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') + + def bottleneck_block(self, + input, + num_filters, + stride=1, + downsample=False, + name=None): + residual = input + conv = self.conv_bn_layer( + input=input, + filter_size=1, + num_filters=num_filters, + name=name + '_conv1') + conv = self.conv_bn_layer( + input=conv, + filter_size=3, + num_filters=num_filters, + stride=stride, + name=name + '_conv2') + conv = self.conv_bn_layer( + input=conv, + filter_size=1, + num_filters=num_filters * 4, + if_act=False, + name=name + '_conv3') + if downsample: + residual = self.conv_bn_layer( + input=input, + filter_size=1, + num_filters=num_filters * 4, + if_act=False, + name=name + '_downsample') + if self.has_se: + conv = self.squeeze_excitation( + input=conv, + num_channels=num_filters * 4, + reduction_ratio=16, + name='fc' + name) + return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') + + def squeeze_excitation(self, + input, + num_channels, + reduction_ratio, + name=None): + pool = fluid.layers.pool2d( + input=input, pool_size=0, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + squeeze = fluid.layers.fc( + input=pool, + size=num_channels / reduction_ratio, + act='relu', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_sqz_weights'), + bias_attr=ParamAttr(name=name + '_sqz_offset')) + stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) + excitation = fluid.layers.fc( + input=squeeze, + size=num_channels, + act='sigmoid', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_exc_weights'), + bias_attr=ParamAttr(name=name + '_exc_offset')) + scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return scale + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride=1, + padding=1, + num_groups=1, + if_act=True, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=num_groups, + act=None, + param_attr=ParamAttr( + initializer=MSRA(), name=name + '_weights'), + bias_attr=False) + bn_name = name + '_bn' + bn = self._bn(input=conv, bn_name=bn_name) + if if_act: + bn = fluid.layers.relu(bn) + return bn + + def _bn(self, input, act=None, bn_name=None): + norm_lr = 0. if self.freeze_norm else 1. + norm_decay = self.norm_decay + pattr = ParamAttr( + name=bn_name + '_scale', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + battr = ParamAttr( + name=bn_name + '_offset', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + + global_stats = True if self.freeze_norm else False + out = fluid.layers.batch_norm( + input=input, + act=act, + name=bn_name + '.output.1', + param_attr=pattr, + bias_attr=battr, + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + use_global_stats=global_stats) + scale = fluid.framework._get_var(pattr.name) + bias = fluid.framework._get_var(battr.name) + if self.freeze_norm: + scale.stop_gradient = True + bias.stop_gradient = True + return out + + def __call__(self, input): + assert isinstance(input, Variable) + assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \ + "feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps) + + res_endpoints = [] + + res = input + feature_maps = self.feature_maps + self.net(input) + + for i in feature_maps: + res = self.end_points[i - 2] + if i in self.feature_maps: + res_endpoints.append(res) + if self.freeze_at >= i: + res.stop_gradient = True + + return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat) + for idx, feat in enumerate(res_endpoints)]) diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/mobilenet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/mobilenet.py new file mode 100755 index 000000000..b2a0eb45f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/mobilenet.py @@ -0,0 +1,218 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +__all__ = ['MobileNet'] + + +@register +class MobileNet(object): + """ + MobileNet v1, see https://arxiv.org/abs/1704.04861 + + Args: + norm_type (str): normalization type, 'bn' and 'sync_bn' are supported + norm_decay (float): weight decay for normalization layer weights + conv_decay (float): weight decay for convolution layer weights. + conv_group_scale (int): scaling factor for convolution groups + with_extra_blocks (bool): if extra blocks should be added + extra_block_filters (list): number of filter for each extra block + """ + __shared__ = ['norm_type', 'weight_prefix_name'] + + def __init__(self, + norm_type='bn', + norm_decay=0., + conv_decay=0., + conv_group_scale=1, + conv_learning_rate=1.0, + with_extra_blocks=False, + extra_block_filters=[[256, 512], [128, 256], [128, 256], + [64, 128]], + weight_prefix_name=''): + self.norm_type = norm_type + self.norm_decay = norm_decay + self.conv_decay = conv_decay + self.conv_group_scale = conv_group_scale + self.conv_learning_rate = conv_learning_rate + self.with_extra_blocks = with_extra_blocks + self.extra_block_filters = extra_block_filters + self.prefix_name = weight_prefix_name + + def _conv_norm(self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + act='relu', + use_cudnn=True, + name=None): + parameter_attr = ParamAttr( + learning_rate=self.conv_learning_rate, + initializer=fluid.initializer.MSRA(), + regularizer=L2Decay(self.conv_decay), + name=name + "_weights") + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=parameter_attr, + bias_attr=False) + + bn_name = name + "_bn" + norm_decay = self.norm_decay + bn_param_attr = ParamAttr( + regularizer=L2Decay(norm_decay), name=bn_name + '_scale') + bn_bias_attr = ParamAttr( + regularizer=L2Decay(norm_decay), name=bn_name + '_offset') + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=bn_param_attr, + bias_attr=bn_bias_attr, + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def depthwise_separable(self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + name=None): + mixed_precision_enabled = mixed_precision_global_state() is not None + depthwise_conv = self._conv_norm( + input=input, + filter_size=3, + num_filters=int(num_filters1 * scale), + stride=stride, + padding=1, + num_groups=int(num_groups * scale), + use_cudnn=mixed_precision_enabled, + name=name + "_dw") + + pointwise_conv = self._conv_norm( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0, + name=name + "_sep") + return pointwise_conv + + def _extra_block(self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + name=None): + pointwise_conv = self._conv_norm( + input=input, + filter_size=1, + num_filters=int(num_filters1), + stride=1, + num_groups=int(num_groups), + padding=0, + act='relu6', + name=name + "_extra1") + normal_conv = self._conv_norm( + input=pointwise_conv, + filter_size=3, + num_filters=int(num_filters2), + stride=2, + num_groups=int(num_groups), + padding=1, + act='relu6', + name=name + "_extra2") + return normal_conv + + def __call__(self, input): + scale = self.conv_group_scale + + blocks = [] + # input 1/1 + out = self._conv_norm( + input, 3, int(32 * scale), 2, 1, name=self.prefix_name + "conv1") + # 1/2 + out = self.depthwise_separable( + out, 32, 64, 32, 1, scale, name=self.prefix_name + "conv2_1") + out = self.depthwise_separable( + out, 64, 128, 64, 2, scale, name=self.prefix_name + "conv2_2") + # 1/4 + out = self.depthwise_separable( + out, 128, 128, 128, 1, scale, name=self.prefix_name + "conv3_1") + out = self.depthwise_separable( + out, 128, 256, 128, 2, scale, name=self.prefix_name + "conv3_2") + # 1/8 + blocks.append(out) + out = self.depthwise_separable( + out, 256, 256, 256, 1, scale, name=self.prefix_name + "conv4_1") + out = self.depthwise_separable( + out, 256, 512, 256, 2, scale, name=self.prefix_name + "conv4_2") + # 1/16 + blocks.append(out) + for i in range(5): + out = self.depthwise_separable( + out, + 512, + 512, + 512, + 1, + scale, + name=self.prefix_name + "conv5_" + str(i + 1)) + module11 = out + + out = self.depthwise_separable( + out, 512, 1024, 512, 2, scale, name=self.prefix_name + "conv5_6") + # 1/32 + out = self.depthwise_separable( + out, 1024, 1024, 1024, 1, scale, name=self.prefix_name + "conv6") + module13 = out + blocks.append(out) + if not self.with_extra_blocks: + return blocks + + num_filters = self.extra_block_filters + module14 = self._extra_block(module13, num_filters[0][0], + num_filters[0][1], 1, 2, + self.prefix_name + "conv7_1") + module15 = self._extra_block(module14, num_filters[1][0], + num_filters[1][1], 1, 2, + self.prefix_name + "conv7_2") + module16 = self._extra_block(module15, num_filters[2][0], + num_filters[2][1], 1, 2, + self.prefix_name + "conv7_3") + module17 = self._extra_block(module16, num_filters[3][0], + num_filters[3][1], 1, 2, + self.prefix_name + "conv7_4") + return module11, module13, module14, module15, module16, module17 diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/mobilenet_v3.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/mobilenet_v3.py new file mode 100755 index 000000000..d4727449a --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/mobilenet_v3.py @@ -0,0 +1,567 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +import math +import numpy as np +from collections import OrderedDict + +from ppdet.core.workspace import register +from numbers import Integral + +__all__ = ['MobileNetV3', 'MobileNetV3RCNN'] + + +@register +class MobileNetV3(object): + """ + MobileNet v3, see https://arxiv.org/abs/1905.02244 + Args: + scale (float): scaling factor for convolution groups proportion of mobilenet_v3. + model_name (str): There are two modes, small and large. + norm_type (str): normalization type, 'bn' and 'sync_bn' are supported. + norm_decay (float): weight decay for normalization layer weights. + conv_decay (float): weight decay for convolution layer weights. + feature_maps (list): index of stages whose feature maps are returned. + extra_block_filters (list): number of filter for each extra block. + lr_mult_list (list): learning rate ratio of different blocks, lower learning rate ratio + is need for pretrained model got using distillation(default as + [1.0, 1.0, 1.0, 1.0, 1.0]). + freeze_norm (bool): freeze normalization layers. + multiplier (float): The multiplier by which to reduce the convolution expansion and + number of channels. + """ + __shared__ = ['norm_type'] + + def __init__( + self, + scale=1.0, + model_name='small', + feature_maps=[5, 6, 7, 8, 9, 10], + conv_decay=0.0, + norm_type='bn', + norm_decay=0.0, + extra_block_filters=[[256, 512], [128, 256], [128, 256], [64, 128]], + lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], + freeze_norm=False, + multiplier=1.0): + if isinstance(feature_maps, Integral): + feature_maps = [feature_maps] + + if norm_type == 'sync_bn' and freeze_norm: + raise ValueError( + "The norm_type should not be sync_bn when freeze_norm is True") + self.scale = scale + self.model_name = model_name + self.feature_maps = feature_maps + self.extra_block_filters = extra_block_filters + self.conv_decay = conv_decay + self.norm_decay = norm_decay + self.inplanes = 16 + self.end_points = [] + self.block_stride = 0 + + self.lr_mult_list = lr_mult_list + self.freeze_norm = freeze_norm + self.norm_type = norm_type + self.curr_stage = 0 + + if model_name == "large": + self.cfg = [ + # kernel_size, expand, channel, se_block, act_mode, stride + [3, 16, 16, False, 'relu', 1], + [3, 64, 24, False, 'relu', 2], + [3, 72, 24, False, 'relu', 1], + [5, 72, 40, True, 'relu', 2], + [5, 120, 40, True, 'relu', 1], + [5, 120, 40, True, 'relu', 1], + [3, 240, 80, False, 'hard_swish', 2], + [3, 200, 80, False, 'hard_swish', 1], + [3, 184, 80, False, 'hard_swish', 1], + [3, 184, 80, False, 'hard_swish', 1], + [3, 480, 112, True, 'hard_swish', 1], + [3, 672, 112, True, 'hard_swish', 1], + [5, 672, 160, True, 'hard_swish', 2], + [5, 960, 160, True, 'hard_swish', 1], + [5, 960, 160, True, 'hard_swish', 1], + ] + self.cls_ch_squeeze = 960 + self.cls_ch_expand = 1280 + elif model_name == "small": + self.cfg = [ + # kernel_size, expand, channel, se_block, act_mode, stride + [3, 16, 16, True, 'relu', 2], + [3, 72, 24, False, 'relu', 2], + [3, 88, 24, False, 'relu', 1], + [5, 96, 40, True, 'hard_swish', 2], + [5, 240, 40, True, 'hard_swish', 1], + [5, 240, 40, True, 'hard_swish', 1], + [5, 120, 48, True, 'hard_swish', 1], + [5, 144, 48, True, 'hard_swish', 1], + [5, 288, 96, True, 'hard_swish', 2], + [5, 576, 96, True, 'hard_swish', 1], + [5, 576, 96, True, 'hard_swish', 1], + ] + self.cls_ch_squeeze = 576 + self.cls_ch_expand = 1280 + else: + raise NotImplementedError + + if multiplier != 1.0: + self.cfg[-3][2] = int(self.cfg[-3][2] * multiplier) + self.cfg[-2][1] = int(self.cfg[-2][1] * multiplier) + self.cfg[-2][2] = int(self.cfg[-2][2] * multiplier) + self.cfg[-1][1] = int(self.cfg[-1][1] * multiplier) + self.cfg[-1][2] = int(self.cfg[-1][2] * multiplier) + + def _conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + if_act=True, + act=None, + name=None, + use_cudnn=True): + lr_idx = self.curr_stage // 3 + lr_idx = min(lr_idx, len(self.lr_mult_list) - 1) + lr_mult = self.lr_mult_list[lr_idx] + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr( + name=name + '_weights', + learning_rate=lr_mult, + regularizer=L2Decay(self.conv_decay)), + bias_attr=False) + bn_name = name + '_bn' + bn = self._bn(conv, bn_name=bn_name) + + if if_act: + if act == 'relu': + bn = fluid.layers.relu(bn) + elif act == 'hard_swish': + bn = self._hard_swish(bn) + elif act == 'relu6': + bn = fluid.layers.relu6(bn) + return bn + + def _bn(self, input, act=None, bn_name=None): + lr_idx = self.curr_stage // 3 + lr_idx = min(lr_idx, len(self.lr_mult_list) - 1) + lr_mult = self.lr_mult_list[lr_idx] + norm_lr = 0. if self.freeze_norm else lr_mult + norm_decay = self.norm_decay + pattr = ParamAttr( + name=bn_name + '_scale', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + battr = ParamAttr( + name=bn_name + '_offset', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + + conv = input + + if self.norm_type in ['bn', 'sync_bn']: + global_stats = True if self.freeze_norm else False + out = fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=pattr, + bias_attr=battr, + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + use_global_stats=global_stats) + scale = fluid.framework._get_var(pattr.name) + bias = fluid.framework._get_var(battr.name) + elif self.norm_type == 'affine_channel': + scale = fluid.layers.create_parameter( + shape=[conv.shape[1]], + dtype=conv.dtype, + attr=pattr, + default_initializer=fluid.initializer.Constant(1.)) + bias = fluid.layers.create_parameter( + shape=[conv.shape[1]], + dtype=conv.dtype, + attr=battr, + default_initializer=fluid.initializer.Constant(0.)) + out = fluid.layers.affine_channel( + x=conv, scale=scale, bias=bias, act=act) + + if self.freeze_norm: + scale.stop_gradient = True + bias.stop_gradient = True + + return out + + def _hard_swish(self, x): + return x * fluid.layers.relu6(x + 3) / 6. + + def _se_block(self, input, num_out_filter, ratio=4, name=None): + lr_idx = self.curr_stage // 3 + lr_idx = min(lr_idx, len(self.lr_mult_list) - 1) + lr_mult = self.lr_mult_list[lr_idx] + + num_mid_filter = int(num_out_filter // ratio) + pool = fluid.layers.pool2d( + input=input, pool_type='avg', global_pooling=True, use_cudnn=False) + conv1 = fluid.layers.conv2d( + input=pool, + filter_size=1, + num_filters=num_mid_filter, + act='relu', + param_attr=ParamAttr( + name=name + '_1_weights', + learning_rate=lr_mult, + regularizer=L2Decay(self.conv_decay)), + bias_attr=ParamAttr( + name=name + '_1_offset', + learning_rate=lr_mult, + regularizer=L2Decay(self.conv_decay))) + conv2 = fluid.layers.conv2d( + input=conv1, + filter_size=1, + num_filters=num_out_filter, + act='hard_sigmoid', + param_attr=ParamAttr( + name=name + '_2_weights', + learning_rate=lr_mult, + regularizer=L2Decay(self.conv_decay)), + bias_attr=ParamAttr( + name=name + '_2_offset', + learning_rate=lr_mult, + regularizer=L2Decay(self.conv_decay))) + + scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0) + return scale + + def _residual_unit(self, + input, + num_in_filter, + num_mid_filter, + num_out_filter, + stride, + filter_size, + act=None, + use_se=False, + name=None): + input_data = input + conv0 = self._conv_bn_layer( + input=input, + filter_size=1, + num_filters=num_mid_filter, + stride=1, + padding=0, + if_act=True, + act=act, + name=name + '_expand') + + if self.block_stride == 4 and stride == 2: + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(conv0) + + with fluid.name_scope('res_conv1'): + conv1 = self._conv_bn_layer( + input=conv0, + filter_size=filter_size, + num_filters=num_mid_filter, + stride=stride, + padding=int((filter_size - 1) // 2), + if_act=True, + act=act, + num_groups=num_mid_filter, + use_cudnn=False, + name=name + '_depthwise') + + if use_se: + with fluid.name_scope('se_block'): + conv1 = self._se_block( + input=conv1, + num_out_filter=num_mid_filter, + name=name + '_se') + + conv2 = self._conv_bn_layer( + input=conv1, + filter_size=1, + num_filters=num_out_filter, + stride=1, + padding=0, + if_act=False, + name=name + '_linear') + if num_in_filter != num_out_filter or stride != 1: + return conv2 + else: + return fluid.layers.elementwise_add(x=input_data, y=conv2, act=None) + + def _extra_block_dw(self, + input, + num_filters1, + num_filters2, + stride, + name=None): + pointwise_conv = self._conv_bn_layer( + input=input, + filter_size=1, + num_filters=int(num_filters1), + stride=1, + padding="SAME", + act='relu6', + name=name + "_extra1") + depthwise_conv = self._conv_bn_layer( + input=pointwise_conv, + filter_size=3, + num_filters=int(num_filters2), + stride=stride, + padding="SAME", + num_groups=int(num_filters1), + act='relu6', + use_cudnn=False, + name=name + "_extra2_dw") + normal_conv = self._conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2), + stride=1, + padding="SAME", + act='relu6', + name=name + "_extra2_sep") + return normal_conv + + def _make_divisible(self, v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + def __call__(self, input): + scale = self.scale + inplanes = self.inplanes + cfg = self.cfg + blocks = [] + + #conv1 + conv = self._conv_bn_layer( + input, + filter_size=3, + num_filters=self._make_divisible(inplanes * scale), + stride=2, + padding=1, + num_groups=1, + if_act=True, + act='hard_swish', + name='conv1') + i = 0 + inplanes = self._make_divisible(inplanes * scale) + for layer_cfg in cfg: + if layer_cfg[5] == 2: + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(conv) + + conv = self._residual_unit( + input=conv, + num_in_filter=inplanes, + num_mid_filter=self._make_divisible(scale * layer_cfg[1]), + num_out_filter=self._make_divisible(scale * layer_cfg[2]), + act=layer_cfg[4], + stride=layer_cfg[5], + filter_size=layer_cfg[0], + use_se=layer_cfg[3], + name='conv' + str(i + 2)) + inplanes = self._make_divisible(scale * layer_cfg[2]) + i += 1 + self.curr_stage += 1 + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(conv) + + # extra block + # check whether conv_extra is needed + if self.block_stride < max(self.feature_maps): + conv_extra = self._conv_bn_layer( + conv, + filter_size=1, + num_filters=self._make_divisible(scale * cfg[-1][1]), + stride=1, + padding="SAME", + num_groups=1, + if_act=True, + act='hard_swish', + name='conv' + str(i + 2)) + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(conv_extra) + i += 1 + for block_filter in self.extra_block_filters: + conv_extra = self._extra_block_dw(conv_extra, block_filter[0], + block_filter[1], 2, + 'conv' + str(i + 2)) + self.block_stride += 1 + if self.block_stride in self.feature_maps: + self.end_points.append(conv_extra) + i += 1 + + return OrderedDict([('mbv3_{}'.format(idx), feat) + for idx, feat in enumerate(self.end_points)]) + + +@register +class MobileNetV3RCNN(MobileNetV3): + def __init__(self, + scale=1.0, + model_name='large', + conv_decay=0.0, + norm_type='bn', + norm_decay=0.0, + freeze_norm=True, + feature_maps=[2, 3, 4, 5], + lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]): + super(MobileNetV3RCNN, self).__init__( + scale=scale, + model_name=model_name, + conv_decay=conv_decay, + norm_type=norm_type, + norm_decay=norm_decay, + lr_mult_list=lr_mult_list, + feature_maps=feature_maps, + freeze_norm=freeze_norm) + self.curr_stage = 0 + self.block_stride = 1 + + def _residual_unit(self, + input, + num_in_filter, + num_mid_filter, + num_out_filter, + stride, + filter_size, + act=None, + use_se=False, + name=None): + input_data = input + conv0 = self._conv_bn_layer( + input=input, + filter_size=1, + num_filters=num_mid_filter, + stride=1, + padding=0, + if_act=True, + act=act, + name=name + '_expand') + + feature_level = int(np.log2(self.block_stride)) + if feature_level in self.feature_maps and stride == 2: + self.end_points.append(conv0) + + conv1 = self._conv_bn_layer( + input=conv0, + filter_size=filter_size, + num_filters=num_mid_filter, + stride=stride, + padding=int((filter_size - 1) // 2), + if_act=True, + act=act, + num_groups=num_mid_filter, + use_cudnn=False, + name=name + '_depthwise') + + if use_se: + conv1 = self._se_block( + input=conv1, num_out_filter=num_mid_filter, name=name + '_se') + + conv2 = self._conv_bn_layer( + input=conv1, + filter_size=1, + num_filters=num_out_filter, + stride=1, + padding=0, + if_act=False, + name=name + '_linear') + if num_in_filter != num_out_filter or stride != 1: + return conv2 + else: + return fluid.layers.elementwise_add(x=input_data, y=conv2, act=None) + + def __call__(self, input): + scale = self.scale + inplanes = self.inplanes + cfg = self.cfg + #conv1 + conv = self._conv_bn_layer( + input, + filter_size=3, + num_filters=self._make_divisible(inplanes * scale), + stride=2, + padding=1, + num_groups=1, + if_act=True, + act='hard_swish', + name='conv1') + i = 0 + inplanes = self._make_divisible(inplanes * scale) + for layer_cfg in cfg: + self.block_stride *= layer_cfg[5] + conv = self._residual_unit( + input=conv, + num_in_filter=inplanes, + num_mid_filter=self._make_divisible(scale * layer_cfg[1]), + num_out_filter=self._make_divisible(scale * layer_cfg[2]), + act=layer_cfg[4], + stride=layer_cfg[5], + filter_size=layer_cfg[0], + use_se=layer_cfg[3], + name='conv' + str(i + 2)) + inplanes = self._make_divisible(scale * layer_cfg[2]) + i += 1 + self.curr_stage += 1 + + if np.max(self.feature_maps) >= 5: + conv = self._conv_bn_layer( + input=conv, + filter_size=1, + num_filters=self._make_divisible(scale * cfg[-1][1]), + stride=1, + padding=0, + num_groups=1, + if_act=True, + act='hard_swish', + name='conv_last') + self.end_points.append(conv) + i += 1 + + res = OrderedDict([('mv3_{}'.format(idx), self.end_points[idx]) + for idx, feat_idx in enumerate(self.feature_maps)]) + return res diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/name_adapter.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/name_adapter.py new file mode 100755 index 000000000..2cb16d0c9 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/name_adapter.py @@ -0,0 +1,73 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class NameAdapter(object): + """Fix the backbones variable names for pretrained weight""" + + def __init__(self, model): + super(NameAdapter, self).__init__() + self.model = model + + @property + def model_type(self): + return getattr(self.model, '_model_type', '') + + @property + def variant(self): + return getattr(self.model, 'variant', '') + + def fix_conv_norm_name(self, name): + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + # the naming rule is same as pretrained weight + if self.model_type == 'SEResNeXt': + bn_name = name + "_bn" + return bn_name + + def fix_shortcut_name(self, name): + if self.model_type == 'SEResNeXt': + name = 'conv' + name + '_prj' + return name + + def fix_bottleneck_name(self, name): + if self.model_type == 'SEResNeXt': + conv_name1 = 'conv' + name + '_x1' + conv_name2 = 'conv' + name + '_x2' + conv_name3 = 'conv' + name + '_x3' + shortcut_name = name + else: + conv_name1 = name + "_branch2a" + conv_name2 = name + "_branch2b" + conv_name3 = name + "_branch2c" + shortcut_name = name + "_branch1" + return conv_name1, conv_name2, conv_name3, shortcut_name + + def fix_layer_warp_name(self, stage_num, count, i): + name = 'res' + str(stage_num) + if count > 10 and stage_num == 4: + if i == 0: + conv_name = name + "a" + else: + conv_name = name + "b" + str(i) + else: + conv_name = name + chr(ord("a") + i) + if self.model_type == 'SEResNeXt': + conv_name = str(stage_num + 2) + '_' + str(i + 1) + return conv_name + + def fix_c1_stage_name(self): + return "res_conv1" if self.model_type == 'ResNeXt' else "conv1" diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/nonlocal_helper.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/nonlocal_helper.py new file mode 100755 index 000000000..d33ae61bb --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/nonlocal_helper.py @@ -0,0 +1,120 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +from paddle.fluid.initializer import ConstantInitializer + + +def space_nonlocal(input, + dim_in, + dim_out, + prefix, + dim_inner, + with_bias=False, + with_scale=True): + theta = fluid.layers.conv2d( + input=input, + num_filters=dim_inner, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=prefix + '_theta_w'), + bias_attr=ParamAttr( + name=prefix + '_theta_b', initializer=ConstantInitializer(value=0.)) + if with_bias else False) + theta_shape = theta.shape + theta_shape_op = fluid.layers.shape(theta) + theta_shape_op.stop_gradient = True + + # we have to use explicit batch size (to support arbitrary spacetime size) + # e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784) + theta = fluid.layers.reshape(theta, shape=(0, 0, -1)) + theta = fluid.layers.transpose(theta, [0, 2, 1]) + + phi = fluid.layers.conv2d( + input=input, + num_filters=dim_inner, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=prefix + '_phi_w'), + bias_attr=ParamAttr( + name=prefix + '_phi_b', initializer=ConstantInitializer(value=0.)) + if with_bias else False, + name=prefix + '_phi') + phi = fluid.layers.reshape(phi, [0, 0, -1]) + + theta_phi = fluid.layers.matmul(theta, phi) + + g = fluid.layers.conv2d( + input=input, + num_filters=dim_inner, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=prefix + '_g_w'), + bias_attr=ParamAttr( + name=prefix + '_g_b', initializer=ConstantInitializer(value=0.)) + if with_bias else False, + name=prefix + '_g') + g = fluid.layers.reshape(g, [0, 0, -1]) + + # scale + if with_scale: + theta_phi = fluid.layers.scale(theta_phi, scale=dim_inner**-.5) + p = fluid.layers.softmax(theta_phi) + + # note g's axis[2] corresponds to p's axis[2] + # e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1) + p = fluid.layers.transpose(p, [0, 2, 1]) + t = fluid.layers.matmul(g, p) + + # reshape back + # e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14) + n = fluid.layers.slice(theta_shape_op, axes=[0], starts=[0], ends=[1]) + h = fluid.layers.slice(theta_shape_op, axes=[0], starts=[2], ends=[3]) + w = fluid.layers.slice(theta_shape_op, axes=[0], starts=[3], ends=[4]) + ch = int(theta_shape[1]) + + t_re = fluid.layers.reshape(t, shape=[n, ch, h, w]) + blob_out = t_re + blob_out = fluid.layers.conv2d( + input=blob_out, + num_filters=dim_out, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr( + name=prefix + '_out_w', initializer=ConstantInitializer(value=0.0)), + bias_attr=ParamAttr( + name=prefix + '_out_b', initializer=ConstantInitializer(value=0.0)) + if with_bias else False, + name=prefix + '_out') + blob_out_shape = blob_out.shape + return blob_out + + +def add_space_nonlocal(input, + dim_in, + dim_out, + prefix, + dim_inner, + with_bias=False, + with_scale=True): + ''' + add_space_nonlocal: + Non-local Neural Networks: see https://arxiv.org/abs/1711.07971 + ''' + conv = space_nonlocal( + input, + dim_in, + dim_out, + prefix, + dim_inner, + with_bias=with_bias, + with_scale=with_scale) + output = input + conv + return output diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/res2net.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/res2net.py new file mode 100755 index 000000000..d30ce0b80 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/res2net.py @@ -0,0 +1,229 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.framework import Variable +from paddle.fluid.regularizer import L2Decay +from paddle.fluid.initializer import Constant + +from ppdet.core.workspace import register, serializable +from numbers import Integral + +from .nonlocal_helper import add_space_nonlocal +from .name_adapter import NameAdapter +from .resnet import ResNet, ResNetC5 + +__all__ = ['Res2Net', 'Res2NetC5'] + + +@register +@serializable +class Res2Net(ResNet): + """ + Res2Net, see https://arxiv.org/abs/1904.01169 + Args: + depth (int): Res2Net depth, should be 50, 101, 152, 200. + width (int): Res2Net width + scales (int): Res2Net scale + freeze_at (int): freeze the backbone at which stage + norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel' + freeze_norm (bool): freeze normalization layers + norm_decay (float): weight decay for normalization layer weights + variant (str): Res2Net variant, supports 'a', 'b', 'c', 'd' currently + feature_maps (list): index of stages whose feature maps are returned + dcn_v2_stages (list): index of stages who select deformable conv v2 + nonlocal_stages (list): index of stages who select nonlocal networks + """ + __shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name'] + + def __init__( + self, + depth=50, + width=26, + scales=4, + freeze_at=2, + norm_type='bn', + freeze_norm=True, + norm_decay=0., + variant='b', + feature_maps=[2, 3, 4, 5], + dcn_v2_stages=[], + weight_prefix_name='', + nonlocal_stages=[], ): + super(Res2Net, self).__init__( + depth=depth, + freeze_at=freeze_at, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay, + variant=variant, + feature_maps=feature_maps, + dcn_v2_stages=dcn_v2_stages, + weight_prefix_name=weight_prefix_name, + nonlocal_stages=nonlocal_stages) + + assert depth >= 50, "just support depth>=50 in res2net, but got depth=".format( + depth) + # res2net config + self.scales = scales + self.width = width + basic_width = self.width * self.scales + self.num_filters1 = [basic_width * t for t in [1, 2, 4, 8]] + self.num_filters2 = [256 * t for t in [1, 2, 4, 8]] + self.num_filters = [64, 128, 384, 768] + + def bottleneck(self, + input, + num_filters1, + num_filters2, + stride, + is_first, + name, + dcn_v2=False): + conv0 = self._conv_norm( + input=input, + num_filters=num_filters1, + filter_size=1, + stride=1, + act='relu', + name=name + '_branch2a') + + xs = fluid.layers.split(conv0, self.scales, 1) + ys = [] + for s in range(self.scales - 1): + if s == 0 or stride == 2: + ys.append( + self._conv_norm( + input=xs[s], + num_filters=num_filters1 // self.scales, + stride=stride, + filter_size=3, + act='relu', + name=name + '_branch2b_' + str(s + 1), + dcn_v2=dcn_v2)) + else: + ys.append( + self._conv_norm( + input=xs[s] + ys[-1], + num_filters=num_filters1 // self.scales, + stride=stride, + filter_size=3, + act='relu', + name=name + '_branch2b_' + str(s + 1), + dcn_v2=dcn_v2)) + + if stride == 1: + ys.append(xs[-1]) + else: + ys.append( + fluid.layers.pool2d( + input=xs[-1], + pool_size=3, + pool_stride=stride, + pool_padding=1, + pool_type='avg')) + + conv1 = fluid.layers.concat(ys, axis=1) + conv2 = self._conv_norm( + input=conv1, + num_filters=num_filters2, + filter_size=1, + act=None, + name=name + "_branch2c") + + short = self._shortcut( + input, num_filters2, stride, is_first, name=name + "_branch1") + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + ".add.output.5") + + def layer_warp(self, input, stage_num): + """ + Args: + input (Variable): input variable. + stage_num (int): the stage number, should be 2, 3, 4, 5 + + Returns: + The last variable in endpoint-th stage. + """ + assert stage_num in [2, 3, 4, 5] + + stages, block_func = self.depth_cfg[self.depth] + count = stages[stage_num - 2] + + ch_out = self.stage_filters[stage_num - 2] + is_first = False if stage_num != 2 else True + dcn_v2 = True if stage_num in self.dcn_v2_stages else False + + num_filters1 = self.num_filters1[stage_num - 2] + num_filters2 = self.num_filters2[stage_num - 2] + + nonlocal_mod = 1000 + if stage_num in self.nonlocal_stages: + nonlocal_mod = self.nonlocal_mod_cfg[ + self.depth] if stage_num == 4 else 2 + + # Make the layer name and parameter name consistent + # with ImageNet pre-trained model + conv = input + for i in range(count): + conv_name = self.na.fix_layer_warp_name(stage_num, count, i) + if self.depth < 50: + is_first = True if i == 0 and stage_num == 2 else False + conv = block_func( + input=conv, + num_filters1=num_filters1, + num_filters2=num_filters2, + stride=2 if i == 0 and stage_num != 2 else 1, + is_first=is_first, + name=conv_name, + dcn_v2=dcn_v2) + + # add non local model + dim_in = conv.shape[1] + nonlocal_name = "nonlocal_conv{}".format(stage_num) + if i % nonlocal_mod == nonlocal_mod - 1: + conv = add_space_nonlocal(conv, dim_in, dim_in, + nonlocal_name + '_{}'.format(i), + int(dim_in / 2)) + return conv + + +@register +@serializable +class Res2NetC5(Res2Net): + __doc__ = Res2Net.__doc__ + + def __init__(self, + depth=50, + width=26, + scales=4, + freeze_at=2, + norm_type='bn', + freeze_norm=True, + norm_decay=0., + variant='b', + feature_maps=[5], + weight_prefix_name=''): + super(Res2NetC5, self).__init__(depth, width, scales, freeze_at, + norm_type, freeze_norm, norm_decay, + variant, feature_maps) + self.severed_head = True diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/resnet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/resnet.py new file mode 100755 index 000000000..1b5d78bf0 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/resnet.py @@ -0,0 +1,503 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.framework import Variable +from paddle.fluid.regularizer import L2Decay +from paddle.fluid.initializer import Constant + +from ppdet.core.workspace import register, serializable +from numbers import Integral + +from .nonlocal_helper import add_space_nonlocal +from .gc_block import add_gc_block +from .name_adapter import NameAdapter + +__all__ = ['ResNet', 'ResNetC5'] + + +@register +@serializable +class ResNet(object): + """ + Residual Network, see https://arxiv.org/abs/1512.03385 + Args: + depth (int): ResNet depth, should be 18, 34, 50, 101, 152. + freeze_at (int): freeze the backbone at which stage + norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel' + freeze_norm (bool): freeze normalization layers + norm_decay (float): weight decay for normalization layer weights + variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently + feature_maps (list): index of stages whose feature maps are returned + dcn_v2_stages (list): index of stages who select deformable conv v2 + nonlocal_stages (list): index of stages who select nonlocal networks + gcb_stages (list): index of stages who select gc blocks + gcb_params (dict): gc blocks config, includes ratio(default as 1.0/16), + pooling_type(default as "att") and + fusion_types(default as ['channel_add']) + lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5), + lower learning rate ratio is need for pretrained model + got using distillation(default as [1.0, 1.0, 1.0, 1.0]). + """ + __shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name'] + + def __init__(self, + depth=50, + freeze_at=2, + norm_type='affine_channel', + freeze_norm=True, + norm_decay=0., + variant='b', + feature_maps=[2, 3, 4, 5], + dcn_v2_stages=[], + weight_prefix_name='', + nonlocal_stages=[], + gcb_stages=[], + gcb_params=dict(), + lr_mult_list=[1., 1., 1., 1.]): + super(ResNet, self).__init__() + + if isinstance(feature_maps, Integral): + feature_maps = [feature_maps] + + assert depth in [18, 34, 50, 101, 152, 200], \ + "depth {} not in [18, 34, 50, 101, 152, 200]" + assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant" + assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4" + assert len(feature_maps) > 0, "need one or more feature maps" + assert norm_type in ['bn', 'sync_bn', 'affine_channel'] + assert not (len(nonlocal_stages)>0 and depth<50), \ + "non-local is not supported for resnet18 or resnet34" + assert len(lr_mult_list + ) == 4, "lr_mult_list length must be 4 but got {}".format( + len(lr_mult_list)) + + self.depth = depth + self.freeze_at = freeze_at + self.norm_type = norm_type + self.norm_decay = norm_decay + self.freeze_norm = freeze_norm + self.variant = variant + self._model_type = 'ResNet' + self.feature_maps = feature_maps + self.dcn_v2_stages = dcn_v2_stages + self.depth_cfg = { + 18: ([2, 2, 2, 2], self.basicblock), + 34: ([3, 4, 6, 3], self.basicblock), + 50: ([3, 4, 6, 3], self.bottleneck), + 101: ([3, 4, 23, 3], self.bottleneck), + 152: ([3, 8, 36, 3], self.bottleneck), + 200: ([3, 12, 48, 3], self.bottleneck), + } + self.stage_filters = [64, 128, 256, 512] + self._c1_out_chan_num = 64 + self.na = NameAdapter(self) + self.prefix_name = weight_prefix_name + + self.nonlocal_stages = nonlocal_stages + self.nonlocal_mod_cfg = { + 50: 2, + 101: 5, + 152: 8, + 200: 12, + } + + self.gcb_stages = gcb_stages + self.gcb_params = gcb_params + + self.lr_mult_list = lr_mult_list + # var denoting curr stage + self.stage_num = -1 + + def _conv_offset(self, + input, + filter_size, + stride, + padding, + act=None, + name=None): + out_channel = filter_size * filter_size * 3 + out = fluid.layers.conv2d( + input, + num_filters=out_channel, + filter_size=filter_size, + stride=stride, + padding=padding, + param_attr=ParamAttr( + initializer=Constant(0.0), name=name + ".w_0"), + bias_attr=ParamAttr( + initializer=Constant(0.0), name=name + ".b_0"), + act=act, + name=name) + return out + + def _conv_norm(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None, + dcn_v2=False): + _name = self.prefix_name + name if self.prefix_name != '' else name + + # need fine lr for distilled model, default as 1.0 + lr_mult = 1.0 + mult_idx = max(self.stage_num - 2, 0) + mult_idx = min(self.stage_num - 2, 3) + lr_mult = self.lr_mult_list[mult_idx] + + if not dcn_v2: + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr( + name=_name + "_weights", learning_rate=lr_mult), + bias_attr=False, + name=_name + '.conv2d.output.1') + else: + # select deformable conv" + offset_mask = self._conv_offset( + input=input, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + act=None, + name=_name + "_conv_offset") + offset_channel = filter_size**2 * 2 + mask_channel = filter_size**2 + offset, mask = fluid.layers.split( + input=offset_mask, + num_or_sections=[offset_channel, mask_channel], + dim=1) + mask = fluid.layers.sigmoid(mask) + conv = fluid.layers.deformable_conv( + input=input, + offset=offset, + mask=mask, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + deformable_groups=1, + im2col_step=1, + param_attr=ParamAttr( + name=_name + "_weights", learning_rate=lr_mult), + bias_attr=False, + name=_name + ".conv2d.output.1") + + bn_name = self.na.fix_conv_norm_name(name) + bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name + + norm_lr = 0. if self.freeze_norm else lr_mult + norm_decay = self.norm_decay + pattr = ParamAttr( + name=bn_name + '_scale', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + battr = ParamAttr( + name=bn_name + '_offset', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + + if self.norm_type in ['bn', 'sync_bn']: + global_stats = True if self.freeze_norm else False + out = fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=pattr, + bias_attr=battr, + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + use_global_stats=global_stats) + scale = fluid.framework._get_var(pattr.name) + bias = fluid.framework._get_var(battr.name) + elif self.norm_type == 'affine_channel': + scale = fluid.layers.create_parameter( + shape=[conv.shape[1]], + dtype=conv.dtype, + attr=pattr, + default_initializer=fluid.initializer.Constant(1.)) + bias = fluid.layers.create_parameter( + shape=[conv.shape[1]], + dtype=conv.dtype, + attr=battr, + default_initializer=fluid.initializer.Constant(0.)) + out = fluid.layers.affine_channel( + x=conv, scale=scale, bias=bias, act=act) + if self.freeze_norm: + scale.stop_gradient = True + bias.stop_gradient = True + return out + + def _shortcut(self, input, ch_out, stride, is_first, name): + max_pooling_in_short_cut = self.variant == 'd' + ch_in = input.shape[1] + # the naming rule is same as pretrained weight + name = self.na.fix_shortcut_name(name) + std_senet = getattr(self, 'std_senet', False) + if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first): + if std_senet: + if is_first: + return self._conv_norm(input, ch_out, 1, stride, name=name) + else: + return self._conv_norm(input, ch_out, 3, stride, name=name) + if max_pooling_in_short_cut and not is_first: + input = fluid.layers.pool2d( + input=input, + pool_size=2, + pool_stride=2, + pool_padding=0, + ceil_mode=True, + pool_type='avg') + return self._conv_norm(input, ch_out, 1, 1, name=name) + return self._conv_norm(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck(self, + input, + num_filters, + stride, + is_first, + name, + dcn_v2=False, + gcb=False, + gcb_name=None): + if self.variant == 'a': + stride1, stride2 = stride, 1 + else: + stride1, stride2 = 1, stride + + # ResNeXt + groups = getattr(self, 'groups', 1) + group_width = getattr(self, 'group_width', -1) + if groups == 1: + expand = 4 + elif (groups * group_width) == 256: + expand = 1 + else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d + num_filters = num_filters // 2 + expand = 2 + + conv_name1, conv_name2, conv_name3, \ + shortcut_name = self.na.fix_bottleneck_name(name) + std_senet = getattr(self, 'std_senet', False) + if std_senet: + conv_def = [ + [int(num_filters / 2), 1, stride1, 'relu', 1, conv_name1], + [num_filters, 3, stride2, 'relu', groups, conv_name2], + [num_filters * expand, 1, 1, None, 1, conv_name3] + ] + else: + conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1], + [num_filters, 3, stride2, 'relu', groups, conv_name2], + [num_filters * expand, 1, 1, None, 1, conv_name3]] + + residual = input + for i, (c, k, s, act, g, _name) in enumerate(conv_def): + residual = self._conv_norm( + input=residual, + num_filters=c, + filter_size=k, + stride=s, + act=act, + groups=g, + name=_name, + dcn_v2=(i == 1 and dcn_v2)) + short = self._shortcut( + input, + num_filters * expand, + stride, + is_first=is_first, + name=shortcut_name) + # Squeeze-and-Excitation + if callable(getattr(self, '_squeeze_excitation', None)): + residual = self._squeeze_excitation( + input=residual, num_channels=num_filters, name='fc' + name) + if gcb: + residual = add_gc_block(residual, name=gcb_name, **self.gcb_params) + return fluid.layers.elementwise_add( + x=short, y=residual, act='relu', name=name + ".add.output.5") + + def basicblock(self, + input, + num_filters, + stride, + is_first, + name, + dcn_v2=False, + gcb=False, + gcb_name=None): + assert dcn_v2 is False, "Not implemented yet." + assert gcb is False, "Not implemented yet." + conv0 = self._conv_norm( + input=input, + num_filters=num_filters, + filter_size=3, + act='relu', + stride=stride, + name=name + "_branch2a") + conv1 = self._conv_norm( + input=conv0, + num_filters=num_filters, + filter_size=3, + act=None, + name=name + "_branch2b") + short = self._shortcut( + input, num_filters, stride, is_first, name=name + "_branch1") + return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') + + def layer_warp(self, input, stage_num): + """ + Args: + input (Variable): input variable. + stage_num (int): the stage number, should be 2, 3, 4, 5 + + Returns: + The last variable in endpoint-th stage. + """ + assert stage_num in [2, 3, 4, 5] + + self.stage_num = stage_num + + stages, block_func = self.depth_cfg[self.depth] + count = stages[stage_num - 2] + + ch_out = self.stage_filters[stage_num - 2] + is_first = False if stage_num != 2 else True + dcn_v2 = True if stage_num in self.dcn_v2_stages else False + + nonlocal_mod = 1000 + if stage_num in self.nonlocal_stages: + nonlocal_mod = self.nonlocal_mod_cfg[ + self.depth] if stage_num == 4 else 2 + + # Make the layer name and parameter name consistent + # with ImageNet pre-trained model + conv = input + for i in range(count): + conv_name = self.na.fix_layer_warp_name(stage_num, count, i) + if self.depth < 50: + is_first = True if i == 0 and stage_num == 2 else False + + gcb = stage_num in self.gcb_stages + gcb_name = "gcb_res{}_b{}".format(stage_num, i) + conv = block_func( + input=conv, + num_filters=ch_out, + stride=2 if i == 0 and stage_num != 2 else 1, + is_first=is_first, + name=conv_name, + dcn_v2=dcn_v2, + gcb=gcb, + gcb_name=gcb_name) + + # add non local model + dim_in = conv.shape[1] + nonlocal_name = "nonlocal_conv{}".format(stage_num) + if i % nonlocal_mod == nonlocal_mod - 1: + conv = add_space_nonlocal(conv, dim_in, dim_in, + nonlocal_name + '_{}'.format(i), + int(dim_in / 2)) + return conv + + def c1_stage(self, input): + out_chan = self._c1_out_chan_num + + conv1_name = self.na.fix_c1_stage_name() + + if self.variant in ['c', 'd']: + conv_def = [ + [out_chan // 2, 3, 2, "conv1_1"], + [out_chan // 2, 3, 1, "conv1_2"], + [out_chan, 3, 1, "conv1_3"], + ] + else: + conv_def = [[out_chan, 7, 2, conv1_name]] + + for (c, k, s, _name) in conv_def: + input = self._conv_norm( + input=input, + num_filters=c, + filter_size=k, + stride=s, + act='relu', + name=_name) + + output = fluid.layers.pool2d( + input=input, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + return output + + def __call__(self, input): + assert isinstance(input, Variable) + assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \ + "feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps) + + res_endpoints = [] + + res = input + feature_maps = self.feature_maps + severed_head = getattr(self, 'severed_head', False) + if not severed_head: + res = self.c1_stage(res) + feature_maps = range(2, max(self.feature_maps) + 1) + + for i in feature_maps: + res = self.layer_warp(res, i) + if i in self.feature_maps: + res_endpoints.append(res) + if self.freeze_at >= i: + res.stop_gradient = True + + return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat) + for idx, feat in enumerate(res_endpoints)]) + + +@register +@serializable +class ResNetC5(ResNet): + __doc__ = ResNet.__doc__ + + def __init__(self, + depth=50, + freeze_at=2, + norm_type='affine_channel', + freeze_norm=True, + norm_decay=0., + variant='b', + feature_maps=[5], + weight_prefix_name=''): + super(ResNetC5, self).__init__(depth, freeze_at, norm_type, freeze_norm, + norm_decay, variant, feature_maps) + self.severed_head = True diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/resnext.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/resnext.py new file mode 100755 index 000000000..545251137 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/resnext.py @@ -0,0 +1,89 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ppdet.core.workspace import register, serializable +from .resnet import ResNet + +__all__ = ['ResNeXt'] + + +@register +@serializable +class ResNeXt(ResNet): + """ + ResNeXt, see https://arxiv.org/abs/1611.05431 + Args: + depth (int): network depth, should be 50, 101, 152. + groups (int): group convolution cardinality + group_width (int): width of each group convolution + freeze_at (int): freeze the backbone at which stage + norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel' + freeze_norm (bool): freeze normalization layers + norm_decay (float): weight decay for normalization layer weights + variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently + feature_maps (list): index of the stages whose feature maps are returned + dcn_v2_stages (list): index of stages who select deformable conv v2 + """ + + def __init__(self, + depth=50, + groups=64, + group_width=4, + freeze_at=2, + norm_type='affine_channel', + freeze_norm=True, + norm_decay=True, + variant='a', + feature_maps=[2, 3, 4, 5], + dcn_v2_stages=[], + weight_prefix_name=''): + assert depth in [50, 101, 152], "depth {} should be 50, 101 or 152" + super(ResNeXt, self).__init__(depth, freeze_at, norm_type, freeze_norm, + norm_decay, variant, feature_maps) + self.depth_cfg = { + 50: ([3, 4, 6, 3], self.bottleneck), + 101: ([3, 4, 23, 3], self.bottleneck), + 152: ([3, 8, 36, 3], self.bottleneck) + } + self.stage_filters = [256, 512, 1024, 2048] + self.groups = groups + self.group_width = group_width + self._model_type = 'ResNeXt' + self.dcn_v2_stages = dcn_v2_stages + + +@register +@serializable +class ResNeXtC5(ResNeXt): + __doc__ = ResNeXt.__doc__ + + def __init__(self, + depth=50, + groups=64, + group_width=4, + freeze_at=2, + norm_type='affine_channel', + freeze_norm=True, + norm_decay=True, + variant='a', + feature_maps=[5], + weight_prefix_name=''): + super(ResNeXtC5, self).__init__(depth, groups, group_width, freeze_at, + norm_type, freeze_norm, norm_decay, + variant, feature_maps) + self.severed_head = True diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/senet.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/senet.py new file mode 100755 index 000000000..3219f6940 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/senet.py @@ -0,0 +1,124 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register, serializable +from .resnext import ResNeXt + +__all__ = ['SENet', 'SENetC5'] + + +@register +@serializable +class SENet(ResNeXt): + """ + Squeeze-and-Excitation Networks, see https://arxiv.org/abs/1709.01507 + Args: + depth (int): SENet depth, should be 50, 101, 152 + groups (int): group convolution cardinality + group_width (int): width of each group convolution + freeze_at (int): freeze the backbone at which stage + norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel' + freeze_norm (bool): freeze normalization layers + norm_decay (float): weight decay for normalization layer weights + variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently + feature_maps (list): index of the stages whose feature maps are returned + dcn_v2_stages (list): index of stages who select deformable conv v2 + """ + + def __init__(self, + depth=50, + groups=64, + group_width=4, + freeze_at=2, + norm_type='affine_channel', + freeze_norm=True, + norm_decay=0., + variant='d', + feature_maps=[2, 3, 4, 5], + dcn_v2_stages=[], + std_senet=False, + weight_prefix_name=''): + super(SENet, self).__init__(depth, groups, group_width, freeze_at, + norm_type, freeze_norm, norm_decay, variant, + feature_maps) + if depth < 152: + self.stage_filters = [128, 256, 512, 1024] + else: + self.stage_filters = [256, 512, 1024, 2048] + self.reduction_ratio = 16 + self.std_senet = std_senet + self._c1_out_chan_num = 128 + self._model_type = 'SEResNeXt' + self.dcn_v2_stages = dcn_v2_stages + + def _squeeze_excitation(self, input, num_channels, name=None): + mixed_precision_enabled = mixed_precision_global_state() is not None + pool = fluid.layers.pool2d( + input=input, + pool_size=0, + pool_type='avg', + global_pooling=True, + use_cudnn=mixed_precision_enabled) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + squeeze = fluid.layers.fc( + input=pool, + size=int(num_channels / self.reduction_ratio), + act='relu', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_sqz_weights'), + bias_attr=ParamAttr(name=name + '_sqz_offset')) + stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) + excitation = fluid.layers.fc( + input=squeeze, + size=num_channels, + act='sigmoid', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_exc_weights'), + bias_attr=ParamAttr(name=name + '_exc_offset')) + scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return scale + + +@register +@serializable +class SENetC5(SENet): + __doc__ = SENet.__doc__ + + def __init__(self, + depth=50, + groups=64, + group_width=4, + freeze_at=2, + norm_type='affine_channel', + freeze_norm=True, + norm_decay=0., + variant='d', + feature_maps=[5], + weight_prefix_name=''): + super(SENetC5, self).__init__(depth, groups, group_width, freeze_at, + norm_type, freeze_norm, norm_decay, + variant, feature_maps) + self.severed_head = True diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/vgg.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/vgg.py new file mode 100755 index 000000000..18cebec2e --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/backbones/vgg.py @@ -0,0 +1,207 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr + +from ppdet.core.workspace import register + +__all__ = ['VGG'] + + +@register +class VGG(object): + """ + VGG, see https://arxiv.org/abs/1409.1556 + + Args: + depth (int): the VGG net depth (16 or 19) + normalizations (list): params list of init scale in l2 norm, skip init + scale if param is -1. + with_extra_blocks (bool): whether or not extra blocks should be added + extra_block_filters (list): in each extra block, params: + [in_channel, out_channel, padding_size, stride_size, filter_size] + """ + + def __init__(self, + depth=16, + with_extra_blocks=False, + normalizations=[20., -1, -1, -1, -1, -1], + extra_block_filters=[[256, 512, 1, 2, 3], [128, 256, 1, 2, 3], + [128, 256, 0, 1, 3], + [128, 256, 0, 1, 3]]): + assert depth in [16, 19], \ + "depth {} not in [16, 19]" + + self.depth = depth + self.depth_cfg = {16: [2, 2, 3, 3, 3], 19: [2, 2, 4, 4, 4]} + self.with_extra_blocks = with_extra_blocks + self.normalizations = normalizations + self.extra_block_filters = extra_block_filters + + def __call__(self, input): + layers = [] + layers += self._vgg_block(input) + + if not self.with_extra_blocks: + return layers[-1] + + layers += self._add_extras_block(layers[-1]) + norm_cfg = self.normalizations + for k, v in enumerate(layers): + if not norm_cfg[k] == -1: + layers[k] = self._l2_norm_scale(v, init_scale=norm_cfg[k]) + + return layers + + def _vgg_block(self, input): + nums = self.depth_cfg[self.depth] + vgg_base = [64, 128, 256, 512, 512] + conv = input + layers = [] + for k, v in enumerate(vgg_base): + conv = self._conv_block( + conv, v, nums[k], name="conv{}_".format(k + 1)) + layers.append(conv) + if k == 4: + conv = self._pooling_block(conv, 3, 1, pool_padding=1) + else: + conv = self._pooling_block(conv, 2, 2) + + fc6 = self._conv_layer(conv, 1024, 3, 1, 6, dilation=6, name="fc6") + fc7 = self._conv_layer(fc6, 1024, 1, 1, 0, name="fc7") + + return [layers[3], fc7] + + def _add_extras_block(self, input): + cfg = self.extra_block_filters + conv = input + layers = [] + for k, v in enumerate(cfg): + assert len(v) == 5, "extra_block_filters size not fix" + conv = self._extra_block( + conv, + v[0], + v[1], + v[2], + v[3], + v[4], + name="conv{}_".format(6 + k)) + layers.append(conv) + + return layers + + def _conv_block(self, input, num_filter, groups, name=None): + conv = input + for i in range(groups): + conv = self._conv_layer( + input=conv, + num_filters=num_filter, + filter_size=3, + stride=1, + padding=1, + act='relu', + name=name + str(i + 1)) + return conv + + def _extra_block(self, + input, + num_filters1, + num_filters2, + padding_size, + stride_size, + filter_size, + name=None): + # 1x1 conv + conv_1 = self._conv_layer( + input=input, + num_filters=int(num_filters1), + filter_size=1, + stride=1, + act='relu', + padding=0, + name=name + "1") + + # 3x3 conv + conv_2 = self._conv_layer( + input=conv_1, + num_filters=int(num_filters2), + filter_size=filter_size, + stride=stride_size, + act='relu', + padding=padding_size, + name=name + "2") + return conv_2 + + def _conv_layer(self, + input, + num_filters, + filter_size, + stride, + padding, + dilation=1, + act='relu', + use_cudnn=True, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + dilation=dilation, + act=act, + use_cudnn=use_cudnn, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=ParamAttr(name=name + "_biases"), + name=name + '.conv2d.output.1') + return conv + + def _pooling_block(self, + conv, + pool_size, + pool_stride, + pool_padding=0, + ceil_mode=True): + pool = fluid.layers.pool2d( + input=conv, + pool_size=pool_size, + pool_type='max', + pool_stride=pool_stride, + pool_padding=pool_padding, + ceil_mode=ceil_mode) + return pool + + def _l2_norm_scale(self, input, init_scale=1.0, channel_shared=False): + from paddle.fluid.layer_helper import LayerHelper + from paddle.fluid.initializer import Constant + helper = LayerHelper("Scale") + l2_norm = fluid.layers.l2_normalize( + input, axis=1) # l2 norm along channel + shape = [1] if channel_shared else [input.shape[1]] + scale = helper.create_parameter( + attr=helper.param_attr, + shape=shape, + dtype=input.dtype, + default_initializer=Constant(init_scale)) + out = fluid.layers.elementwise_mul( + x=l2_norm, + y=scale, + axis=-1 if channel_shared else 1, + name="conv4_3_norm_scale") + return out diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/__init__.py new file mode 100755 index 000000000..85f92f3aa --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from . import yolo_loss +from . import smooth_l1_loss +from . import giou_loss +from . import diou_loss +from . import iou_loss +from . import balanced_l1_loss +from . import fcos_loss +from . import diou_loss_yolo +from . import iou_aware_loss +from . import ssd_with_lmk_loss +from . import solov2_loss + +from .iou_aware_loss import * +from .yolo_loss import * +from .smooth_l1_loss import * +from .giou_loss import * +from .diou_loss import * +from .iou_loss import * +from .balanced_l1_loss import * +from .fcos_loss import * +from .diou_loss_yolo import * +from .ssd_with_lmk_loss import * +from .solov2_loss import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/balanced_l1_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/balanced_l1_loss.py new file mode 100755 index 000000000..08e2087f8 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/balanced_l1_loss.py @@ -0,0 +1,73 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from paddle import fluid +from ppdet.core.workspace import register, serializable + +__all__ = ['BalancedL1Loss'] + + +@register +@serializable +class BalancedL1Loss(object): + """ + Balanced L1 Loss, see https://arxiv.org/abs/1904.02701 + Args: + alpha (float): hyper parameter of BalancedL1Loss, see more details in the paper + gamma (float): hyper parameter of BalancedL1Loss, see more details in the paper + beta (float): hyper parameter of BalancedL1Loss, see more details in the paper + loss_weights (float): loss weight + """ + + def __init__(self, alpha=0.5, gamma=1.5, beta=1.0, loss_weight=1.0): + super(BalancedL1Loss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.beta = beta + self.loss_weight = loss_weight + + def __call__( + self, + x, + y, + inside_weight=None, + outside_weight=None, ): + alpha = self.alpha + gamma = self.gamma + beta = self.beta + loss_weight = self.loss_weight + diff = fluid.layers.abs(x - y) + b = np.e**(gamma / alpha) - 1 + less_beta = diff < beta + ge_beta = diff >= beta + less_beta = fluid.layers.cast(x=less_beta, dtype='float32') + ge_beta = fluid.layers.cast(x=ge_beta, dtype='float32') + less_beta.stop_gradient = True + ge_beta.stop_gradient = True + loss_1 = less_beta * ( + alpha / b * (b * diff + 1) * fluid.layers.log(b * diff / beta + 1) - + alpha * diff) + loss_2 = ge_beta * (gamma * diff + gamma / b - alpha * beta) + iou_weights = 1.0 + if inside_weight is not None and outside_weight is not None: + iou_weights = inside_weight * outside_weight + loss = (loss_1 + loss_2) * iou_weights + loss = fluid.layers.reduce_sum(loss, dim=-1) * loss_weight + return loss diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/diou_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/diou_loss.py new file mode 100755 index 000000000..c3dbac433 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/diou_loss.py @@ -0,0 +1,123 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from paddle import fluid +from ppdet.core.workspace import register, serializable +from .giou_loss import GiouLoss + +__all__ = ['DiouLoss'] + + +@register +@serializable +class DiouLoss(GiouLoss): + """ + Distance-IoU Loss, see https://arxiv.org/abs/1911.08287 + Args: + loss_weight (float): diou loss weight, default as 10 in faster-rcnn + is_cls_agnostic (bool): flag of class-agnostic + num_classes (int): class num + use_complete_iou_loss (bool): whether to use complete iou loss + """ + + def __init__(self, + loss_weight=10., + is_cls_agnostic=False, + num_classes=81, + use_complete_iou_loss=True): + super(DiouLoss, self).__init__( + loss_weight=loss_weight, + is_cls_agnostic=is_cls_agnostic, + num_classes=num_classes) + self.use_complete_iou_loss = use_complete_iou_loss + + def __call__(self, + x, + y, + inside_weight=None, + outside_weight=None, + bbox_reg_weight=[0.1, 0.1, 0.2, 0.2]): + eps = 1.e-10 + x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight) + x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight) + + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + + cxg = (x1g + x2g) / 2 + cyg = (y1g + y2g) / 2 + wg = x2g - x1g + hg = y2g - y1g + + x2 = fluid.layers.elementwise_max(x1, x2) + y2 = fluid.layers.elementwise_max(y1, y2) + + # A and B + xkis1 = fluid.layers.elementwise_max(x1, x1g) + ykis1 = fluid.layers.elementwise_max(y1, y1g) + xkis2 = fluid.layers.elementwise_min(x2, x2g) + ykis2 = fluid.layers.elementwise_min(y2, y2g) + + # A or B + xc1 = fluid.layers.elementwise_min(x1, x1g) + yc1 = fluid.layers.elementwise_min(y1, y1g) + xc2 = fluid.layers.elementwise_max(x2, x2g) + yc2 = fluid.layers.elementwise_max(y2, y2g) + + intsctk = (xkis2 - xkis1) * (ykis2 - ykis1) + intsctk = intsctk * fluid.layers.greater_than( + xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g + ) - intsctk + eps + iouk = intsctk / unionk + + # DIOU term + dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg) + dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1) + diou_term = (dist_intersection + eps) / (dist_union + eps) + + # CIOU term + ciou_term = 0 + if self.use_complete_iou_loss: + ar_gt = wg / hg + ar_pred = w / h + arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred) + ar_loss = 4. / np.pi / np.pi * arctan * arctan + alpha = ar_loss / (1 - iouk + ar_loss + eps) + alpha.stop_gradient = True + ciou_term = alpha * ar_loss + + iou_weights = 1 + if inside_weight is not None and outside_weight is not None: + inside_weight = fluid.layers.reshape(inside_weight, shape=(-1, 4)) + outside_weight = fluid.layers.reshape(outside_weight, shape=(-1, 4)) + + inside_weight = fluid.layers.reduce_mean(inside_weight, dim=1) + outside_weight = fluid.layers.reduce_mean(outside_weight, dim=1) + + iou_weights = inside_weight * outside_weight + + class_weight = 2 if self.is_cls_agnostic else self.num_classes + diou = fluid.layers.reduce_mean( + (1 - iouk + ciou_term + diou_term) * iou_weights) * class_weight + + return diou * self.loss_weight diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/diou_loss_yolo.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/diou_loss_yolo.py new file mode 100755 index 000000000..d85154d08 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/diou_loss_yolo.py @@ -0,0 +1,112 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import NumpyArrayInitializer + +from paddle import fluid +from ppdet.core.workspace import register, serializable +from .iou_loss import IouLoss + +__all__ = ['DiouLossYolo'] + + +@register +@serializable +class DiouLossYolo(IouLoss): + """ + Distance-IoU Loss, see https://arxiv.org/abs/1911.08287 + Args: + loss_weight (float): diou loss weight, default is 5 + max_height (int): max height of input to support random shape input + max_width (int): max width of input to support random shape input + """ + + def __init__(self, loss_weight=5, max_height=608, max_width=608): + self._loss_weight = loss_weight + self._MAX_HI = max_height + self._MAX_WI = max_width + + def __call__(self, + x, + y, + w, + h, + tx, + ty, + tw, + th, + anchors, + downsample_ratio, + batch_size, + eps=1.e-10): + ''' + Args: + x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h + tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h + anchors ([float]): list of anchors for current output layer + downsample_ratio (float): the downsample ratio for current output layer + batch_size (int): training batch size + eps (float): the decimal to prevent the denominator eqaul zero + ''' + x1, y1, x2, y2 = self._bbox_transform( + x, y, w, h, anchors, downsample_ratio, batch_size, False, 1.0, eps) + x1g, y1g, x2g, y2g = self._bbox_transform(tx, ty, tw, th, anchors, + downsample_ratio, batch_size, + True, 1.0, eps) + + #central coordinates + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + + cxg = (x1g + x2g) / 2 + cyg = (y1g + y2g) / 2 + wg = x2g - x1g + hg = y2g - y1g + + x2 = fluid.layers.elementwise_max(x1, x2) + y2 = fluid.layers.elementwise_max(y1, y2) + # A and B + xkis1 = fluid.layers.elementwise_max(x1, x1g) + ykis1 = fluid.layers.elementwise_max(y1, y1g) + xkis2 = fluid.layers.elementwise_min(x2, x2g) + ykis2 = fluid.layers.elementwise_min(y2, y2g) + # A or B + xc1 = fluid.layers.elementwise_min(x1, x1g) + yc1 = fluid.layers.elementwise_min(y1, y1g) + xc2 = fluid.layers.elementwise_max(x2, x2g) + yc2 = fluid.layers.elementwise_max(y2, y2g) + + intsctk = (xkis2 - xkis1) * (ykis2 - ykis1) + intsctk = intsctk * fluid.layers.greater_than( + xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g + ) - intsctk + eps + iouk = intsctk / unionk + + # diou_loss + dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg) + dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1) + diou_term = (dist_intersection + eps) / (dist_union + eps) + + loss_diou = 1. - iouk + diou_term + loss_diou = loss_diou * self._loss_weight + + return loss_diou diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/fcos_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/fcos_loss.py new file mode 100755 index 000000000..7275874c6 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/fcos_loss.py @@ -0,0 +1,208 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer +from ppdet.core.workspace import register, serializable + +INF = 1e8 +__all__ = ['FCOSLoss'] + + +@register +@serializable +class FCOSLoss(object): + """ + FCOSLoss + Args: + loss_alpha (float): alpha in focal loss + loss_gamma (float): gamma in focal loss + iou_loss_type(str): location loss type, IoU/GIoU/LINEAR_IoU + reg_weights(float): weight for location loss + """ + + def __init__(self, + loss_alpha=0.25, + loss_gamma=2.0, + iou_loss_type="IoU", + reg_weights=1.0): + self.loss_alpha = loss_alpha + self.loss_gamma = loss_gamma + self.iou_loss_type = iou_loss_type + self.reg_weights = reg_weights + + def __flatten_tensor(self, input, channel_first=False): + """ + Flatten a Tensor + Args: + input (Variables): Input Tensor + channel_first(bool): if true the dimension order of + Tensor is [N, C, H, W], otherwise is [N, H, W, C] + Return: + input_channel_last (Variables): The flattened Tensor in channel_last style + """ + if channel_first: + input_channel_last = fluid.layers.transpose( + input, perm=[0, 2, 3, 1]) + else: + input_channel_last = input + input_channel_last = fluid.layers.flatten(input_channel_last, axis=3) + return input_channel_last + + def __iou_loss(self, pred, targets, positive_mask, weights=None): + """ + Calculate the loss for location prediction + Args: + pred (Variables): bounding boxes prediction + targets (Variables): targets for positive samples + positive_mask (Variables): mask of positive samples + weights (Variables): weights for each positive samples + Return: + loss (Varialbes): location loss + """ + plw = fluid.layers.elementwise_mul(pred[:, 0], positive_mask, axis=0) + pth = fluid.layers.elementwise_mul(pred[:, 1], positive_mask, axis=0) + prw = fluid.layers.elementwise_mul(pred[:, 2], positive_mask, axis=0) + pbh = fluid.layers.elementwise_mul(pred[:, 3], positive_mask, axis=0) + tlw = fluid.layers.elementwise_mul(targets[:, 0], positive_mask, axis=0) + tth = fluid.layers.elementwise_mul(targets[:, 1], positive_mask, axis=0) + trw = fluid.layers.elementwise_mul(targets[:, 2], positive_mask, axis=0) + tbh = fluid.layers.elementwise_mul(targets[:, 3], positive_mask, axis=0) + tlw.stop_gradient = True + trw.stop_gradient = True + tth.stop_gradient = True + tbh.stop_gradient = True + area_target = (tlw + trw) * (tth + tbh) + area_predict = (plw + prw) * (pth + pbh) + ilw = fluid.layers.elementwise_min(plw, tlw) + irw = fluid.layers.elementwise_min(prw, trw) + ith = fluid.layers.elementwise_min(pth, tth) + ibh = fluid.layers.elementwise_min(pbh, tbh) + clw = fluid.layers.elementwise_max(plw, tlw) + crw = fluid.layers.elementwise_max(prw, trw) + cth = fluid.layers.elementwise_max(pth, tth) + cbh = fluid.layers.elementwise_max(pbh, tbh) + area_inter = (ilw + irw) * (ith + ibh) + ious = (area_inter + 1.0) / ( + area_predict + area_target - area_inter + 1.0) + ious = fluid.layers.elementwise_mul(ious, positive_mask, axis=0) + if self.iou_loss_type.lower() == "linear_iou": + loss = 1.0 - ious + elif self.iou_loss_type.lower() == "giou": + area_uniou = area_predict + area_target - area_inter + area_circum = (clw + crw) * (cth + cbh) + 1e-7 + giou = ious - (area_circum - area_uniou) / area_circum + loss = 1.0 - giou + elif self.iou_loss_type.lower() == "iou": + loss = 0.0 - fluid.layers.log(ious) + else: + raise KeyError + if weights is not None: + loss = loss * weights + return loss + + def __call__(self, cls_logits, bboxes_reg, centerness, tag_labels, + tag_bboxes, tag_center): + """ + Calculate the loss for classification, location and centerness + Args: + cls_logits (list): list of Variables, which is predicted + score for all anchor points with shape [N, M, C] + bboxes_reg (list): list of Variables, which is predicted + offsets for all anchor points with shape [N, M, 4] + centerness (list): list of Variables, which is predicted + centerness for all anchor points with shape [N, M, 1] + tag_labels (list): list of Variables, which is category + targets for each anchor point + tag_bboxes (list): list of Variables, which is bounding + boxes targets for positive samples + tag_center (list): list of Variables, which is centerness + targets for positive samples + Return: + loss (dict): loss composed by classification loss, bounding box + """ + cls_logits_flatten_list = [] + bboxes_reg_flatten_list = [] + centerness_flatten_list = [] + tag_labels_flatten_list = [] + tag_bboxes_flatten_list = [] + tag_center_flatten_list = [] + num_lvl = len(cls_logits) + for lvl in range(num_lvl): + cls_logits_flatten_list.append( + self.__flatten_tensor(cls_logits[num_lvl - 1 - lvl], True)) + bboxes_reg_flatten_list.append( + self.__flatten_tensor(bboxes_reg[num_lvl - 1 - lvl], True)) + centerness_flatten_list.append( + self.__flatten_tensor(centerness[num_lvl - 1 - lvl], True)) + tag_labels_flatten_list.append( + self.__flatten_tensor(tag_labels[lvl], False)) + tag_bboxes_flatten_list.append( + self.__flatten_tensor(tag_bboxes[lvl], False)) + tag_center_flatten_list.append( + self.__flatten_tensor(tag_center[lvl], False)) + + cls_logits_flatten = fluid.layers.concat( + cls_logits_flatten_list, axis=0) + bboxes_reg_flatten = fluid.layers.concat( + bboxes_reg_flatten_list, axis=0) + centerness_flatten = fluid.layers.concat( + centerness_flatten_list, axis=0) + tag_labels_flatten = fluid.layers.concat( + tag_labels_flatten_list, axis=0) + tag_bboxes_flatten = fluid.layers.concat( + tag_bboxes_flatten_list, axis=0) + tag_center_flatten = fluid.layers.concat( + tag_center_flatten_list, axis=0) + tag_labels_flatten.stop_gradient = True + tag_bboxes_flatten.stop_gradient = True + tag_center_flatten.stop_gradient = True + + mask_positive = tag_labels_flatten > 0 + mask_positive.stop_gradient = True + mask_positive_float = fluid.layers.cast(mask_positive, dtype="float32") + mask_positive_float.stop_gradient = True + num_positive_fp32 = fluid.layers.reduce_sum(mask_positive_float) + num_positive_int32 = fluid.layers.cast(num_positive_fp32, dtype="int32") + num_positive_int32 = num_positive_int32 * 0 + 1 + num_positive_fp32.stop_gradient = True + num_positive_int32.stop_gradient = True + normalize_sum = fluid.layers.sum(tag_center_flatten) + normalize_sum.stop_gradient = True + normalize_sum = fluid.layers.reduce_sum(mask_positive_float * + normalize_sum) + + normalize_sum.stop_gradient = True + cls_loss = fluid.layers.sigmoid_focal_loss( + cls_logits_flatten, tag_labels_flatten, + num_positive_int32) / num_positive_fp32 + reg_loss = self.__iou_loss(bboxes_reg_flatten, tag_bboxes_flatten, + mask_positive_float, tag_center_flatten) + reg_loss = fluid.layers.elementwise_mul( + reg_loss, mask_positive_float, axis=0) / normalize_sum + ctn_loss = fluid.layers.sigmoid_cross_entropy_with_logits( + x=centerness_flatten, label=tag_center_flatten) + ctn_loss = fluid.layers.elementwise_mul( + ctn_loss, mask_positive_float, axis=0) / num_positive_fp32 + loss_all = { + "loss_centerness": fluid.layers.reduce_sum(ctn_loss), + "loss_cls": fluid.layers.reduce_sum(cls_loss), + "loss_box": fluid.layers.reduce_sum(reg_loss) + } + return loss_all diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/giou_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/giou_loss.py new file mode 100755 index 000000000..82c07c828 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/giou_loss.py @@ -0,0 +1,147 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from paddle import fluid +from ppdet.core.workspace import register, serializable + +__all__ = ['GiouLoss'] + + +@register +@serializable +class GiouLoss(object): + ''' + Generalized Intersection over Union, see https://arxiv.org/abs/1902.09630 + Args: + loss_weight (float): diou loss weight, default as 10 in faster-rcnn + is_cls_agnostic (bool): flag of class-agnostic + num_classes (int): class num + do_average (bool): whether to average the loss + use_class_weight(bool): whether to use class weight + ''' + __shared__ = ['num_classes'] + + def __init__(self, + loss_weight=10., + is_cls_agnostic=False, + num_classes=81, + do_average=True, + use_class_weight=True): + super(GiouLoss, self).__init__() + self.loss_weight = loss_weight + self.is_cls_agnostic = is_cls_agnostic + self.num_classes = num_classes + self.do_average = do_average + self.class_weight = 2 if is_cls_agnostic else num_classes + self.use_class_weight = use_class_weight + + # deltas: NxMx4 + def bbox_transform(self, deltas, weights): + wx, wy, ww, wh = weights + + deltas = fluid.layers.reshape(deltas, shape=(0, -1, 4)) + + dx = fluid.layers.slice(deltas, axes=[2], starts=[0], ends=[1]) * wx + dy = fluid.layers.slice(deltas, axes=[2], starts=[1], ends=[2]) * wy + dw = fluid.layers.slice(deltas, axes=[2], starts=[2], ends=[3]) * ww + dh = fluid.layers.slice(deltas, axes=[2], starts=[3], ends=[4]) * wh + + dw = fluid.layers.clip(dw, -1.e10, np.log(1000. / 16)) + dh = fluid.layers.clip(dh, -1.e10, np.log(1000. / 16)) + + pred_ctr_x = dx + pred_ctr_y = dy + pred_w = fluid.layers.exp(dw) + pred_h = fluid.layers.exp(dh) + + x1 = pred_ctr_x - 0.5 * pred_w + y1 = pred_ctr_y - 0.5 * pred_h + x2 = pred_ctr_x + 0.5 * pred_w + y2 = pred_ctr_y + 0.5 * pred_h + + x1 = fluid.layers.reshape(x1, shape=(-1, )) + y1 = fluid.layers.reshape(y1, shape=(-1, )) + x2 = fluid.layers.reshape(x2, shape=(-1, )) + y2 = fluid.layers.reshape(y2, shape=(-1, )) + + return x1, y1, x2, y2 + + def __call__(self, + x, + y, + inside_weight=None, + outside_weight=None, + bbox_reg_weight=[0.1, 0.1, 0.2, 0.2], + use_transform=True): + eps = 1.e-10 + if use_transform: + x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight) + x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight) + else: + x1, y1, x2, y2 = fluid.layers.split(x, num_or_sections=4, dim=1) + x1g, y1g, x2g, y2g = fluid.layers.split(y, num_or_sections=4, dim=1) + + x2 = fluid.layers.elementwise_max(x1, x2) + y2 = fluid.layers.elementwise_max(y1, y2) + + xkis1 = fluid.layers.elementwise_max(x1, x1g) + ykis1 = fluid.layers.elementwise_max(y1, y1g) + xkis2 = fluid.layers.elementwise_min(x2, x2g) + ykis2 = fluid.layers.elementwise_min(y2, y2g) + + xc1 = fluid.layers.elementwise_min(x1, x1g) + yc1 = fluid.layers.elementwise_min(y1, y1g) + xc2 = fluid.layers.elementwise_max(x2, x2g) + yc2 = fluid.layers.elementwise_max(y2, y2g) + + intsctk = (xkis2 - xkis1) * (ykis2 - ykis1) + intsctk = intsctk * fluid.layers.greater_than( + xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g + ) - intsctk + eps + + iouk = intsctk / unionk + + area_c = (xc2 - xc1) * (yc2 - yc1) + eps + miouk = iouk - ((area_c - unionk) / area_c) + + iou_weights = 1 + if inside_weight is not None and outside_weight is not None: + inside_weight = fluid.layers.reshape(inside_weight, shape=(-1, 4)) + outside_weight = fluid.layers.reshape(outside_weight, shape=(-1, 4)) + + inside_weight = fluid.layers.reduce_mean(inside_weight, dim=1) + outside_weight = fluid.layers.reduce_mean(outside_weight, dim=1) + + iou_weights = inside_weight * outside_weight + elif outside_weight is not None: + iou_weights = outside_weight + + if self.do_average: + miouk = fluid.layers.reduce_mean((1 - miouk) * iou_weights) + else: + iou_distance = fluid.layers.elementwise_mul( + 1 - miouk, iou_weights, axis=0) + miouk = fluid.layers.reduce_sum(iou_distance) + + if self.use_class_weight: + miouk = miouk * self.class_weight + + return miouk * self.loss_weight diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/iou_aware_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/iou_aware_loss.py new file mode 100755 index 000000000..c68c7a707 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/iou_aware_loss.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import NumpyArrayInitializer + +from paddle import fluid +from ppdet.core.workspace import register, serializable +from .iou_loss import IouLoss + +__all__ = ['IouAwareLoss'] + + +@register +@serializable +class IouAwareLoss(IouLoss): + """ + iou aware loss, see https://arxiv.org/abs/1912.05992 + Args: + loss_weight (float): iou aware loss weight, default is 1.0 + max_height (int): max height of input to support random shape input + max_width (int): max width of input to support random shape input + """ + + def __init__(self, loss_weight=1.0, max_height=608, max_width=608): + super(IouAwareLoss, self).__init__( + loss_weight=loss_weight, max_height=max_height, max_width=max_width) + + def __call__(self, + ioup, + x, + y, + w, + h, + tx, + ty, + tw, + th, + anchors, + downsample_ratio, + batch_size, + scale_x_y, + eps=1.e-10): + ''' + Args: + ioup ([Variables]): the predicted iou + x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h + tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h + anchors ([float]): list of anchors for current output layer + downsample_ratio (float): the downsample ratio for current output layer + batch_size (int): training batch size + eps (float): the decimal to prevent the denominator eqaul zero + ''' + + pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio, + batch_size, False, scale_x_y, eps) + gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio, + batch_size, True, scale_x_y, eps) + iouk = self._iou(pred, gt, ioup, eps) + iouk.stop_gradient = True + + loss_iou_aware = fluid.layers.cross_entropy(ioup, iouk, soft_label=True) + loss_iou_aware = loss_iou_aware * self._loss_weight + return loss_iou_aware diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/iou_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/iou_loss.py new file mode 100755 index 000000000..590217000 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/iou_loss.py @@ -0,0 +1,235 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import NumpyArrayInitializer + +from paddle import fluid +from ppdet.core.workspace import register, serializable + +__all__ = ['IouLoss'] + + +@register +@serializable +class IouLoss(object): + """ + iou loss, see https://arxiv.org/abs/1908.03851 + loss = 1.0 - iou * iou + Args: + loss_weight (float): iou loss weight, default is 2.5 + max_height (int): max height of input to support random shape input + max_width (int): max width of input to support random shape input + ciou_term (bool): whether to add ciou_term + loss_square (bool): whether to square the iou term + """ + + def __init__(self, + loss_weight=2.5, + max_height=608, + max_width=608, + ciou_term=False, + loss_square=True): + self._loss_weight = loss_weight + self._MAX_HI = max_height + self._MAX_WI = max_width + self.ciou_term = ciou_term + self.loss_square = loss_square + + def __call__(self, + x, + y, + w, + h, + tx, + ty, + tw, + th, + anchors, + downsample_ratio, + batch_size, + scale_x_y=1., + ioup=None, + eps=1.e-10): + ''' + Args: + x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h + tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h + anchors ([float]): list of anchors for current output layer + downsample_ratio (float): the downsample ratio for current output layer + batch_size (int): training batch size + eps (float): the decimal to prevent the denominator eqaul zero + ''' + pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio, + batch_size, False, scale_x_y, eps) + gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio, + batch_size, True, scale_x_y, eps) + iouk = self._iou(pred, gt, ioup, eps) + if self.loss_square: + loss_iou = 1. - iouk * iouk + else: + loss_iou = 1. - iouk + loss_iou = loss_iou * self._loss_weight + + return loss_iou + + def _iou(self, pred, gt, ioup=None, eps=1.e-10): + x1, y1, x2, y2 = pred + x1g, y1g, x2g, y2g = gt + x2 = fluid.layers.elementwise_max(x1, x2) + y2 = fluid.layers.elementwise_max(y1, y2) + + xkis1 = fluid.layers.elementwise_max(x1, x1g) + ykis1 = fluid.layers.elementwise_max(y1, y1g) + xkis2 = fluid.layers.elementwise_min(x2, x2g) + ykis2 = fluid.layers.elementwise_min(y2, y2g) + + intsctk = (xkis2 - xkis1) * (ykis2 - ykis1) + intsctk = intsctk * fluid.layers.greater_than( + xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g + ) - intsctk + eps + iouk = intsctk / unionk + if self.ciou_term: + ciou = self.get_ciou_term(pred, gt, iouk, eps) + iouk = iouk - ciou + return iouk + + def get_ciou_term(self, pred, gt, iouk, eps): + x1, y1, x2, y2 = pred + x1g, y1g, x2g, y2g = gt + + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = (x2 - x1) + fluid.layers.cast((x2 - x1) == 0, 'float32') + h = (y2 - y1) + fluid.layers.cast((y2 - y1) == 0, 'float32') + + cxg = (x1g + x2g) / 2 + cyg = (y1g + y2g) / 2 + wg = x2g - x1g + hg = y2g - y1g + + # A or B + xc1 = fluid.layers.elementwise_min(x1, x1g) + yc1 = fluid.layers.elementwise_min(y1, y1g) + xc2 = fluid.layers.elementwise_max(x2, x2g) + yc2 = fluid.layers.elementwise_max(y2, y2g) + + # DIOU term + dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg) + dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1) + diou_term = (dist_intersection + eps) / (dist_union + eps) + # CIOU term + ciou_term = 0 + ar_gt = wg / hg + ar_pred = w / h + arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred) + ar_loss = 4. / np.pi / np.pi * arctan * arctan + alpha = ar_loss / (1 - iouk + ar_loss + eps) + alpha.stop_gradient = True + ciou_term = alpha * ar_loss + return diou_term + ciou_term + + def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio, + batch_size, is_gt, scale_x_y, eps): + grid_x = int(self._MAX_WI / downsample_ratio) + grid_y = int(self._MAX_HI / downsample_ratio) + an_num = len(anchors) // 2 + + shape_fmp = fluid.layers.shape(dcx) + shape_fmp.stop_gradient = True + # generate the grid_w x grid_h center of feature map + idx_i = np.array([[i for i in range(grid_x)]]) + idx_j = np.array([[j for j in range(grid_y)]]).transpose() + gi_np = np.repeat(idx_i, grid_y, axis=0) + gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x]) + gi_np = np.tile(gi_np, reps=[batch_size, an_num, 1, 1]) + gj_np = np.repeat(idx_j, grid_x, axis=1) + gj_np = np.reshape(gj_np, newshape=[1, 1, grid_y, grid_x]) + gj_np = np.tile(gj_np, reps=[batch_size, an_num, 1, 1]) + gi_max = self._create_tensor_from_numpy(gi_np.astype(np.float32)) + gi = fluid.layers.crop(x=gi_max, shape=dcx) + gi.stop_gradient = True + gj_max = self._create_tensor_from_numpy(gj_np.astype(np.float32)) + gj = fluid.layers.crop(x=gj_max, shape=dcx) + gj.stop_gradient = True + + grid_x_act = fluid.layers.cast(shape_fmp[3], dtype="float32") + grid_x_act.stop_gradient = True + grid_y_act = fluid.layers.cast(shape_fmp[2], dtype="float32") + grid_y_act.stop_gradient = True + if is_gt: + cx = fluid.layers.elementwise_add(dcx, gi) / grid_x_act + cx.gradient = True + cy = fluid.layers.elementwise_add(dcy, gj) / grid_y_act + cy.gradient = True + else: + dcx_sig = fluid.layers.sigmoid(dcx) + dcy_sig = fluid.layers.sigmoid(dcy) + if (abs(scale_x_y - 1.0) > eps): + dcx_sig = scale_x_y * dcx_sig - 0.5 * (scale_x_y - 1) + dcy_sig = scale_x_y * dcy_sig - 0.5 * (scale_x_y - 1) + cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act + cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act + + anchor_w_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 0] + anchor_w_np = np.array(anchor_w_) + anchor_w_np = np.reshape(anchor_w_np, newshape=[1, an_num, 1, 1]) + anchor_w_np = np.tile(anchor_w_np, reps=[batch_size, 1, grid_y, grid_x]) + anchor_w_max = self._create_tensor_from_numpy( + anchor_w_np.astype(np.float32)) + anchor_w = fluid.layers.crop(x=anchor_w_max, shape=dcx) + anchor_w.stop_gradient = True + anchor_h_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 1] + anchor_h_np = np.array(anchor_h_) + anchor_h_np = np.reshape(anchor_h_np, newshape=[1, an_num, 1, 1]) + anchor_h_np = np.tile(anchor_h_np, reps=[batch_size, 1, grid_y, grid_x]) + anchor_h_max = self._create_tensor_from_numpy( + anchor_h_np.astype(np.float32)) + anchor_h = fluid.layers.crop(x=anchor_h_max, shape=dcx) + anchor_h.stop_gradient = True + # e^tw e^th + exp_dw = fluid.layers.exp(dw) + exp_dh = fluid.layers.exp(dh) + pw = fluid.layers.elementwise_mul(exp_dw, anchor_w) / \ + (grid_x_act * downsample_ratio) + ph = fluid.layers.elementwise_mul(exp_dh, anchor_h) / \ + (grid_y_act * downsample_ratio) + if is_gt: + exp_dw.stop_gradient = True + exp_dh.stop_gradient = True + pw.stop_gradient = True + ph.stop_gradient = True + + x1 = cx - 0.5 * pw + y1 = cy - 0.5 * ph + x2 = cx + 0.5 * pw + y2 = cy + 0.5 * ph + if is_gt: + x1.stop_gradient = True + y1.stop_gradient = True + x2.stop_gradient = True + y2.stop_gradient = True + + return x1, y1, x2, y2 + + def _create_tensor_from_numpy(self, numpy_array): + paddle_array = fluid.layers.create_global_var( + shape=numpy_array.shape, value=0., dtype=numpy_array.dtype) + fluid.layers.assign(numpy_array, paddle_array) + return paddle_array diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/smooth_l1_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/smooth_l1_loss.py new file mode 100755 index 000000000..22c87b624 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/smooth_l1_loss.py @@ -0,0 +1,44 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from ppdet.core.workspace import register, serializable + +__all__ = ['SmoothL1Loss'] + + +@register +@serializable +class SmoothL1Loss(object): + ''' + Smooth L1 loss + Args: + sigma (float): hyper param in smooth l1 loss + ''' + + def __init__(self, sigma=1.0): + super(SmoothL1Loss, self).__init__() + self.sigma = sigma + + def __call__(self, x, y, inside_weight=None, outside_weight=None): + return fluid.layers.smooth_l1( + x, + y, + inside_weight=inside_weight, + outside_weight=outside_weight, + sigma=self.sigma) diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/solov2_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/solov2_loss.py new file mode 100755 index 000000000..e3439a860 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/solov2_loss.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import fluid +from ppdet.core.workspace import register, serializable + +__all__ = ['SOLOv2Loss'] + + +@register +@serializable +class SOLOv2Loss(object): + """ + SOLOv2Loss + Args: + ins_loss_weight (float): Weight of instance loss. + focal_loss_gamma (float): Gamma parameter for focal loss. + focal_loss_alpha (float): Alpha parameter for focal loss. + """ + + def __init__(self, + ins_loss_weight=3.0, + focal_loss_gamma=2.0, + focal_loss_alpha=0.25): + self.ins_loss_weight = ins_loss_weight + self.focal_loss_gamma = focal_loss_gamma + self.focal_loss_alpha = focal_loss_alpha + + def _dice_loss(self, input, target): + input = fluid.layers.reshape( + input, shape=(fluid.layers.shape(input)[0], -1)) + target = fluid.layers.reshape( + target, shape=(fluid.layers.shape(target)[0], -1)) + target = fluid.layers.cast(target, 'float32') + a = fluid.layers.reduce_sum(input * target, dim=1) + b = fluid.layers.reduce_sum(input * input, dim=1) + 0.001 + c = fluid.layers.reduce_sum(target * target, dim=1) + 0.001 + d = (2 * a) / (b + c) + return 1 - d + + def __call__(self, ins_pred_list, ins_label_list, cate_preds, cate_labels, + num_ins): + """ + Get loss of network of SOLOv2. + Args: + ins_pred_list (list): Variable list of instance branch output. + ins_label_list (list): List of instance labels pre batch. + cate_preds (list): Concat Variable list of categroy branch output. + cate_labels (list): Concat list of categroy labels pre batch. + num_ins (int): Number of positive samples in a mini-batch. + Returns: + loss_ins (Variable): The instance loss Variable of SOLOv2 network. + loss_cate (Variable): The category loss Variable of SOLOv2 network. + """ + + # Ues dice_loss to calculate instance loss + loss_ins = [] + total_weights = fluid.layers.zeros(shape=[1], dtype='float32') + for input, target in zip(ins_pred_list, ins_label_list): + weights = fluid.layers.cast( + fluid.layers.reduce_sum( + target, dim=[1, 2]) > 0, 'float32') + input = fluid.layers.sigmoid(input) + dice_out = fluid.layers.elementwise_mul( + self._dice_loss(input, target), weights) + total_weights += fluid.layers.reduce_sum(weights) + loss_ins.append(dice_out) + loss_ins = fluid.layers.reduce_sum(fluid.layers.concat( + loss_ins)) / total_weights + loss_ins = loss_ins * self.ins_loss_weight + + # Ues sigmoid_focal_loss to calculate category loss + loss_cate = fluid.layers.sigmoid_focal_loss( + x=cate_preds, + label=cate_labels, + fg_num=num_ins + 1, + gamma=self.focal_loss_gamma, + alpha=self.focal_loss_alpha) + loss_cate = fluid.layers.reduce_sum(loss_cate) + + return loss_ins, loss_cate diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/ssd_with_lmk_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/ssd_with_lmk_loss.py new file mode 100755 index 000000000..8cedee2d4 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/ssd_with_lmk_loss.py @@ -0,0 +1,241 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import Variable +import paddle.fluid.layers as layers +from paddle.fluid.layers import (tensor, iou_similarity, bipartite_match, + target_assign, box_coder) +from ppdet.core.workspace import register, serializable + +__all__ = ['SSDWithLmkLoss'] + + +@register +@serializable +class SSDWithLmkLoss(object): + """ + ssd_with_lmk_loss function. + Args: + background_label (int): The index of background label, 0 by default. + overlap_threshold (float): If match_type is `per_prediction`, + use `overlap_threshold` to determine the extra matching bboxes + when finding matched boxes. 0.5 by default. + neg_pos_ratio (float): The ratio of the negative boxes to the positive + boxes, used only when mining_type is `max_negative`, 3.0 by default. + neg_overlap (float): The negative overlap upper bound for the unmatched + predictions. Use only when mining_type is `max_negative`, 0.5 by default. + loc_loss_weight (float): Weight for localization loss, 1.0 by default. + conf_loss_weight (float): Weight for confidence loss, 1.0 by default. + match_type (str): The type of matching method during training, should be + `bipartite` or `per_prediction`, `per_prediction` by default. + normalize (bool): Whether to normalize the loss by the total number of + output locations, True by default. + """ + + def __init__(self, + background_label=0, + overlap_threshold=0.5, + neg_pos_ratio=3.0, + neg_overlap=0.5, + loc_loss_weight=1.0, + conf_loss_weight=1.0, + match_type='per_prediction', + normalize=True): + super(SSDWithLmkLoss, self).__init__() + self.background_label = background_label + self.overlap_threshold = overlap_threshold + self.neg_pos_ratio = neg_pos_ratio + self.neg_overlap = neg_overlap + self.loc_loss_weight = loc_loss_weight + self.conf_loss_weight = conf_loss_weight + self.match_type = match_type + self.normalize = normalize + + def __call__(self, + location, + confidence, + gt_box, + gt_label, + landmark_predict, + lmk_label, + lmk_ignore_flag, + prior_box, + prior_box_var=None): + def _reshape_to_2d(var): + return layers.flatten(x=var, axis=2) + + helper = LayerHelper('ssd_loss') #, **locals()) + # Only support mining_type == 'max_negative' now. + mining_type = 'max_negative' + # The max `sample_size` of negative box, used only + # when mining_type is `hard_example`. + sample_size = None + num, num_prior, num_class = confidence.shape + conf_shape = layers.shape(confidence) + + # 1. Find matched boundding box by prior box. + # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. + iou = iou_similarity(x=gt_box, y=prior_box) + # 1.2 Compute matched boundding box by bipartite matching algorithm. + matched_indices, matched_dist = bipartite_match(iou, self.match_type, + self.overlap_threshold) + + # 2. Compute confidence for mining hard examples + # 2.1. Get the target label based on matched indices + gt_label = layers.reshape( + x=gt_label, shape=(len(gt_label.shape) - 1) * (0, ) + (-1, 1)) + gt_label.stop_gradient = True + target_label, _ = target_assign( + gt_label, matched_indices, mismatch_value=self.background_label) + # 2.2. Compute confidence loss. + # Reshape confidence to 2D tensor. + confidence = _reshape_to_2d(confidence) + target_label = tensor.cast(x=target_label, dtype='int64') + target_label = _reshape_to_2d(target_label) + target_label.stop_gradient = True + conf_loss = layers.softmax_with_cross_entropy(confidence, target_label) + # 3. Mining hard examples + actual_shape = layers.slice(conf_shape, axes=[0], starts=[0], ends=[2]) + actual_shape.stop_gradient = True + conf_loss = layers.reshape( + x=conf_loss, shape=(-1, 0), actual_shape=actual_shape) + conf_loss.stop_gradient = True + neg_indices = helper.create_variable_for_type_inference(dtype='int32') + updated_matched_indices = helper.create_variable_for_type_inference( + dtype=matched_indices.dtype) + helper.append_op( + type='mine_hard_examples', + inputs={ + 'ClsLoss': conf_loss, + 'LocLoss': None, + 'MatchIndices': matched_indices, + 'MatchDist': matched_dist, + }, + outputs={ + 'NegIndices': neg_indices, + 'UpdatedMatchIndices': updated_matched_indices + }, + attrs={ + 'neg_pos_ratio': self.neg_pos_ratio, + 'neg_dist_threshold': self.neg_overlap, + 'mining_type': mining_type, + 'sample_size': sample_size, + }) + + # 4. Assign classification and regression targets + # 4.1. Encoded bbox according to the prior boxes. + encoded_bbox = box_coder( + prior_box=prior_box, + prior_box_var=prior_box_var, + target_box=gt_box, + code_type='encode_center_size') + # 4.2. Assign regression targets + target_bbox, target_loc_weight = target_assign( + encoded_bbox, + updated_matched_indices, + mismatch_value=self.background_label) + # 4.3. Assign classification targets + target_label, target_conf_weight = target_assign( + gt_label, + updated_matched_indices, + negative_indices=neg_indices, + mismatch_value=self.background_label) + + target_loc_weight = target_loc_weight * target_label + encoded_lmk_label = self.decode_lmk(lmk_label, prior_box, prior_box_var) + + target_lmk, target_lmk_weight = target_assign( + encoded_lmk_label, + updated_matched_indices, + mismatch_value=self.background_label) + lmk_ignore_flag = layers.reshape( + x=lmk_ignore_flag, + shape=(len(lmk_ignore_flag.shape) - 1) * (0, ) + (-1, 1)) + target_ignore, nouse = target_assign( + lmk_ignore_flag, + updated_matched_indices, + mismatch_value=self.background_label) + + target_lmk_weight = target_lmk_weight * target_ignore + landmark_predict = _reshape_to_2d(landmark_predict) + target_lmk = _reshape_to_2d(target_lmk) + target_lmk_weight = _reshape_to_2d(target_lmk_weight) + lmk_loss = layers.smooth_l1(landmark_predict, target_lmk) + lmk_loss = lmk_loss * target_lmk_weight + target_lmk.stop_gradient = True + target_lmk_weight.stop_gradient = True + target_ignore.stop_gradient = True + nouse.stop_gradient = True + + # 5. Compute loss. + # 5.1 Compute confidence loss. + target_label = _reshape_to_2d(target_label) + target_label = tensor.cast(x=target_label, dtype='int64') + + conf_loss = layers.softmax_with_cross_entropy(confidence, target_label) + target_conf_weight = _reshape_to_2d(target_conf_weight) + conf_loss = conf_loss * target_conf_weight + + # the target_label and target_conf_weight do not have gradient. + target_label.stop_gradient = True + target_conf_weight.stop_gradient = True + + # 5.2 Compute regression loss. + location = _reshape_to_2d(location) + target_bbox = _reshape_to_2d(target_bbox) + + loc_loss = layers.smooth_l1(location, target_bbox) + target_loc_weight = _reshape_to_2d(target_loc_weight) + loc_loss = loc_loss * target_loc_weight + + # the target_bbox and target_loc_weight do not have gradient. + target_bbox.stop_gradient = True + target_loc_weight.stop_gradient = True + + # 5.3 Compute overall weighted loss. + loss = self.conf_loss_weight * conf_loss + self.loc_loss_weight * loc_loss + 0.4 * lmk_loss + # reshape to [N, Np], N is the batch size and Np is the prior box number. + loss = layers.reshape(x=loss, shape=(-1, 0), actual_shape=actual_shape) + loss = layers.reduce_sum(loss, dim=1, keep_dim=True) + if self.normalize: + normalizer = layers.reduce_sum(target_loc_weight) + 1 + loss = loss / normalizer + + return loss + + def decode_lmk(self, lmk_label, prior_box, prior_box_var): + label0, label1, label2, label3, label4 = fluid.layers.split( + lmk_label, num_or_sections=5, dim=1) + lmk_labels_list = [label0, label1, label2, label3, label4] + encoded_lmk_list = [] + for label in lmk_labels_list: + concat_label = fluid.layers.concat([label, label], axis=1) + encoded_label = box_coder( + prior_box=prior_box, + prior_box_var=prior_box_var, + target_box=concat_label, + code_type='encode_center_size') + encoded_lmk_label, _ = fluid.layers.split( + encoded_label, num_or_sections=2, dim=2) + encoded_lmk_list.append(encoded_lmk_label) + + encoded_lmk_concat = fluid.layers.concat( + [ + encoded_lmk_list[0], encoded_lmk_list[1], encoded_lmk_list[2], + encoded_lmk_list[3], encoded_lmk_list[4] + ], + axis=2) + return encoded_lmk_concat diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/yolo_loss.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/yolo_loss.py new file mode 100755 index 000000000..9445f6a23 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/losses/yolo_loss.py @@ -0,0 +1,384 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from ppdet.core.workspace import register +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence + +import logging +logger = logging.getLogger(__name__) + +__all__ = ['YOLOv3Loss'] + + +@register +class YOLOv3Loss(object): + """ + Combined loss for YOLOv3 network + + Args: + train_batch_size (int): training batch size + ignore_thresh (float): threshold to ignore confidence loss + label_smooth (bool): whether to use label smoothing + use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss + instead of fluid.layers.yolov3_loss + """ + __inject__ = ['iou_loss', 'iou_aware_loss'] + __shared__ = ['use_fine_grained_loss', 'train_batch_size'] + + def __init__( + self, + train_batch_size=8, + batch_size=-1, # stub for backward compatable + ignore_thresh=0.7, + label_smooth=True, + use_fine_grained_loss=False, + iou_loss=None, + iou_aware_loss=None, + downsample=[32, 16, 8], + scale_x_y=1., + match_score=False): + self._train_batch_size = train_batch_size + self._ignore_thresh = ignore_thresh + self._label_smooth = label_smooth + self._use_fine_grained_loss = use_fine_grained_loss + self._iou_loss = iou_loss + self._iou_aware_loss = iou_aware_loss + self.downsample = downsample + self.scale_x_y = scale_x_y + self.match_score = match_score + + if batch_size != -1: + logger.warn( + "config YOLOv3Loss.batch_size is deprecated, " + "training batch size should be set by TrainReader.batch_size") + + def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors, + anchor_masks, mask_anchors, num_classes, prefix_name): + if self._use_fine_grained_loss: + return self._get_fine_grained_loss( + outputs, targets, gt_box, self._train_batch_size, num_classes, + mask_anchors, self._ignore_thresh) + else: + losses = [] + for i, output in enumerate(outputs): + scale_x_y = self.scale_x_y if not isinstance( + self.scale_x_y, Sequence) else self.scale_x_y[i] + anchor_mask = anchor_masks[i] + loss = fluid.layers.yolov3_loss( + x=output, + gt_box=gt_box, + gt_label=gt_label, + gt_score=gt_score, + anchors=anchors, + anchor_mask=anchor_mask, + class_num=num_classes, + ignore_thresh=self._ignore_thresh, + downsample_ratio=self.downsample[i], + use_label_smooth=self._label_smooth, + scale_x_y=scale_x_y, + name=prefix_name + "yolo_loss" + str(i)) + + losses.append(fluid.layers.reduce_mean(loss)) + + return {'loss': sum(losses)} + + def _get_fine_grained_loss(self, + outputs, + targets, + gt_box, + train_batch_size, + num_classes, + mask_anchors, + ignore_thresh, + eps=1.e-10): + """ + Calculate fine grained YOLOv3 loss + + Args: + outputs ([Variables]): List of Variables, output of backbone stages + targets ([Variables]): List of Variables, The targets for yolo + loss calculatation. + gt_box (Variable): The ground-truth boudding boxes. + train_batch_size (int): The training batch size + num_classes (int): class num of dataset + mask_anchors ([[float]]): list of anchors in each output layer + ignore_thresh (float): prediction bbox overlap any gt_box greater + than ignore_thresh, objectness loss will + be ignored. + + Returns: + Type: dict + xy_loss (Variable): YOLOv3 (x, y) coordinates loss + wh_loss (Variable): YOLOv3 (w, h) coordinates loss + obj_loss (Variable): YOLOv3 objectness score loss + cls_loss (Variable): YOLOv3 classification loss + + """ + + assert len(outputs) == len(targets), \ + "YOLOv3 output layer number not equal target number" + + loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], [] + if self._iou_loss is not None: + loss_ious = [] + if self._iou_aware_loss is not None: + loss_iou_awares = [] + for i, (output, target, + anchors) in enumerate(zip(outputs, targets, mask_anchors)): + downsample = self.downsample[i] + an_num = len(anchors) // 2 + if self._iou_aware_loss is not None: + ioup, output = self._split_ioup(output, an_num, num_classes) + x, y, w, h, obj, cls = self._split_output(output, an_num, + num_classes) + tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target) + + tscale_tobj = tscale * tobj + + scale_x_y = self.scale_x_y if not isinstance( + self.scale_x_y, Sequence) else self.scale_x_y[i] + + if (abs(scale_x_y - 1.0) < eps): + loss_x = fluid.layers.sigmoid_cross_entropy_with_logits( + x, tx) * tscale_tobj + loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) + loss_y = fluid.layers.sigmoid_cross_entropy_with_logits( + y, ty) * tscale_tobj + loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3]) + else: + dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y - + 1.0) + dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y - + 1.0) + loss_x = fluid.layers.abs(dx - tx) * tscale_tobj + loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) + loss_y = fluid.layers.abs(dy - ty) * tscale_tobj + loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3]) + + # NOTE: we refined loss function of (w, h) as L1Loss + loss_w = fluid.layers.abs(w - tw) * tscale_tobj + loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3]) + loss_h = fluid.layers.abs(h - th) * tscale_tobj + loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3]) + if self._iou_loss is not None: + loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors, + downsample, self._train_batch_size, + scale_x_y) + loss_iou = loss_iou * tscale_tobj + loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3]) + loss_ious.append(fluid.layers.reduce_mean(loss_iou)) + + if self._iou_aware_loss is not None: + loss_iou_aware = self._iou_aware_loss( + ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample, + self._train_batch_size, scale_x_y) + loss_iou_aware = loss_iou_aware * tobj + loss_iou_aware = fluid.layers.reduce_sum( + loss_iou_aware, dim=[1, 2, 3]) + loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware)) + + loss_obj_pos, loss_obj_neg = self._calc_obj_loss( + output, obj, tobj, gt_box, self._train_batch_size, anchors, + num_classes, downsample, self._ignore_thresh, scale_x_y) + + loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls) + loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0) + loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4]) + + loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y)) + loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h)) + loss_objs.append( + fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg)) + loss_clss.append(fluid.layers.reduce_mean(loss_cls)) + + losses_all = { + "loss_xy": fluid.layers.sum(loss_xys), + "loss_wh": fluid.layers.sum(loss_whs), + "loss_obj": fluid.layers.sum(loss_objs), + "loss_cls": fluid.layers.sum(loss_clss), + } + if self._iou_loss is not None: + losses_all["loss_iou"] = fluid.layers.sum(loss_ious) + if self._iou_aware_loss is not None: + losses_all["loss_iou_aware"] = fluid.layers.sum(loss_iou_awares) + return losses_all + + def _split_ioup(self, output, an_num, num_classes): + """ + Split output feature map to output, predicted iou + along channel dimension + """ + ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num]) + ioup = fluid.layers.sigmoid(ioup) + oriout = fluid.layers.slice( + output, + axes=[1], + starts=[an_num], + ends=[an_num * (num_classes + 6)]) + return (ioup, oriout) + + def _split_output(self, output, an_num, num_classes): + """ + Split output feature map to x, y, w, h, objectness, classification + along channel dimension + """ + x = fluid.layers.strided_slice( + output, + axes=[1], + starts=[0], + ends=[output.shape[1]], + strides=[5 + num_classes]) + y = fluid.layers.strided_slice( + output, + axes=[1], + starts=[1], + ends=[output.shape[1]], + strides=[5 + num_classes]) + w = fluid.layers.strided_slice( + output, + axes=[1], + starts=[2], + ends=[output.shape[1]], + strides=[5 + num_classes]) + h = fluid.layers.strided_slice( + output, + axes=[1], + starts=[3], + ends=[output.shape[1]], + strides=[5 + num_classes]) + obj = fluid.layers.strided_slice( + output, + axes=[1], + starts=[4], + ends=[output.shape[1]], + strides=[5 + num_classes]) + clss = [] + stride = output.shape[1] // an_num + for m in range(an_num): + clss.append( + fluid.layers.slice( + output, + axes=[1], + starts=[stride * m + 5], + ends=[stride * m + 5 + num_classes])) + cls = fluid.layers.transpose( + fluid.layers.stack( + clss, axis=1), perm=[0, 1, 3, 4, 2]) + + return (x, y, w, h, obj, cls) + + def _split_target(self, target): + """ + split target to x, y, w, h, objectness, classification + along dimension 2 + + target is in shape [N, an_num, 6 + class_num, H, W] + """ + tx = target[:, :, 0, :, :] + ty = target[:, :, 1, :, :] + tw = target[:, :, 2, :, :] + th = target[:, :, 3, :, :] + + tscale = target[:, :, 4, :, :] + tobj = target[:, :, 5, :, :] + + tcls = fluid.layers.transpose( + target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2]) + tcls.stop_gradient = True + + return (tx, ty, tw, th, tscale, tobj, tcls) + + def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors, + num_classes, downsample, ignore_thresh, scale_x_y): + # A prediction bbox overlap any gt_bbox over ignore_thresh, + # objectness loss will be ignored, process as follows: + + # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here + # NOTE: img_size is set as 1.0 to get noramlized pred bbox + bbox, prob = fluid.layers.yolo_box( + x=output, + img_size=fluid.layers.ones( + shape=[batch_size, 2], dtype="int32"), + anchors=anchors, + class_num=num_classes, + conf_thresh=0., + downsample_ratio=downsample, + clip_bbox=False, + scale_x_y=scale_x_y) + + # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox + # and gt bbox in each sample + if batch_size > 1: + preds = fluid.layers.split(bbox, batch_size, dim=0) + gts = fluid.layers.split(gt_box, batch_size, dim=0) + else: + preds = [bbox] + gts = [gt_box] + probs = [prob] + ious = [] + for pred, gt in zip(preds, gts): + + def box_xywh2xyxy(box): + x = box[:, 0] + y = box[:, 1] + w = box[:, 2] + h = box[:, 3] + return fluid.layers.stack( + [ + x - w / 2., + y - h / 2., + x + w / 2., + y + h / 2., + ], axis=1) + + pred = fluid.layers.squeeze(pred, axes=[0]) + gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0])) + ious.append(fluid.layers.iou_similarity(pred, gt)) + + iou = fluid.layers.stack(ious, axis=0) + # 3. Get iou_mask by IoU between gt bbox and prediction bbox, + # Get obj_mask by tobj(holds gt_score), calculate objectness loss + + max_iou = fluid.layers.reduce_max(iou, dim=-1) + iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32") + if self.match_score: + max_prob = fluid.layers.reduce_max(prob, dim=-1) + iou_mask = iou_mask * fluid.layers.cast( + max_prob <= 0.25, dtype="float32") + output_shape = fluid.layers.shape(output) + an_num = len(anchors) // 2 + iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2], + output_shape[3])) + iou_mask.stop_gradient = True + + # NOTE: tobj holds gt_score, obj_mask holds object existence mask + obj_mask = fluid.layers.cast(tobj > 0., dtype="float32") + obj_mask.stop_gradient = True + + # For positive objectness grids, objectness loss should be calculated + # For negative objectness grids, objectness loss is calculated only iou_mask == 1.0 + loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj, obj_mask) + loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3]) + loss_obj_neg = fluid.layers.reduce_sum( + loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3]) + + return loss_obj_pos, loss_obj_neg diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/mask_head/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/mask_head/__init__.py new file mode 100755 index 000000000..ba2b3b30b --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/mask_head/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from . import solo_mask_head + +from .solo_mask_head import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/mask_head/solo_mask_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/mask_head/solo_mask_head.py new file mode 100755 index 000000000..61e8e2175 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/mask_head/solo_mask_head.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import fluid + +from ppdet.core.workspace import register +from ppdet.modeling.ops import ConvNorm, DeformConvNorm + +__all__ = ['SOLOv2MaskHead'] + + +@register +class SOLOv2MaskHead(object): + """ + MaskHead of SOLOv2 + + Args: + in_channels (int): The channel number of input variable. + out_channels (int): The channel number of output variable. + start_level (int): The position where the input starts. + end_level (int): The position where the input ends. + use_dcn_in_tower: Whether to use dcn in tower or not. + """ + + def __init__(self, + in_channels=128, + out_channels=128, + start_level=0, + end_level=3, + use_dcn_in_tower=False): + super(SOLOv2MaskHead, self).__init__() + assert start_level >= 0 and end_level >= start_level + self.out_channels = out_channels + self.start_level = start_level + self.end_level = end_level + self.in_channels = in_channels + self.use_dcn_in_tower = use_dcn_in_tower + self.conv_type = [ConvNorm, DeformConvNorm] + + def _convs_levels(self, conv_feat, level, name=None): + conv_func = self.conv_type[0] + if self.use_dcn_in_tower: + conv_func = self.conv_type[1] + + if level == 0: + return conv_func( + input=conv_feat, + num_filters=self.in_channels, + filter_size=3, + stride=1, + norm_type='gn', + norm_groups=32, + freeze_norm=False, + act='relu', + initializer=fluid.initializer.NormalInitializer(scale=0.01), + norm_name=name + '.conv' + str(level) + '.gn', + name=name + '.conv' + str(level)) + + for j in range(level): + conv_feat = conv_func( + input=conv_feat, + num_filters=self.in_channels, + filter_size=3, + stride=1, + norm_type='gn', + norm_groups=32, + freeze_norm=False, + act='relu', + initializer=fluid.initializer.NormalInitializer(scale=0.01), + norm_name=name + '.conv' + str(j) + '.gn', + name=name + '.conv' + str(j)) + conv_feat = fluid.layers.resize_bilinear( + conv_feat, + scale=2, + name='upsample' + str(level) + str(j), + align_corners=False, + align_mode=0) + return conv_feat + + def _conv_pred(self, conv_feat): + conv_func = self.conv_type[0] + if self.use_dcn_in_tower: + conv_func = self.conv_type[1] + conv_feat = conv_func( + input=conv_feat, + num_filters=self.out_channels, + filter_size=1, + stride=1, + norm_type='gn', + norm_groups=32, + freeze_norm=False, + act='relu', + initializer=fluid.initializer.NormalInitializer(scale=0.01), + norm_name='mask_feat_head.conv_pred.0.gn', + name='mask_feat_head.conv_pred.0') + + return conv_feat + + def get_output(self, inputs): + """ + Get SOLOv2MaskHead output. + + Args: + inputs(list[Variable]): feature map from each necks with shape of [N, C, H, W] + Returns: + ins_pred(Variable): Output of SOLOv2MaskHead head + """ + range_level = self.end_level - self.start_level + 1 + feature_add_all_level = self._convs_levels( + inputs[0], 0, name='mask_feat_head.convs_all_levels.0') + for i in range(1, range_level): + input_p = inputs[i] + if i == (range_level - 1): + input_feat = input_p + x_range = paddle.linspace( + -1, 1, fluid.layers.shape(input_feat)[-1], dtype='float32') + y_range = paddle.linspace( + -1, 1, fluid.layers.shape(input_feat)[-2], dtype='float32') + y, x = paddle.tensor.meshgrid([y_range, x_range]) + x = fluid.layers.unsqueeze(x, [0, 1]) + y = fluid.layers.unsqueeze(y, [0, 1]) + y = fluid.layers.expand( + y, + expand_times=[fluid.layers.shape(input_feat)[0], 1, 1, 1]) + x = fluid.layers.expand( + x, + expand_times=[fluid.layers.shape(input_feat)[0], 1, 1, 1]) + coord_feat = fluid.layers.concat([x, y], axis=1) + input_p = fluid.layers.concat([input_p, coord_feat], axis=1) + feature_add_all_level = fluid.layers.elementwise_add( + feature_add_all_level, + self._convs_levels( + input_p, + i, + name='mask_feat_head.convs_all_levels.{}'.format(i))) + ins_pred = self._conv_pred(feature_add_all_level) + + return ins_pred diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/ops.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/ops.py new file mode 100755 index 000000000..85d1fe17c --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/ops.py @@ -0,0 +1,1703 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from numbers import Integral +import math +import six + +import paddle +from paddle import fluid +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.initializer import NumpyArrayInitializer +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay +from ppdet.core.workspace import register, serializable +from ppdet.utils.bbox_utils import bbox_overlaps, box_to_delta + +__all__ = [ + 'AnchorGenerator', 'AnchorGrid', 'DropBlock', 'RPNTargetAssign', + 'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', + 'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead', + 'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm', + 'DeformConvNorm', 'MultiClassSoftNMS', 'MatrixNMS', 'LibraBBoxAssigner', + 'DeformConv' +] + + +def _conv_offset(input, filter_size, stride, padding, act=None, name=None): + out_channel = filter_size * filter_size * 3 + out = fluid.layers.conv2d( + input, + num_filters=out_channel, + filter_size=filter_size, + stride=stride, + padding=padding, + param_attr=ParamAttr( + initializer=fluid.initializer.Constant(0), name=name + ".w_0"), + bias_attr=ParamAttr( + initializer=fluid.initializer.Constant(0), + learning_rate=2., + regularizer=L2Decay(0.), + name=name + ".b_0"), + act=act, + name=name) + return out + + +def DeformConv(input, + num_filters, + filter_size, + stride=1, + groups=1, + dilation=1, + lr_scale=1, + initializer=None, + bias_attr=False, + name=None): + if bias_attr: + bias_para = ParamAttr( + name=name + "_bias", + initializer=fluid.initializer.Constant(0), + regularizer=L2Decay(0.), + learning_rate=lr_scale * 2) + else: + bias_para = False + offset_mask = _conv_offset( + input=input, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + act=None, + name=name + "_conv_offset") + offset_channel = filter_size**2 * 2 + mask_channel = filter_size**2 + offset, mask = fluid.layers.split( + input=offset_mask, + num_or_sections=[offset_channel, mask_channel], + dim=1) + mask = fluid.layers.sigmoid(mask) + conv = fluid.layers.deformable_conv( + input=input, + offset=offset, + mask=mask, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2 * dilation, + dilation=dilation, + groups=groups, + deformable_groups=1, + im2col_step=1, + param_attr=ParamAttr( + name=name + "_weights", + initializer=initializer, + learning_rate=lr_scale), + bias_attr=bias_para, + name=name + ".conv2d.output.1") + + return conv + + +def DeformConvNorm(input, + num_filters, + filter_size, + stride=1, + groups=1, + norm_decay=0., + norm_type='affine_channel', + norm_groups=32, + dilation=1, + lr_scale=1, + freeze_norm=False, + act=None, + norm_name=None, + initializer=None, + bias_attr=False, + name=None): + assert norm_type in ['bn', 'sync_bn', 'affine_channel', 'gn'] + conv = DeformConv(input, num_filters, filter_size, stride, groups, dilation, + lr_scale, initializer, bias_attr, name) + + norm_lr = 0. if freeze_norm else 1. + pattr = ParamAttr( + name=norm_name + '_scale', + learning_rate=norm_lr * lr_scale, + regularizer=L2Decay(norm_decay)) + battr = ParamAttr( + name=norm_name + '_offset', + learning_rate=norm_lr * lr_scale, + regularizer=L2Decay(norm_decay)) + + if norm_type in ['bn', 'sync_bn']: + global_stats = True if freeze_norm else False + out = fluid.layers.batch_norm( + input=conv, + act=act, + name=norm_name + '.output.1', + param_attr=pattr, + bias_attr=battr, + moving_mean_name=norm_name + '_mean', + moving_variance_name=norm_name + '_variance', + use_global_stats=global_stats) + scale = fluid.framework._get_var(pattr.name) + bias = fluid.framework._get_var(battr.name) + elif norm_type == 'gn': + out = fluid.layers.group_norm( + input=conv, + act=act, + name=norm_name + '.output.1', + groups=norm_groups, + param_attr=pattr, + bias_attr=battr) + scale = fluid.framework._get_var(pattr.name) + bias = fluid.framework._get_var(battr.name) + elif norm_type == 'affine_channel': + scale = fluid.layers.create_parameter( + shape=[conv.shape[1]], + dtype=conv.dtype, + attr=pattr, + default_initializer=fluid.initializer.Constant(1.)) + bias = fluid.layers.create_parameter( + shape=[conv.shape[1]], + dtype=conv.dtype, + attr=battr, + default_initializer=fluid.initializer.Constant(0.)) + out = fluid.layers.affine_channel( + x=conv, scale=scale, bias=bias, act=act) + + if freeze_norm: + scale.stop_gradient = True + bias.stop_gradient = True + return out + + +def ConvNorm(input, + num_filters, + filter_size, + stride=1, + groups=1, + norm_decay=0., + norm_type='affine_channel', + norm_groups=32, + dilation=1, + lr_scale=1, + freeze_norm=False, + act=None, + norm_name=None, + initializer=None, + bias_attr=False, + name=None): + fan = num_filters + if bias_attr: + bias_para = ParamAttr( + name=name + "_bias", + initializer=fluid.initializer.Constant(value=0), + learning_rate=lr_scale * 2) + else: + bias_para = False + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=((filter_size - 1) // 2) * dilation, + dilation=dilation, + groups=groups, + act=None, + param_attr=ParamAttr( + name=name + "_weights", + initializer=initializer, + learning_rate=lr_scale), + bias_attr=bias_para, + name=name + '.conv2d.output.1') + + norm_lr = 0. if freeze_norm else 1. + pattr = ParamAttr( + name=norm_name + '_scale', + learning_rate=norm_lr * lr_scale, + regularizer=L2Decay(norm_decay)) + battr = ParamAttr( + name=norm_name + '_offset', + learning_rate=norm_lr * lr_scale, + regularizer=L2Decay(norm_decay)) + + if norm_type in ['bn', 'sync_bn']: + global_stats = True if freeze_norm else False + out = fluid.layers.batch_norm( + input=conv, + act=act, + name=norm_name + '.output.1', + param_attr=pattr, + bias_attr=battr, + moving_mean_name=norm_name + '_mean', + moving_variance_name=norm_name + '_variance', + use_global_stats=global_stats) + scale = fluid.framework._get_var(pattr.name) + bias = fluid.framework._get_var(battr.name) + elif norm_type == 'gn': + out = fluid.layers.group_norm( + input=conv, + act=act, + name=norm_name + '.output.1', + groups=norm_groups, + param_attr=pattr, + bias_attr=battr) + scale = fluid.framework._get_var(pattr.name) + bias = fluid.framework._get_var(battr.name) + elif norm_type == 'affine_channel': + scale = fluid.layers.create_parameter( + shape=[conv.shape[1]], + dtype=conv.dtype, + attr=pattr, + default_initializer=fluid.initializer.Constant(1.)) + bias = fluid.layers.create_parameter( + shape=[conv.shape[1]], + dtype=conv.dtype, + attr=battr, + default_initializer=fluid.initializer.Constant(0.)) + out = fluid.layers.affine_channel( + x=conv, scale=scale, bias=bias, act=act) + if freeze_norm: + scale.stop_gradient = True + bias.stop_gradient = True + return out + + +def DropBlock(input, block_size, keep_prob, is_test): + if is_test: + return input + + def CalculateGamma(input, block_size, keep_prob): + input_shape = fluid.layers.shape(input) + feat_shape_tmp = fluid.layers.slice(input_shape, [0], [3], [4]) + feat_shape_tmp = fluid.layers.cast(feat_shape_tmp, dtype="float32") + feat_shape_t = fluid.layers.reshape(feat_shape_tmp, [1, 1, 1, 1]) + feat_area = fluid.layers.pow(feat_shape_t, factor=2) + + block_shape_t = fluid.layers.fill_constant( + shape=[1, 1, 1, 1], value=block_size, dtype='float32') + block_area = fluid.layers.pow(block_shape_t, factor=2) + + useful_shape_t = feat_shape_t - block_shape_t + 1 + useful_area = fluid.layers.pow(useful_shape_t, factor=2) + + upper_t = feat_area * (1 - keep_prob) + bottom_t = block_area * useful_area + output = upper_t / bottom_t + return output + + gamma = CalculateGamma(input, block_size=block_size, keep_prob=keep_prob) + input_shape = fluid.layers.shape(input) + p = fluid.layers.expand_as(gamma, input) + + input_shape_tmp = fluid.layers.cast(input_shape, dtype="int64") + random_matrix = fluid.layers.uniform_random( + input_shape_tmp, dtype='float32', min=0.0, max=1.0) + one_zero_m = fluid.layers.less_than(random_matrix, p) + one_zero_m.stop_gradient = True + one_zero_m = fluid.layers.cast(one_zero_m, dtype="float32") + + mask_flag = fluid.layers.pool2d( + one_zero_m, + pool_size=block_size, + pool_type='max', + pool_stride=1, + pool_padding=block_size // 2) + mask = 1.0 - mask_flag + + elem_numel = fluid.layers.reduce_prod(input_shape) + elem_numel_m = fluid.layers.cast(elem_numel, dtype="float32") + elem_numel_m.stop_gradient = True + + elem_sum = fluid.layers.reduce_sum(mask) + elem_sum_m = fluid.layers.cast(elem_sum, dtype="float32") + elem_sum_m.stop_gradient = True + + output = input * mask * elem_numel_m / elem_sum_m + return output + + +@register +@serializable +class AnchorGenerator(object): + __op__ = fluid.layers.anchor_generator + __append_doc__ = True + + def __init__(self, + stride=[16.0, 16.0], + anchor_sizes=[32, 64, 128, 256, 512], + aspect_ratios=[0.5, 1., 2.], + variance=[1., 1., 1., 1.]): + super(AnchorGenerator, self).__init__() + self.anchor_sizes = anchor_sizes + self.aspect_ratios = aspect_ratios + self.variance = variance + self.stride = stride + + +@register +@serializable +class AnchorGrid(object): + """Generate anchor grid + Args: + image_size (int or list): input image size, may be a single integer or + list of [h, w]. Default: 512 + min_level (int): min level of the feature pyramid. Default: 3 + max_level (int): max level of the feature pyramid. Default: 7 + anchor_base_scale: base anchor scale. Default: 4 + num_scales: number of anchor scales. Default: 3 + aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]] + """ + + def __init__(self, + image_size=512, + min_level=3, + max_level=7, + anchor_base_scale=4, + num_scales=3, + aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]): + super(AnchorGrid, self).__init__() + if isinstance(image_size, Integral): + self.image_size = [image_size, image_size] + else: + self.image_size = image_size + for dim in self.image_size: + assert dim % 2 ** max_level == 0, \ + "image size should be multiple of the max level stride" + self.min_level = min_level + self.max_level = max_level + self.anchor_base_scale = anchor_base_scale + self.num_scales = num_scales + self.aspect_ratios = aspect_ratios + + @property + def base_cell(self): + if not hasattr(self, '_base_cell'): + self._base_cell = self.make_cell() + return self._base_cell + + def make_cell(self): + scales = [2**(i / self.num_scales) for i in range(self.num_scales)] + scales = np.array(scales) + ratios = np.array(self.aspect_ratios) + ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1) + hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1) + anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs)) + return anchors + + def make_grid(self, stride): + cell = self.base_cell * stride * self.anchor_base_scale + x_steps = np.arange(stride // 2, self.image_size[1], stride) + y_steps = np.arange(stride // 2, self.image_size[0], stride) + offset_x, offset_y = np.meshgrid(x_steps, y_steps) + offset_x = offset_x.flatten() + offset_y = offset_y.flatten() + offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1) + offsets = offsets[:, np.newaxis, :] + return (cell + offsets).reshape(-1, 4) + + def generate(self): + return [ + self.make_grid(2**l) + for l in range(self.min_level, self.max_level + 1) + ] + + def __call__(self): + if not hasattr(self, '_anchor_vars'): + anchor_vars = [] + helper = LayerHelper('anchor_grid') + for idx, l in enumerate(range(self.min_level, self.max_level + 1)): + stride = 2**l + anchors = self.make_grid(stride) + var = helper.create_parameter( + attr=ParamAttr(name='anchors_{}'.format(idx)), + shape=anchors.shape, + dtype='float32', + stop_gradient=True, + default_initializer=NumpyArrayInitializer(anchors)) + anchor_vars.append(var) + var.persistable = True + self._anchor_vars = anchor_vars + + return self._anchor_vars + + +@register +@serializable +class RPNTargetAssign(object): + __op__ = fluid.layers.rpn_target_assign + __append_doc__ = True + + def __init__(self, + rpn_batch_size_per_im=256, + rpn_straddle_thresh=0., + rpn_fg_fraction=0.5, + rpn_positive_overlap=0.7, + rpn_negative_overlap=0.3, + use_random=True): + super(RPNTargetAssign, self).__init__() + self.rpn_batch_size_per_im = rpn_batch_size_per_im + self.rpn_straddle_thresh = rpn_straddle_thresh + self.rpn_fg_fraction = rpn_fg_fraction + self.rpn_positive_overlap = rpn_positive_overlap + self.rpn_negative_overlap = rpn_negative_overlap + self.use_random = use_random + + +@register +@serializable +class GenerateProposals(object): + __op__ = fluid.layers.generate_proposals + __append_doc__ = True + + def __init__(self, + pre_nms_top_n=6000, + post_nms_top_n=1000, + nms_thresh=.5, + min_size=.1, + eta=1.): + super(GenerateProposals, self).__init__() + self.pre_nms_top_n = pre_nms_top_n + self.post_nms_top_n = post_nms_top_n + self.nms_thresh = nms_thresh + self.min_size = min_size + self.eta = eta + + +@register +class MaskAssigner(object): + __op__ = fluid.layers.generate_mask_labels + __append_doc__ = True + __shared__ = ['num_classes'] + + def __init__(self, num_classes=81, resolution=14): + super(MaskAssigner, self).__init__() + self.num_classes = num_classes + self.resolution = resolution + + +@register +@serializable +class MultiClassNMS(object): + __op__ = fluid.layers.multiclass_nms + __append_doc__ = True + + def __init__(self, + score_threshold=.05, + nms_top_k=-1, + keep_top_k=100, + nms_threshold=.5, + normalized=False, + nms_eta=1.0, + background_label=0): + super(MultiClassNMS, self).__init__() + self.score_threshold = score_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + self.nms_threshold = nms_threshold + self.normalized = normalized + self.nms_eta = nms_eta + self.background_label = background_label + + +@register +@serializable +class MatrixNMS(object): + __op__ = 'paddle.fluid.layers.matrix_nms' + __append_doc__ = True + + def __init__(self, + score_threshold=.05, + post_threshold=.05, + nms_top_k=-1, + keep_top_k=100, + use_gaussian=False, + gaussian_sigma=2., + normalized=False, + background_label=0): + super(MatrixNMS, self).__init__() + self.score_threshold = score_threshold + self.post_threshold = post_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + self.normalized = normalized + self.use_gaussian = use_gaussian + self.gaussian_sigma = gaussian_sigma + self.background_label = background_label + + +@register +@serializable +class MultiClassSoftNMS(object): + def __init__( + self, + score_threshold=0.01, + keep_top_k=300, + softnms_sigma=0.5, + normalized=False, + background_label=0, ): + super(MultiClassSoftNMS, self).__init__() + self.score_threshold = score_threshold + self.keep_top_k = keep_top_k + self.softnms_sigma = softnms_sigma + self.normalized = normalized + self.background_label = background_label + + def __call__(self, bboxes, scores): + def create_tmp_var(program, name, dtype, shape, lod_level): + return program.current_block().create_var( + name=name, dtype=dtype, shape=shape, lod_level=lod_level) + + def _soft_nms_for_cls(dets, sigma, thres): + """soft_nms_for_cls""" + dets_final = [] + while len(dets) > 0: + maxpos = np.argmax(dets[:, 0]) + dets_final.append(dets[maxpos].copy()) + ts, tx1, ty1, tx2, ty2 = dets[maxpos] + scores = dets[:, 0] + # force remove bbox at maxpos + scores[maxpos] = -1 + x1 = dets[:, 1] + y1 = dets[:, 2] + x2 = dets[:, 3] + y2 = dets[:, 4] + eta = 0 if self.normalized else 1 + areas = (x2 - x1 + eta) * (y2 - y1 + eta) + xx1 = np.maximum(tx1, x1) + yy1 = np.maximum(ty1, y1) + xx2 = np.minimum(tx2, x2) + yy2 = np.minimum(ty2, y2) + w = np.maximum(0.0, xx2 - xx1 + eta) + h = np.maximum(0.0, yy2 - yy1 + eta) + inter = w * h + ovr = inter / (areas + areas[maxpos] - inter) + weight = np.exp(-(ovr * ovr) / sigma) + scores = scores * weight + idx_keep = np.where(scores >= thres) + dets[:, 0] = scores + dets = dets[idx_keep] + dets_final = np.array(dets_final).reshape(-1, 5) + return dets_final + + def _soft_nms(bboxes, scores): + class_nums = scores.shape[-1] + + softnms_thres = self.score_threshold + softnms_sigma = self.softnms_sigma + keep_top_k = self.keep_top_k + + cls_boxes = [[] for _ in range(class_nums)] + cls_ids = [[] for _ in range(class_nums)] + + start_idx = 1 if self.background_label == 0 else 0 + for j in range(start_idx, class_nums): + inds = np.where(scores[:, j] >= softnms_thres)[0] + scores_j = scores[inds, j] + rois_j = bboxes[inds, j, :] if len( + bboxes.shape) > 2 else bboxes[inds, :] + dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype( + np.float32, copy=False) + cls_rank = np.argsort(-dets_j[:, 0]) + dets_j = dets_j[cls_rank] + + cls_boxes[j] = _soft_nms_for_cls( + dets_j, sigma=softnms_sigma, thres=softnms_thres) + cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1, + 1) + + cls_boxes = np.vstack(cls_boxes[start_idx:]) + cls_ids = np.vstack(cls_ids[start_idx:]) + pred_result = np.hstack([cls_ids, cls_boxes]) + + # Limit to max_per_image detections **over all classes** + image_scores = cls_boxes[:, 0] + if len(image_scores) > keep_top_k: + image_thresh = np.sort(image_scores)[-keep_top_k] + keep = np.where(cls_boxes[:, 0] >= image_thresh)[0] + pred_result = pred_result[keep, :] + + return pred_result + + def _batch_softnms(bboxes, scores): + batch_offsets = bboxes.lod() + bboxes = np.array(bboxes) + scores = np.array(scores) + out_offsets = [0] + pred_res = [] + if len(batch_offsets) > 0: + batch_offset = batch_offsets[0] + for i in range(len(batch_offset) - 1): + s, e = batch_offset[i], batch_offset[i + 1] + pred = _soft_nms(bboxes[s:e], scores[s:e]) + out_offsets.append(pred.shape[0] + out_offsets[-1]) + pred_res.append(pred) + else: + assert len(bboxes.shape) == 3 + assert len(scores.shape) == 3 + for i in range(bboxes.shape[0]): + pred = _soft_nms(bboxes[i], scores[i]) + out_offsets.append(pred.shape[0] + out_offsets[-1]) + pred_res.append(pred) + + res = fluid.LoDTensor() + res.set_lod([out_offsets]) + if len(pred_res) == 0: + pred_res = np.array([[1]], dtype=np.float32) + res.set(np.vstack(pred_res).astype(np.float32), fluid.CPUPlace()) + return res + + pred_result = create_tmp_var( + fluid.default_main_program(), + name='softnms_pred_result', + dtype='float32', + shape=[-1, 6], + lod_level=1) + fluid.layers.py_func( + func=_batch_softnms, x=[bboxes, scores], out=pred_result) + return pred_result + + +@register +@serializable +class MultiClassDiouNMS(object): + def __init__( + self, + score_threshold=0.05, + keep_top_k=100, + nms_threshold=0.5, + normalized=False, + background_label=0, ): + super(MultiClassDiouNMS, self).__init__() + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.keep_top_k = keep_top_k + self.normalized = normalized + self.background_label = background_label + + def __call__(self, bboxes, scores): + def create_tmp_var(program, name, dtype, shape, lod_level): + return program.current_block().create_var( + name=name, dtype=dtype, shape=shape, lod_level=lod_level) + + def _calc_diou_term(dets1, dets2): + eps = 1.e-10 + eta = 0 if self.normalized else 1 + + x1, y1, x2, y2 = dets1[0], dets1[1], dets1[2], dets1[3] + x1g, y1g, x2g, y2g = dets2[0], dets2[1], dets2[2], dets2[3] + + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + eta + h = y2 - y1 + eta + + cxg = (x1g + x2g) / 2 + cyg = (y1g + y2g) / 2 + wg = x2g - x1g + eta + hg = y2g - y1g + eta + + x2 = np.maximum(x1, x2) + y2 = np.maximum(y1, y2) + + # A or B + xc1 = np.minimum(x1, x1g) + yc1 = np.minimum(y1, y1g) + xc2 = np.maximum(x2, x2g) + yc2 = np.maximum(y2, y2g) + + # DIOU term + dist_intersection = (cx - cxg)**2 + (cy - cyg)**2 + dist_union = (xc2 - xc1)**2 + (yc2 - yc1)**2 + diou_term = (dist_intersection + eps) / (dist_union + eps) + return diou_term + + def _diou_nms_for_cls(dets, thres): + """_diou_nms_for_cls""" + scores = dets[:, 0] + x1 = dets[:, 1] + y1 = dets[:, 2] + x2 = dets[:, 3] + y2 = dets[:, 4] + eta = 0 if self.normalized else 1 + areas = (x2 - x1 + eta) * (y2 - y1 + eta) + dt_num = dets.shape[0] + order = np.array(range(dt_num)) + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + eta) + h = np.maximum(0.0, yy2 - yy1 + eta) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + diou_term = _calc_diou_term([x1[i], y1[i], x2[i], y2[i]], [ + x1[order[1:]], y1[order[1:]], x2[order[1:]], y2[order[1:]] + ]) + + inds = np.where(ovr - diou_term <= thres)[0] + + order = order[inds + 1] + + dets_final = dets[keep] + return dets_final + + def _diou_nms(bboxes, scores): + bboxes = np.array(bboxes) + scores = np.array(scores) + class_nums = scores.shape[-1] + + score_threshold = self.score_threshold + nms_threshold = self.nms_threshold + keep_top_k = self.keep_top_k + + cls_boxes = [[] for _ in range(class_nums)] + cls_ids = [[] for _ in range(class_nums)] + + start_idx = 1 if self.background_label == 0 else 0 + for j in range(start_idx, class_nums): + inds = np.where(scores[:, j] >= score_threshold)[0] + scores_j = scores[inds, j] + rois_j = bboxes[inds, j, :] + dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype( + np.float32, copy=False) + cls_rank = np.argsort(-dets_j[:, 0]) + dets_j = dets_j[cls_rank] + + cls_boxes[j] = _diou_nms_for_cls(dets_j, thres=nms_threshold) + cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1, + 1) + + cls_boxes = np.vstack(cls_boxes[start_idx:]) + cls_ids = np.vstack(cls_ids[start_idx:]) + pred_result = np.hstack([cls_ids, cls_boxes]).astype(np.float32) + + # Limit to max_per_image detections **over all classes** + image_scores = cls_boxes[:, 0] + if len(image_scores) > keep_top_k: + image_thresh = np.sort(image_scores)[-keep_top_k] + keep = np.where(cls_boxes[:, 0] >= image_thresh)[0] + pred_result = pred_result[keep, :] + + res = fluid.LoDTensor() + res.set_lod([[0, pred_result.shape[0]]]) + if pred_result.shape[0] == 0: + pred_result = np.array([[1]], dtype=np.float32) + res.set(pred_result, fluid.CPUPlace()) + + return res + + pred_result = create_tmp_var( + fluid.default_main_program(), + name='diou_nms_pred_result', + dtype='float32', + shape=[-1, 6], + lod_level=0) + fluid.layers.py_func( + func=_diou_nms, x=[bboxes, scores], out=pred_result) + return pred_result + + +@register +class BBoxAssigner(object): + __op__ = fluid.layers.generate_proposal_labels + __append_doc__ = True + __shared__ = ['num_classes'] + + def __init__(self, + batch_size_per_im=512, + fg_fraction=.25, + fg_thresh=.5, + bg_thresh_hi=.5, + bg_thresh_lo=0., + bbox_reg_weights=[0.1, 0.1, 0.2, 0.2], + num_classes=81, + shuffle_before_sample=True): + super(BBoxAssigner, self).__init__() + self.batch_size_per_im = batch_size_per_im + self.fg_fraction = fg_fraction + self.fg_thresh = fg_thresh + self.bg_thresh_hi = bg_thresh_hi + self.bg_thresh_lo = bg_thresh_lo + self.bbox_reg_weights = bbox_reg_weights + self.class_nums = num_classes + self.use_random = shuffle_before_sample + + +@register +class LibraBBoxAssigner(object): + __shared__ = ['num_classes'] + + def __init__(self, + batch_size_per_im=512, + fg_fraction=.25, + fg_thresh=.5, + bg_thresh_hi=.5, + bg_thresh_lo=0., + bbox_reg_weights=[0.1, 0.1, 0.2, 0.2], + num_classes=81, + shuffle_before_sample=True, + is_cls_agnostic=False, + num_bins=3): + super(LibraBBoxAssigner, self).__init__() + self.batch_size_per_im = batch_size_per_im + self.fg_fraction = fg_fraction + self.fg_thresh = fg_thresh + self.bg_thresh_hi = bg_thresh_hi + self.bg_thresh_lo = bg_thresh_lo + self.bbox_reg_weights = bbox_reg_weights + self.class_nums = num_classes + self.use_random = shuffle_before_sample + self.is_cls_agnostic = is_cls_agnostic + self.num_bins = num_bins + + def __call__( + self, + rpn_rois, + gt_classes, + is_crowd, + gt_boxes, + im_info, ): + return self.generate_proposal_label_libra( + rpn_rois=rpn_rois, + gt_classes=gt_classes, + is_crowd=is_crowd, + gt_boxes=gt_boxes, + im_info=im_info, + batch_size_per_im=self.batch_size_per_im, + fg_fraction=self.fg_fraction, + fg_thresh=self.fg_thresh, + bg_thresh_hi=self.bg_thresh_hi, + bg_thresh_lo=self.bg_thresh_lo, + bbox_reg_weights=self.bbox_reg_weights, + class_nums=self.class_nums, + use_random=self.use_random, + is_cls_agnostic=self.is_cls_agnostic, + is_cascade_rcnn=False) + + def generate_proposal_label_libra( + self, rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, + batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, + bg_thresh_lo, bbox_reg_weights, class_nums, use_random, + is_cls_agnostic, is_cascade_rcnn): + num_bins = self.num_bins + + def create_tmp_var(program, name, dtype, shape, lod_level=None): + return program.current_block().create_var( + name=name, dtype=dtype, shape=shape, lod_level=lod_level) + + def _sample_pos(max_overlaps, max_classes, pos_inds, num_expected): + if len(pos_inds) <= num_expected: + return pos_inds + else: + unique_gt_inds = np.unique(max_classes[pos_inds]) + num_gts = len(unique_gt_inds) + num_per_gt = int(round(num_expected / float(num_gts)) + 1) + + sampled_inds = [] + for i in unique_gt_inds: + inds = np.nonzero(max_classes == i)[0] + before_len = len(inds) + inds = list(set(inds) & set(pos_inds)) + after_len = len(inds) + if len(inds) > num_per_gt: + inds = np.random.choice( + inds, size=num_per_gt, replace=False) + sampled_inds.extend(list(inds)) # combine as a new sampler + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array( + list(set(pos_inds) - set(sampled_inds))) + assert len(sampled_inds)+len(extra_inds) == len(pos_inds), \ + "sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format( + len(sampled_inds), len(extra_inds), len(pos_inds)) + if len(extra_inds) > num_extra: + extra_inds = np.random.choice( + extra_inds, size=num_extra, replace=False) + sampled_inds.extend(extra_inds.tolist()) + elif len(sampled_inds) > num_expected: + sampled_inds = np.random.choice( + sampled_inds, size=num_expected, replace=False) + return sampled_inds + + def sample_via_interval(max_overlaps, full_set, num_expected, floor_thr, + num_bins, bg_thresh_hi): + max_iou = max_overlaps.max() + iou_interval = (max_iou - floor_thr) / num_bins + per_num_expected = int(num_expected / num_bins) + + sampled_inds = [] + for i in range(num_bins): + start_iou = floor_thr + i * iou_interval + end_iou = floor_thr + (i + 1) * iou_interval + + tmp_set = set( + np.where( + np.logical_and(max_overlaps >= start_iou, max_overlaps < + end_iou))[0]) + tmp_inds = list(tmp_set & full_set) + + if len(tmp_inds) > per_num_expected: + tmp_sampled_set = np.random.choice( + tmp_inds, size=per_num_expected, replace=False) + else: + tmp_sampled_set = np.array(tmp_inds, dtype=np.int) + sampled_inds.append(tmp_sampled_set) + + sampled_inds = np.concatenate(sampled_inds) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array(list(full_set - set(sampled_inds))) + assert len(sampled_inds)+len(extra_inds) == len(full_set), \ + "sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format( + len(sampled_inds), len(extra_inds), len(full_set)) + + if len(extra_inds) > num_extra: + extra_inds = np.random.choice( + extra_inds, num_extra, replace=False) + sampled_inds = np.concatenate([sampled_inds, extra_inds]) + + return sampled_inds + + def _sample_neg(max_overlaps, + max_classes, + neg_inds, + num_expected, + floor_thr=-1, + floor_fraction=0, + num_bins=3, + bg_thresh_hi=0.5): + if len(neg_inds) <= num_expected: + return neg_inds + else: + # balance sampling for negative samples + neg_set = set(neg_inds) + if floor_thr > 0: + floor_set = set( + np.where( + np.logical_and(max_overlaps >= 0, max_overlaps < + floor_thr))[0]) + iou_sampling_set = set( + np.where(max_overlaps >= floor_thr)[0]) + elif floor_thr == 0: + floor_set = set(np.where(max_overlaps == 0)[0]) + iou_sampling_set = set( + np.where(max_overlaps > floor_thr)[0]) + else: + floor_set = set() + iou_sampling_set = set( + np.where(max_overlaps > floor_thr)[0]) + floor_thr = 0 + + floor_neg_inds = list(floor_set & neg_set) + iou_sampling_neg_inds = list(iou_sampling_set & neg_set) + + num_expected_iou_sampling = int(num_expected * + (1 - floor_fraction)) + if len(iou_sampling_neg_inds) > num_expected_iou_sampling: + if num_bins >= 2: + iou_sampled_inds = sample_via_interval( + max_overlaps, + set(iou_sampling_neg_inds), + num_expected_iou_sampling, floor_thr, num_bins, + bg_thresh_hi) + else: + iou_sampled_inds = np.random.choice( + iou_sampling_neg_inds, + size=num_expected_iou_sampling, + replace=False) + else: + iou_sampled_inds = np.array( + iou_sampling_neg_inds, dtype=np.int) + num_expected_floor = num_expected - len(iou_sampled_inds) + if len(floor_neg_inds) > num_expected_floor: + sampled_floor_inds = np.random.choice( + floor_neg_inds, size=num_expected_floor, replace=False) + else: + sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int) + sampled_inds = np.concatenate( + (sampled_floor_inds, iou_sampled_inds)) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array(list(neg_set - set(sampled_inds))) + if len(extra_inds) > num_extra: + extra_inds = np.random.choice( + extra_inds, size=num_extra, replace=False) + sampled_inds = np.concatenate((sampled_inds, extra_inds)) + return sampled_inds + + def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, + batch_size_per_im, fg_fraction, fg_thresh, + bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, + class_nums, use_random, is_cls_agnostic, + is_cascade_rcnn): + rois_per_image = int(batch_size_per_im) + fg_rois_per_im = int(np.round(fg_fraction * rois_per_image)) + + # Roidb + im_scale = im_info[2] + inv_im_scale = 1. / im_scale + rpn_rois = rpn_rois * inv_im_scale + if is_cascade_rcnn: + rpn_rois = rpn_rois[gt_boxes.shape[0]:, :] + boxes = np.vstack([gt_boxes, rpn_rois]) + gt_overlaps = np.zeros((boxes.shape[0], class_nums)) + box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32) + if len(gt_boxes) > 0: + proposal_to_gt_overlaps = bbox_overlaps(boxes, gt_boxes) + + overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1) + overlaps_max = proposal_to_gt_overlaps.max(axis=1) + # Boxes which with non-zero overlap with gt boxes + overlapped_boxes_ind = np.where(overlaps_max > 0)[0] + + overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[ + overlapped_boxes_ind]] + + for idx in range(len(overlapped_boxes_ind)): + gt_overlaps[overlapped_boxes_ind[ + idx], overlapped_boxes_gt_classes[idx]] = overlaps_max[ + overlapped_boxes_ind[idx]] + box_to_gt_ind_map[overlapped_boxes_ind[ + idx]] = overlaps_argmax[overlapped_boxes_ind[idx]] + + crowd_ind = np.where(is_crowd)[0] + gt_overlaps[crowd_ind] = -1 + + max_overlaps = gt_overlaps.max(axis=1) + max_classes = gt_overlaps.argmax(axis=1) + + # Cascade RCNN Decode Filter + if is_cascade_rcnn: + ws = boxes[:, 2] - boxes[:, 0] + 1 + hs = boxes[:, 3] - boxes[:, 1] + 1 + keep = np.where((ws > 0) & (hs > 0))[0] + boxes = boxes[keep] + max_overlaps = max_overlaps[keep] + fg_inds = np.where(max_overlaps >= fg_thresh)[0] + bg_inds = np.where((max_overlaps < bg_thresh_hi) & ( + max_overlaps >= bg_thresh_lo))[0] + fg_rois_per_this_image = fg_inds.shape[0] + bg_rois_per_this_image = bg_inds.shape[0] + else: + # Foreground + fg_inds = np.where(max_overlaps >= fg_thresh)[0] + fg_rois_per_this_image = np.minimum(fg_rois_per_im, + fg_inds.shape[0]) + # Sample foreground if there are too many + if fg_inds.shape[0] > fg_rois_per_this_image: + if use_random: + fg_inds = _sample_pos(max_overlaps, max_classes, + fg_inds, fg_rois_per_this_image) + fg_inds = fg_inds[:fg_rois_per_this_image] + + # Background + bg_inds = np.where((max_overlaps < bg_thresh_hi) & ( + max_overlaps >= bg_thresh_lo))[0] + bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image + bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, + bg_inds.shape[0]) + assert bg_rois_per_this_image >= 0, "bg_rois_per_this_image must be >= 0 but got {}".format( + bg_rois_per_this_image) + + # Sample background if there are too many + if bg_inds.shape[0] > bg_rois_per_this_image: + if use_random: + # libra neg sample + bg_inds = _sample_neg( + max_overlaps, + max_classes, + bg_inds, + bg_rois_per_this_image, + num_bins=num_bins, + bg_thresh_hi=bg_thresh_hi) + bg_inds = bg_inds[:bg_rois_per_this_image] + + keep_inds = np.append(fg_inds, bg_inds) + sampled_labels = max_classes[keep_inds] # N x 1 + sampled_labels[fg_rois_per_this_image:] = 0 + sampled_boxes = boxes[keep_inds] # N x 324 + sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]] + sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0] + bbox_label_targets = _compute_targets( + sampled_boxes, sampled_gts, sampled_labels, bbox_reg_weights) + bbox_targets, bbox_inside_weights = _expand_bbox_targets( + bbox_label_targets, class_nums, is_cls_agnostic) + bbox_outside_weights = np.array( + bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype) + # Scale rois + sampled_rois = sampled_boxes * im_scale + + # Faster RCNN blobs + frcn_blobs = dict( + rois=sampled_rois, + labels_int32=sampled_labels, + bbox_targets=bbox_targets, + bbox_inside_weights=bbox_inside_weights, + bbox_outside_weights=bbox_outside_weights) + return frcn_blobs + + def _compute_targets(roi_boxes, gt_boxes, labels, bbox_reg_weights): + assert roi_boxes.shape[0] == gt_boxes.shape[0] + assert roi_boxes.shape[1] == 4 + assert gt_boxes.shape[1] == 4 + + targets = np.zeros(roi_boxes.shape) + bbox_reg_weights = np.asarray(bbox_reg_weights) + targets = box_to_delta( + ex_boxes=roi_boxes, gt_boxes=gt_boxes, weights=bbox_reg_weights) + + return np.hstack([labels[:, np.newaxis], targets]).astype( + np.float32, copy=False) + + def _expand_bbox_targets(bbox_targets_input, class_nums, + is_cls_agnostic): + class_labels = bbox_targets_input[:, 0] + fg_inds = np.where(class_labels > 0)[0] + bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums + if not is_cls_agnostic else 4 * 2)) + bbox_inside_weights = np.zeros(bbox_targets.shape) + for ind in fg_inds: + class_label = int(class_labels[ + ind]) if not is_cls_agnostic else 1 + start_ind = class_label * 4 + end_ind = class_label * 4 + 4 + bbox_targets[ind, start_ind:end_ind] = bbox_targets_input[ind, + 1:] + bbox_inside_weights[ind, start_ind:end_ind] = (1.0, 1.0, 1.0, + 1.0) + return bbox_targets, bbox_inside_weights + + def generate_func( + rpn_rois, + gt_classes, + is_crowd, + gt_boxes, + im_info, ): + rpn_rois_lod = rpn_rois.lod()[0] + gt_classes_lod = gt_classes.lod()[0] + + # convert + rpn_rois = np.array(rpn_rois) + gt_classes = np.array(gt_classes) + is_crowd = np.array(is_crowd) + gt_boxes = np.array(gt_boxes) + im_info = np.array(im_info) + + rois = [] + labels_int32 = [] + bbox_targets = [] + bbox_inside_weights = [] + bbox_outside_weights = [] + lod = [0] + + for idx in range(len(rpn_rois_lod) - 1): + rois_si = rpn_rois_lod[idx] + rois_ei = rpn_rois_lod[idx + 1] + + gt_si = gt_classes_lod[idx] + gt_ei = gt_classes_lod[idx + 1] + frcn_blobs = _sample_rois( + rpn_rois[rois_si:rois_ei], gt_classes[gt_si:gt_ei], + is_crowd[gt_si:gt_ei], gt_boxes[gt_si:gt_ei], im_info[idx], + batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, + bg_thresh_lo, bbox_reg_weights, class_nums, use_random, + is_cls_agnostic, is_cascade_rcnn) + lod.append(frcn_blobs['rois'].shape[0] + lod[-1]) + rois.append(frcn_blobs['rois']) + labels_int32.append(frcn_blobs['labels_int32'].reshape(-1, 1)) + bbox_targets.append(frcn_blobs['bbox_targets']) + bbox_inside_weights.append(frcn_blobs['bbox_inside_weights']) + bbox_outside_weights.append(frcn_blobs['bbox_outside_weights']) + + rois = np.vstack(rois) + labels_int32 = np.vstack(labels_int32) + bbox_targets = np.vstack(bbox_targets) + bbox_inside_weights = np.vstack(bbox_inside_weights) + bbox_outside_weights = np.vstack(bbox_outside_weights) + + # create lod-tensor for return + # notice that the func create_lod_tensor does not work well here + ret_rois = fluid.LoDTensor() + ret_rois.set_lod([lod]) + ret_rois.set(rois.astype("float32"), fluid.CPUPlace()) + + ret_labels_int32 = fluid.LoDTensor() + ret_labels_int32.set_lod([lod]) + ret_labels_int32.set(labels_int32.astype("int32"), fluid.CPUPlace()) + + ret_bbox_targets = fluid.LoDTensor() + ret_bbox_targets.set_lod([lod]) + ret_bbox_targets.set( + bbox_targets.astype("float32"), fluid.CPUPlace()) + + ret_bbox_inside_weights = fluid.LoDTensor() + ret_bbox_inside_weights.set_lod([lod]) + ret_bbox_inside_weights.set( + bbox_inside_weights.astype("float32"), fluid.CPUPlace()) + + ret_bbox_outside_weights = fluid.LoDTensor() + ret_bbox_outside_weights.set_lod([lod]) + ret_bbox_outside_weights.set( + bbox_outside_weights.astype("float32"), fluid.CPUPlace()) + + return ret_rois, ret_labels_int32, ret_bbox_targets, ret_bbox_inside_weights, ret_bbox_outside_weights + + rois = create_tmp_var( + fluid.default_main_program(), + name=None, + dtype='float32', + shape=[-1, 4], ) + bbox_inside_weights = create_tmp_var( + fluid.default_main_program(), + name=None, + dtype='float32', + shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], ) + bbox_outside_weights = create_tmp_var( + fluid.default_main_program(), + name=None, + dtype='float32', + shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], ) + bbox_targets = create_tmp_var( + fluid.default_main_program(), + name=None, + dtype='float32', + shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], ) + labels_int32 = create_tmp_var( + fluid.default_main_program(), + name=None, + dtype='int32', + shape=[-1, 1], ) + + outs = [ + rois, labels_int32, bbox_targets, bbox_inside_weights, + bbox_outside_weights + ] + + fluid.layers.py_func( + func=generate_func, + x=[rpn_rois, gt_classes, is_crowd, gt_boxes, im_info], + out=outs) + return outs + + +@register +class RoIAlign(object): + __op__ = fluid.layers.roi_align + __append_doc__ = True + + def __init__(self, resolution=7, spatial_scale=1. / 16, sampling_ratio=0): + super(RoIAlign, self).__init__() + if isinstance(resolution, Integral): + resolution = [resolution, resolution] + self.pooled_height = resolution[0] + self.pooled_width = resolution[1] + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + +@register +class RoIPool(object): + __op__ = fluid.layers.roi_pool + __append_doc__ = True + + def __init__(self, resolution=7, spatial_scale=1. / 16): + super(RoIPool, self).__init__() + if isinstance(resolution, Integral): + resolution = [resolution, resolution] + self.pooled_height = resolution[0] + self.pooled_width = resolution[1] + self.spatial_scale = spatial_scale + + +@register +class MultiBoxHead(object): + __op__ = fluid.layers.multi_box_head + __append_doc__ = True + + def __init__(self, + min_ratio=20, + max_ratio=90, + base_size=300, + min_sizes=[60.0, 105.0, 150.0, 195.0, 240.0, 285.0], + max_sizes=[[], 150.0, 195.0, 240.0, 285.0, 300.0], + aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], + [2., 3.]], + steps=None, + offset=0.5, + flip=True, + min_max_aspect_ratios_order=False, + kernel_size=1, + pad=0): + super(MultiBoxHead, self).__init__() + self.min_ratio = min_ratio + self.max_ratio = max_ratio + self.base_size = base_size + self.min_sizes = min_sizes + self.max_sizes = max_sizes + self.aspect_ratios = aspect_ratios + self.steps = steps + self.offset = offset + self.flip = flip + self.min_max_aspect_ratios_order = min_max_aspect_ratios_order + self.kernel_size = kernel_size + self.pad = pad + + +@register +@serializable +class SSDLiteMultiBoxHead(object): + def __init__(self, + min_ratio=20, + max_ratio=90, + base_size=300, + min_sizes=None, + max_sizes=None, + aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], + [2., 3.]], + steps=None, + offset=0.5, + flip=True, + clip=False, + pad=0, + conv_decay=0.0): + super(SSDLiteMultiBoxHead, self).__init__() + self.min_ratio = min_ratio + self.max_ratio = max_ratio + self.base_size = base_size + self.min_sizes = min_sizes + self.max_sizes = max_sizes + self.aspect_ratios = aspect_ratios + self.steps = steps + self.offset = offset + self.flip = flip + self.pad = pad + self.clip = clip + self.conv_decay = conv_decay + + def _separable_conv(self, input, num_filters, name): + dwconv_param_attr = ParamAttr( + name=name + 'dw_weights', regularizer=L2Decay(self.conv_decay)) + num_filter1 = input.shape[1] + depthwise_conv = fluid.layers.conv2d( + input=input, + num_filters=num_filter1, + filter_size=3, + stride=1, + padding="SAME", + groups=int(num_filter1), + act=None, + use_cudnn=False, + param_attr=dwconv_param_attr, + bias_attr=False) + bn_name = name + '_bn' + bn_param_attr = ParamAttr( + name=bn_name + "_scale", regularizer=L2Decay(0.0)) + bn_bias_attr = ParamAttr( + name=bn_name + "_offset", regularizer=L2Decay(0.0)) + bn = fluid.layers.batch_norm( + input=depthwise_conv, + param_attr=bn_param_attr, + bias_attr=bn_bias_attr, + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + bn = fluid.layers.relu6(bn) + pwconv_param_attr = ParamAttr( + name=name + 'pw_weights', regularizer=L2Decay(self.conv_decay)) + pointwise_conv = fluid.layers.conv2d( + input=bn, + num_filters=num_filters, + filter_size=1, + stride=1, + act=None, + use_cudnn=True, + param_attr=pwconv_param_attr, + bias_attr=False) + return pointwise_conv + + def __call__(self, inputs, image, num_classes): + def _permute_and_reshape(input, last_dim): + trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1]) + compile_shape = [0, -1, last_dim] + return fluid.layers.reshape(trans, shape=compile_shape) + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if self.min_sizes is None and self.max_sizes is None: + num_layer = len(inputs) + self.min_sizes = [] + self.max_sizes = [] + step = int( + math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2 + ))) + for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1, + step): + self.min_sizes.append(self.base_size * ratio / 100.) + self.max_sizes.append(self.base_size * (ratio + step) / 100.) + self.min_sizes = [self.base_size * .10] + self.min_sizes + self.max_sizes = [self.base_size * .20] + self.max_sizes + + locs, confs = [], [] + boxes, mvars = [], [] + + for i, input in enumerate(inputs): + min_size = self.min_sizes[i] + max_size = self.max_sizes[i] + if not _is_list_or_tuple_(min_size): + min_size = [min_size] + if not _is_list_or_tuple_(max_size): + max_size = [max_size] + step = [ + self.steps[i] if self.steps else 0.0, self.steps[i] + if self.steps else 0.0 + ] + box, var = fluid.layers.prior_box( + input, + image, + min_sizes=min_size, + max_sizes=max_size, + steps=step, + aspect_ratios=self.aspect_ratios[i], + variance=[0.1, 0.1, 0.2, 0.2], + clip=self.clip, + flip=self.flip, + offset=0.5) + + num_boxes = box.shape[2] + box = fluid.layers.reshape(box, shape=[-1, 4]) + var = fluid.layers.reshape(var, shape=[-1, 4]) + num_loc_output = num_boxes * 4 + num_conf_output = num_boxes * num_classes + # get loc + mbox_loc = self._separable_conv(input, num_loc_output, + "loc_{}".format(i + 1)) + loc = _permute_and_reshape(mbox_loc, 4) + # get conf + mbox_conf = self._separable_conv(input, num_conf_output, + "conf_{}".format(i + 1)) + conf = _permute_and_reshape(mbox_conf, num_classes) + + locs.append(loc) + confs.append(conf) + boxes.append(box) + mvars.append(var) + + ssd_mbox_loc = fluid.layers.concat(locs, axis=1) + ssd_mbox_conf = fluid.layers.concat(confs, axis=1) + prior_boxes = fluid.layers.concat(boxes) + box_vars = fluid.layers.concat(mvars) + + prior_boxes.stop_gradient = True + box_vars.stop_gradient = True + return ssd_mbox_loc, ssd_mbox_conf, prior_boxes, box_vars + + +@register +@serializable +class SSDOutputDecoder(object): + __op__ = fluid.layers.detection_output + __append_doc__ = True + + def __init__(self, + nms_threshold=0.45, + nms_top_k=400, + keep_top_k=200, + score_threshold=0.01, + nms_eta=1.0, + background_label=0, + return_index=False): + super(SSDOutputDecoder, self).__init__() + self.nms_threshold = nms_threshold + self.background_label = background_label + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + self.score_threshold = score_threshold + self.nms_eta = nms_eta + self.return_index = return_index + + +@register +@serializable +class RetinaTargetAssign(object): + __op__ = fluid.layers.retinanet_target_assign + __append_doc__ = True + + def __init__(self, positive_overlap=0.5, negative_overlap=0.4): + super(RetinaTargetAssign, self).__init__() + self.positive_overlap = positive_overlap + self.negative_overlap = negative_overlap + + +@register +@serializable +class RetinaOutputDecoder(object): + __op__ = fluid.layers.retinanet_detection_output + __append_doc__ = True + + def __init__(self, + score_thresh=0.05, + nms_thresh=0.3, + pre_nms_top_n=1000, + detections_per_im=100, + nms_eta=1.0): + super(RetinaOutputDecoder, self).__init__() + self.score_threshold = score_thresh + self.nms_threshold = nms_thresh + self.nms_top_k = pre_nms_top_n + self.keep_top_k = detections_per_im + self.nms_eta = nms_eta + + +@register +@serializable +class MaskMatrixNMS(object): + """ + Matrix NMS for multi-class masks. + Args: + update_threshold (float): Updated threshold of categroy score in second time. + pre_nms_top_n (int): Number of total instance to be kept per image before NMS + post_nms_top_n (int): Number of total instance to be kept per image after NMS. + kernel (str): 'linear' or 'gaussian'. + sigma (float): std in gaussian method. + Input: + seg_preds (Variable): shape (n, h, w), segmentation feature maps + seg_masks (Variable): shape (n, h, w), segmentation feature maps + cate_labels (Variable): shape (n), mask labels in descending order + cate_scores (Variable): shape (n), mask scores in descending order + sum_masks (Variable): a float tensor of the sum of seg_masks + Returns: + Variable: cate_scores, tensors of shape (n) + """ + + def __init__(self, + update_threshold=0.05, + pre_nms_top_n=500, + post_nms_top_n=100, + kernel='gaussian', + sigma=2.0): + super(MaskMatrixNMS, self).__init__() + self.update_threshold = update_threshold + self.pre_nms_top_n = pre_nms_top_n + self.post_nms_top_n = post_nms_top_n + self.kernel = kernel + self.sigma = sigma + + def _sort_score(self, scores, top_num): + self.case_scores = scores + + def fn_1(): + return fluid.layers.topk(self.case_scores, top_num) + + def fn_2(): + return fluid.layers.argsort(self.case_scores, descending=True) + + sort_inds = fluid.layers.case( + pred_fn_pairs=[(fluid.layers.shape(scores)[0] > top_num, fn_1)], + default=fn_2) + return sort_inds + + def __call__(self, + seg_preds, + seg_masks, + cate_labels, + cate_scores, + sum_masks=None): + # sort and keep top nms_pre + sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n) + + seg_masks = fluid.layers.gather(seg_masks, index=sort_inds[1]) + seg_preds = fluid.layers.gather(seg_preds, index=sort_inds[1]) + sum_masks = fluid.layers.gather(sum_masks, index=sort_inds[1]) + cate_scores = sort_inds[0] + cate_labels = fluid.layers.gather(cate_labels, index=sort_inds[1]) + + seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1) + # inter. + inter_matrix = paddle.mm(seg_masks, + fluid.layers.transpose(seg_masks, [1, 0])) + n_samples = fluid.layers.shape(cate_labels) + # union. + sum_masks_x = fluid.layers.reshape( + fluid.layers.expand( + sum_masks, expand_times=[n_samples]), + shape=[n_samples, n_samples]) + # iou. + iou_matrix = (inter_matrix / (sum_masks_x + fluid.layers.transpose( + sum_masks_x, [1, 0]) - inter_matrix)) + iou_matrix = paddle.triu(iou_matrix, diagonal=1) + # label_specific matrix. + cate_labels_x = fluid.layers.reshape( + fluid.layers.expand( + cate_labels, expand_times=[n_samples]), + shape=[n_samples, n_samples]) + label_matrix = fluid.layers.cast( + (cate_labels_x == fluid.layers.transpose(cate_labels_x, [1, 0])), + 'float32') + label_matrix = paddle.triu(label_matrix, diagonal=1) + + # IoU compensation + compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0) + compensate_iou = fluid.layers.reshape( + fluid.layers.expand( + compensate_iou, expand_times=[n_samples]), + shape=[n_samples, n_samples]) + compensate_iou = fluid.layers.transpose(compensate_iou, [1, 0]) + + # IoU decay + decay_iou = iou_matrix * label_matrix + + # matrix nms + if self.kernel == 'gaussian': + decay_matrix = fluid.layers.exp(-1 * self.sigma * (decay_iou**2)) + compensate_matrix = fluid.layers.exp(-1 * self.sigma * + (compensate_iou**2)) + decay_coefficient = paddle.min(decay_matrix / compensate_matrix, + axis=0) + elif self.kernel == 'linear': + decay_matrix = (1 - decay_iou) / (1 - compensate_iou) + decay_coefficient = paddle.min(decay_matrix, axis=0) + else: + raise NotImplementedError + + # update the score. + cate_scores = cate_scores * decay_coefficient + + keep = fluid.layers.where(cate_scores >= self.update_threshold) + keep = fluid.layers.squeeze(keep, axes=[1]) + # Prevent empty and increase fake data + keep = fluid.layers.concat([ + keep, fluid.layers.cast( + fluid.layers.shape(cate_scores)[0] - 1, 'int64') + ]) + + seg_preds = fluid.layers.gather(seg_preds, index=keep) + cate_scores = fluid.layers.gather(cate_scores, index=keep) + cate_labels = fluid.layers.gather(cate_labels, index=keep) + + # sort and keep top_k + sort_inds = self._sort_score(cate_scores, self.post_nms_top_n) + + seg_preds = fluid.layers.gather(seg_preds, index=sort_inds[1]) + cate_scores = sort_inds[0] + cate_labels = fluid.layers.gather(cate_labels, index=sort_inds[1]) + return seg_preds, cate_scores, cate_labels diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_extractors/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_extractors/__init__.py new file mode 100755 index 000000000..15d2525db --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_extractors/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from . import roi_extractor +from .roi_extractor import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_extractors/roi_extractor.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_extractors/roi_extractor.py new file mode 100755 index 000000000..1caf3936f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_extractors/roi_extractor.py @@ -0,0 +1,97 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid + +from ppdet.core.workspace import register +from ppdet.modeling.ops import RoIAlign, RoIPool + +__all__ = ['RoIPool', 'RoIAlign', 'FPNRoIAlign'] + + +@register +class FPNRoIAlign(object): + """ + RoI align pooling for FPN feature maps + Args: + sampling_ratio (int): number of sampling points + min_level (int): lowest level of FPN layer + max_level (int): highest level of FPN layer + canconical_level (int): the canconical FPN feature map level + canonical_size (int): the canconical FPN feature map size + box_resolution (int): box resolution + mask_resolution (int): mask roi resolution + """ + + def __init__(self, + sampling_ratio=0, + min_level=2, + max_level=5, + canconical_level=4, + canonical_size=224, + box_resolution=7, + mask_resolution=14): + super(FPNRoIAlign, self).__init__() + self.sampling_ratio = sampling_ratio + self.min_level = min_level + self.max_level = max_level + self.canconical_level = canconical_level + self.canonical_size = canonical_size + self.box_resolution = box_resolution + self.mask_resolution = mask_resolution + + def __call__(self, head_inputs, rois, spatial_scale, is_mask=False): + """ + Adopt RoI align onto several level of feature maps to get RoI features. + Distribute RoIs to different levels by area and get a list of RoI + features by distributed RoIs and their corresponding feature maps. + + Returns: + roi_feat(Variable): RoI features with shape of [M, C, R, R], + where M is the number of RoIs and R is RoI resolution + + """ + k_min = self.min_level + k_max = self.max_level + num_roi_lvls = k_max - k_min + 1 + name_list = list(head_inputs.keys()) + input_name_list = name_list[-num_roi_lvls:] + spatial_scale = spatial_scale[-num_roi_lvls:] + rois_dist, restore_index = fluid.layers.distribute_fpn_proposals( + rois, k_min, k_max, self.canconical_level, self.canonical_size) + # rois_dist is in ascend order + roi_out_list = [] + resolution = is_mask and self.mask_resolution or self.box_resolution + for lvl in range(num_roi_lvls): + name_index = num_roi_lvls - lvl - 1 + rois_input = rois_dist[lvl] + head_input = head_inputs[input_name_list[name_index]] + sc = spatial_scale[name_index] + roi_out = fluid.layers.roi_align( + input=head_input, + rois=rois_input, + pooled_height=resolution, + pooled_width=resolution, + spatial_scale=sc, + sampling_ratio=self.sampling_ratio) + roi_out_list.append(roi_out) + roi_feat_shuffle = fluid.layers.concat(roi_out_list) + roi_feat_ = fluid.layers.gather(roi_feat_shuffle, restore_index) + roi_feat = fluid.layers.lod_reset(roi_feat_, rois) + + return roi_feat diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/__init__.py new file mode 100755 index 000000000..bb5f47d6f --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from . import bbox_head +from . import mask_head +from . import cascade_head +from . import htc_bbox_head +from . import htc_mask_head +from . import htc_semantic_head + +from .bbox_head import * +from .mask_head import * +from .cascade_head import * +from .htc_bbox_head import * +from .htc_mask_head import * +from .htc_semantic_head import * diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/bbox_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/bbox_head.py new file mode 100755 index 000000000..d33c6248d --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/bbox_head.py @@ -0,0 +1,323 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Xavier +from paddle.fluid.regularizer import L2Decay +from paddle.fluid.initializer import MSRA + +from ppdet.modeling.ops import MultiClassNMS +from ppdet.modeling.ops import ConvNorm +from ppdet.modeling.losses import SmoothL1Loss +from ppdet.core.workspace import register, serializable +from ppdet.experimental import mixed_precision_global_state + +__all__ = ['BBoxHead', 'TwoFCHead', 'XConvNormHead'] + + +@register +@serializable +class BoxCoder(object): + __op__ = fluid.layers.box_coder + __append_doc__ = True + + def __init__(self, + prior_box_var=[0.1, 0.1, 0.2, 0.2], + code_type='decode_center_size', + box_normalized=False, + axis=1): + super(BoxCoder, self).__init__() + self.prior_box_var = prior_box_var + self.code_type = code_type + self.box_normalized = box_normalized + self.axis = axis + + +@register +class XConvNormHead(object): + """ + RCNN head with serveral convolution layers + + Args: + conv_num (int): num of convolution layers for the rcnn head + conv_dim (int): num of filters for the conv layers + mlp_dim (int): num of filters for the fc layers + """ + __shared__ = ['norm_type', 'freeze_norm'] + + def __init__(self, + num_conv=4, + conv_dim=256, + mlp_dim=1024, + norm_type=None, + freeze_norm=False): + super(XConvNormHead, self).__init__() + self.conv_dim = conv_dim + self.mlp_dim = mlp_dim + self.num_conv = num_conv + self.norm_type = norm_type + self.freeze_norm = freeze_norm + + def __call__(self, roi_feat): + conv = roi_feat + fan = self.conv_dim * 3 * 3 + initializer = MSRA(uniform=False, fan_in=fan) + for i in range(self.num_conv): + name = 'bbox_head_conv' + str(i) + conv = ConvNorm( + conv, + self.conv_dim, + 3, + act='relu', + initializer=initializer, + norm_type=self.norm_type, + freeze_norm=self.freeze_norm, + name=name, + norm_name=name) + fan = conv.shape[1] * conv.shape[2] * conv.shape[3] + head_heat = fluid.layers.fc(input=conv, + size=self.mlp_dim, + act='relu', + name='fc6' + name, + param_attr=ParamAttr( + name='fc6%s_w' % name, + initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name='fc6%s_b' % name, + learning_rate=2, + regularizer=L2Decay(0.))) + return head_heat + + +@register +class TwoFCHead(object): + """ + RCNN head with two Fully Connected layers + + Args: + mlp_dim (int): num of filters for the fc layers + """ + + def __init__(self, mlp_dim=1024): + super(TwoFCHead, self).__init__() + self.mlp_dim = mlp_dim + + def __call__(self, roi_feat): + fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3] + + mixed_precision_enabled = mixed_precision_global_state() is not None + + if mixed_precision_enabled: + roi_feat = fluid.layers.cast(roi_feat, 'float16') + + fc6 = fluid.layers.fc(input=roi_feat, + size=self.mlp_dim, + act='relu', + name='fc6', + param_attr=ParamAttr( + name='fc6_w', + initializer=Xavier(fan_out=fan)), + bias_attr=ParamAttr( + name='fc6_b', + learning_rate=2., + regularizer=L2Decay(0.))) + head_feat = fluid.layers.fc(input=fc6, + size=self.mlp_dim, + act='relu', + name='fc7', + param_attr=ParamAttr( + name='fc7_w', initializer=Xavier()), + bias_attr=ParamAttr( + name='fc7_b', + learning_rate=2., + regularizer=L2Decay(0.))) + + if mixed_precision_enabled: + head_feat = fluid.layers.cast(head_feat, 'float32') + + return head_feat + + +@register +class BBoxHead(object): + """ + RCNN bbox head + + Args: + head (object): the head module instance, e.g., `ResNetC5`, `TwoFCHead` + box_coder (object): `BoxCoder` instance + nms (object): `MultiClassNMS` instance + num_classes: number of output classes + """ + __inject__ = ['head', 'box_coder', 'nms', 'bbox_loss'] + __shared__ = ['num_classes'] + + def __init__(self, + head, + box_coder=BoxCoder().__dict__, + nms=MultiClassNMS().__dict__, + bbox_loss=SmoothL1Loss().__dict__, + num_classes=81): + super(BBoxHead, self).__init__() + self.head = head + self.num_classes = num_classes + self.box_coder = box_coder + self.nms = nms + self.bbox_loss = bbox_loss + if isinstance(box_coder, dict): + self.box_coder = BoxCoder(**box_coder) + if isinstance(nms, dict): + self.nms = MultiClassNMS(**nms) + if isinstance(bbox_loss, dict): + self.bbox_loss = SmoothL1Loss(**bbox_loss) + self.head_feat = None + + def get_head_feat(self, input=None): + """ + Get the bbox head feature map. + """ + + if input is not None: + feat = self.head(input) + if isinstance(feat, OrderedDict): + feat = list(feat.values())[0] + self.head_feat = feat + return self.head_feat + + def _get_output(self, roi_feat): + """ + Get bbox head output. + + Args: + roi_feat (Variable): RoI feature from RoIExtractor. + + Returns: + cls_score(Variable): Output of rpn head with shape of + [N, num_anchors, H, W]. + bbox_pred(Variable): Output of rpn head with shape of + [N, num_anchors * 4, H, W]. + """ + head_feat = self.get_head_feat(roi_feat) + # when ResNetC5 output a single feature map + if not isinstance(self.head, TwoFCHead) and not isinstance( + self.head, XConvNormHead): + head_feat = fluid.layers.pool2d( + head_feat, pool_type='avg', global_pooling=True) + cls_score = fluid.layers.fc(input=head_feat, + size=self.num_classes, + act=None, + name='cls_score', + param_attr=ParamAttr( + name='cls_score_w', + initializer=Normal( + loc=0.0, scale=0.01)), + bias_attr=ParamAttr( + name='cls_score_b', + learning_rate=2., + regularizer=L2Decay(0.))) + bbox_pred = fluid.layers.fc(input=head_feat, + size=4 * self.num_classes, + act=None, + name='bbox_pred', + param_attr=ParamAttr( + name='bbox_pred_w', + initializer=Normal( + loc=0.0, scale=0.001)), + bias_attr=ParamAttr( + name='bbox_pred_b', + learning_rate=2., + regularizer=L2Decay(0.))) + return cls_score, bbox_pred + + def get_loss(self, roi_feat, labels_int32, bbox_targets, + bbox_inside_weights, bbox_outside_weights): + """ + Get bbox_head loss. + + Args: + roi_feat (Variable): RoI feature from RoIExtractor. + labels_int32(Variable): Class label of a RoI with shape [P, 1]. + P is the number of RoI. + bbox_targets(Variable): Box label of a RoI with shape + [P, 4 * class_nums]. + bbox_inside_weights(Variable): Indicates whether a box should + contribute to loss. Same shape as bbox_targets. + bbox_outside_weights(Variable): Indicates whether a box should + contribute to loss. Same shape as bbox_targets. + + Return: + Type: Dict + loss_cls(Variable): bbox_head loss. + loss_bbox(Variable): bbox_head loss. + """ + + cls_score, bbox_pred = self._get_output(roi_feat) + + labels_int64 = fluid.layers.cast(x=labels_int32, dtype='int64') + labels_int64.stop_gradient = True + loss_cls = fluid.layers.softmax_with_cross_entropy( + logits=cls_score, label=labels_int64, numeric_stable_mode=True) + loss_cls = fluid.layers.reduce_mean(loss_cls) + loss_bbox = self.bbox_loss( + x=bbox_pred, + y=bbox_targets, + inside_weight=bbox_inside_weights, + outside_weight=bbox_outside_weights) + loss_bbox = fluid.layers.reduce_mean(loss_bbox) + return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox} + + def get_prediction(self, + roi_feat, + rois, + im_info, + im_shape, + return_box_score=False): + """ + Get prediction bounding box in test stage. + + Args: + roi_feat (Variable): RoI feature from RoIExtractor. + rois (Variable): Output of generate_proposals in rpn head. + im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the + number of input images, each element consists of im_height, + im_width, im_scale. + im_shape (Variable): Actual shape of original image with shape + [B, 3]. B is the number of images, each element consists of + original_height, original_width, 1 + + Returns: + pred_result(Variable): Prediction result with shape [N, 6]. Each + row has 6 values: [label, confidence, xmin, ymin, xmax, ymax]. + N is the total number of prediction. + """ + cls_score, bbox_pred = self._get_output(roi_feat) + + im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, rois) + boxes = rois / im_scale + cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False) + bbox_pred = fluid.layers.reshape(bbox_pred, (-1, self.num_classes, 4)) + decoded_box = self.box_coder(prior_box=boxes, target_box=bbox_pred) + cliped_box = fluid.layers.box_clip(input=decoded_box, im_info=im_shape) + if return_box_score: + return {'bbox': cliped_box, 'score': cls_prob} + pred_result = self.nms(bboxes=cliped_box, scores=cls_prob) + return {'bbox': pred_result} diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/cascade_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/cascade_head.py new file mode 100755 index 000000000..a04e3d605 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/cascade_head.py @@ -0,0 +1,365 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Xavier +from paddle.fluid.regularizer import L2Decay +from paddle.fluid.initializer import MSRA + +from ppdet.modeling.ops import MultiClassNMS +from ppdet.modeling.ops import ConvNorm +from ppdet.modeling.losses import SmoothL1Loss +from ppdet.core.workspace import register + +__all__ = ['CascadeBBoxHead'] + + +@register +class CascadeBBoxHead(object): + """ + Cascade RCNN bbox head + + Args: + head (object): the head module instance + nms (object): `MultiClassNMS` instance + num_classes: number of output classes + """ + __inject__ = ['head', 'nms', 'bbox_loss'] + __shared__ = ['num_classes'] + + def __init__( + self, + head, + nms=MultiClassNMS().__dict__, + bbox_loss=SmoothL1Loss().__dict__, + num_classes=81, ): + super(CascadeBBoxHead, self).__init__() + self.head = head + self.nms = nms + self.bbox_loss = bbox_loss + self.num_classes = num_classes + if isinstance(nms, dict): + self.nms = MultiClassNMS(**nms) + if isinstance(bbox_loss, dict): + self.bbox_loss = SmoothL1Loss(**bbox_loss) + + def get_output(self, + roi_feat, + cls_agnostic_bbox_reg=2, + wb_scalar=1.0, + name=''): + """ + Get bbox head output. + + Args: + roi_feat (Variable): RoI feature from RoIExtractor. + cls_agnostic_bbox_reg(Int): BBox regressor are class agnostic. + wb_scalar(Float): Weights and Bias's learning rate. + name(String): Layer's name + + Returns: + cls_score(Variable): cls score. + bbox_pred(Variable): bbox regression. + """ + head_feat = self.head(roi_feat, wb_scalar, name) + cls_score = fluid.layers.fc(input=head_feat, + size=self.num_classes, + act=None, + name='cls_score' + name, + param_attr=ParamAttr( + name='cls_score%s_w' % name, + initializer=Normal( + loc=0.0, scale=0.01), + learning_rate=wb_scalar), + bias_attr=ParamAttr( + name='cls_score%s_b' % name, + learning_rate=wb_scalar * 2, + regularizer=L2Decay(0.))) + bbox_pred = fluid.layers.fc(input=head_feat, + size=4 * cls_agnostic_bbox_reg, + act=None, + name='bbox_pred' + name, + param_attr=ParamAttr( + name='bbox_pred%s_w' % name, + initializer=Normal( + loc=0.0, scale=0.001), + learning_rate=wb_scalar), + bias_attr=ParamAttr( + name='bbox_pred%s_b' % name, + learning_rate=wb_scalar * 2, + regularizer=L2Decay(0.))) + return cls_score, bbox_pred + + def get_loss(self, rcnn_pred_list, rcnn_target_list, rcnn_loss_weight_list): + """ + Get bbox_head loss. + + Args: + rcnn_pred_list(List): Cascade RCNN's head's output including + bbox_pred and cls_score + rcnn_target_list(List): Cascade rcnn's bbox and label target + rcnn_loss_weight_list(List): The weight of location and class loss + + Return: + loss_cls(Variable): bbox_head loss. + loss_bbox(Variable): bbox_head loss. + """ + loss_dict = {} + for i, (rcnn_pred, rcnn_target + ) in enumerate(zip(rcnn_pred_list, rcnn_target_list)): + labels_int64 = fluid.layers.cast(x=rcnn_target[1], dtype='int64') + labels_int64.stop_gradient = True + + loss_cls = fluid.layers.softmax_with_cross_entropy( + logits=rcnn_pred[0], + label=labels_int64, + numeric_stable_mode=True, ) + loss_cls = fluid.layers.reduce_mean( + loss_cls, name='loss_cls_' + str(i)) * rcnn_loss_weight_list[i] + + loss_bbox = self.bbox_loss( + x=rcnn_pred[1], + y=rcnn_target[2], + inside_weight=rcnn_target[3], + outside_weight=rcnn_target[4]) + loss_bbox = fluid.layers.reduce_mean( + loss_bbox, + name='loss_bbox_' + str(i)) * rcnn_loss_weight_list[i] + + loss_dict['loss_cls_%d' % i] = loss_cls + loss_dict['loss_loc_%d' % i] = loss_bbox + + return loss_dict + + def get_prediction(self, + im_info, + im_shape, + roi_feat_list, + rcnn_pred_list, + proposal_list, + cascade_bbox_reg_weights, + cls_agnostic_bbox_reg=2, + return_box_score=False): + """ + Get prediction bounding box in test stage. + : + Args: + im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the + number of input images, each element consists + of im_height, im_width, im_scale. + im_shape (Variable): Actual shape of original image with shape + [B, 3]. B is the number of images, each element consists of + original_height, original_width, 1 + rois_feat_list (List): RoI feature from RoIExtractor. + rcnn_pred_list (Variable): Cascade rcnn's head's output + including bbox_pred and cls_score + proposal_list (List): RPN proposal boxes. + cascade_bbox_reg_weights (List): BBox decode var. + cls_agnostic_bbox_reg(Int): BBox regressor are class agnostic + + Returns: + pred_result(Variable): Prediction result with shape [N, 6]. Each + row has 6 values: [label, confidence, xmin, ymin, xmax, ymax]. + N is the total number of prediction. + """ + self.im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + boxes_cls_prob_l = [] + + rcnn_pred = rcnn_pred_list[-1] # stage 3 + repreat_num = 1 + repreat_num = 3 + bbox_reg_w = cascade_bbox_reg_weights[-1] + for i in range(repreat_num): + # cls score + if i < 2: + cls_score, _ = self.get_output( + roi_feat_list[-1], # roi_feat_3 + name='_' + str(i + 1) if i > 0 else '') + else: + cls_score = rcnn_pred[0] + cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False) + boxes_cls_prob_l.append(cls_prob) + + boxes_cls_prob_mean = ( + boxes_cls_prob_l[0] + boxes_cls_prob_l[1] + boxes_cls_prob_l[2] + ) / 3.0 + + # bbox pred + proposals_boxes = proposal_list[-1] + im_scale_lod = fluid.layers.sequence_expand(self.im_scale, + proposals_boxes) + proposals_boxes = proposals_boxes / im_scale_lod + bbox_pred = rcnn_pred[1] + bbox_pred_new = fluid.layers.reshape(bbox_pred, + (-1, cls_agnostic_bbox_reg, 4)) + if cls_agnostic_bbox_reg == 2: + # only use fg box delta to decode box + bbox_pred_new = fluid.layers.slice( + bbox_pred_new, axes=[1], starts=[1], ends=[2]) + bbox_pred_new = fluid.layers.expand(bbox_pred_new, + [1, self.num_classes, 1]) + decoded_box = fluid.layers.box_coder( + prior_box=proposals_boxes, + prior_box_var=bbox_reg_w, + target_box=bbox_pred_new, + code_type='decode_center_size', + box_normalized=False, + axis=1) + + box_out = fluid.layers.box_clip(input=decoded_box, im_info=im_shape) + if return_box_score: + return {'bbox': box_out, 'score': boxes_cls_prob_mean} + pred_result = self.nms(bboxes=box_out, scores=boxes_cls_prob_mean) + return {"bbox": pred_result} + + def get_prediction_cls_aware(self, + im_info, + im_shape, + cascade_cls_prob, + cascade_decoded_box, + cascade_bbox_reg_weights, + return_box_score=False): + ''' + get_prediction_cls_aware: predict bbox for each class + ''' + cascade_num_stage = 3 + cascade_eval_weight = [0.2, 0.3, 0.5] + # merge 3 stages results + sum_cascade_cls_prob = sum([ + prob * cascade_eval_weight[idx] + for idx, prob in enumerate(cascade_cls_prob) + ]) + sum_cascade_decoded_box = sum([ + bbox * cascade_eval_weight[idx] + for idx, bbox in enumerate(cascade_decoded_box) + ]) + self.im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + im_scale_lod = fluid.layers.sequence_expand(self.im_scale, + sum_cascade_decoded_box) + + sum_cascade_decoded_box = sum_cascade_decoded_box / im_scale_lod + + decoded_bbox = sum_cascade_decoded_box + decoded_bbox = fluid.layers.reshape( + decoded_bbox, shape=(-1, self.num_classes, 4)) + + box_out = fluid.layers.box_clip(input=decoded_bbox, im_info=im_shape) + if return_box_score: + return {'bbox': box_out, 'score': sum_cascade_cls_prob} + pred_result = self.nms(bboxes=box_out, scores=sum_cascade_cls_prob) + return {"bbox": pred_result} + + +@register +class CascadeXConvNormHead(object): + """ + RCNN head with serveral convolution layers + + Args: + conv_num (int): num of convolution layers for the rcnn head + conv_dim (int): num of filters for the conv layers + mlp_dim (int): num of filters for the fc layers + """ + __shared__ = ['norm_type', 'freeze_norm'] + + def __init__(self, + num_conv=4, + conv_dim=256, + mlp_dim=1024, + norm_type=None, + freeze_norm=False): + super(CascadeXConvNormHead, self).__init__() + self.conv_dim = conv_dim + self.mlp_dim = mlp_dim + self.num_conv = num_conv + self.norm_type = norm_type + self.freeze_norm = freeze_norm + + def __call__(self, roi_feat, wb_scalar=1.0, name=''): + conv = roi_feat + fan = self.conv_dim * 3 * 3 + initializer = MSRA(uniform=False, fan_in=fan) + for i in range(self.num_conv): + name = 'bbox_head_conv' + str(i) + conv = ConvNorm( + conv, + self.conv_dim, + 3, + act='relu', + initializer=initializer, + norm_type=self.norm_type, + freeze_norm=self.freeze_norm, + lr_scale=wb_scalar, + name=name, + norm_name=name) + fan = conv.shape[1] * conv.shape[2] * conv.shape[3] + head_heat = fluid.layers.fc(input=conv, + size=self.mlp_dim, + act='relu', + name='fc6' + name, + param_attr=ParamAttr( + name='fc6%s_w' % name, + initializer=Xavier(fan_out=fan), + learning_rate=wb_scalar), + bias_attr=ParamAttr( + name='fc6%s_b' % name, + regularizer=L2Decay(0.), + learning_rate=wb_scalar * 2)) + return head_heat + + +@register +class CascadeTwoFCHead(object): + """ + RCNN head with serveral convolution layers + + Args: + mlp_dim (int): num of filters for the fc layers + """ + + def __init__(self, mlp_dim): + super(CascadeTwoFCHead, self).__init__() + self.mlp_dim = mlp_dim + + def __call__(self, roi_feat, wb_scalar=1.0, name=''): + fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3] + fc6 = fluid.layers.fc(input=roi_feat, + size=self.mlp_dim, + act='relu', + name='fc6' + name, + param_attr=ParamAttr( + name='fc6%s_w' % name, + initializer=Xavier(fan_out=fan), + learning_rate=wb_scalar), + bias_attr=ParamAttr( + name='fc6%s_b' % name, + learning_rate=wb_scalar * 2, + regularizer=L2Decay(0.))) + head_feat = fluid.layers.fc(input=fc6, + size=self.mlp_dim, + act='relu', + name='fc7' + name, + param_attr=ParamAttr( + name='fc7%s_w' % name, + initializer=Xavier(), + learning_rate=wb_scalar), + bias_attr=ParamAttr( + name='fc7%s_b' % name, + learning_rate=wb_scalar * 2, + regularizer=L2Decay(0.))) + return head_feat diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/htc_bbox_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/htc_bbox_head.py new file mode 100755 index 000000000..d43c7d9b8 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/htc_bbox_head.py @@ -0,0 +1,265 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Xavier +from paddle.fluid.regularizer import L2Decay +from paddle.fluid.initializer import MSRA + +from ppdet.modeling.ops import MultiClassNMS +from ppdet.modeling.ops import ConvNorm +from ppdet.modeling.losses import SmoothL1Loss +from ppdet.core.workspace import register + +__all__ = ['HTCBBoxHead'] + + +@register +class HTCBBoxHead(object): + """ + HTC bbox head + + Args: + head (object): the head module instance + nms (object): `MultiClassNMS` instance + num_classes: number of output classes + """ + __inject__ = ['head', 'nms', 'bbox_loss'] + __shared__ = ['num_classes'] + + def __init__(self, + head, + nms=MultiClassNMS().__dict__, + bbox_loss=SmoothL1Loss().__dict__, + num_classes=81, + lr_ratio=2.0): + super(HTCBBoxHead, self).__init__() + self.head = head + self.nms = nms + self.bbox_loss = bbox_loss + self.num_classes = num_classes + self.lr_ratio = lr_ratio + + if isinstance(nms, dict): + self.nms = MultiClassNMS(**nms) + if isinstance(bbox_loss, dict): + self.bbox_loss = SmoothL1Loss(**bbox_loss) + + def get_output(self, + roi_feat, + cls_agnostic_bbox_reg=2, + wb_scalar=1.0, + name=''): + """ + Get bbox head output. + + Args: + roi_feat (Variable): RoI feature from RoIExtractor. + cls_agnostic_bbox_reg(Int): BBox regressor are class agnostic. + wb_scalar(Float): Weights and Bias's learning rate. + name(String): Layer's name + + Returns: + cls_score(Variable): cls score. + bbox_pred(Variable): bbox regression. + """ + head_feat = self.head(roi_feat, wb_scalar, name) + cls_score = fluid.layers.fc(input=head_feat, + size=self.num_classes, + act=None, + name='cls_score' + name, + param_attr=ParamAttr( + name='cls_score%s_w' % name, + initializer=Normal( + loc=0.0, scale=0.01), + learning_rate=wb_scalar), + bias_attr=ParamAttr( + name='cls_score%s_b' % name, + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + bbox_pred = fluid.layers.fc(input=head_feat, + size=4 * cls_agnostic_bbox_reg, + act=None, + name='bbox_pred' + name, + param_attr=ParamAttr( + name='bbox_pred%s_w' % name, + initializer=Normal( + loc=0.0, scale=0.001), + learning_rate=wb_scalar), + bias_attr=ParamAttr( + name='bbox_pred%s_b' % name, + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + return cls_score, bbox_pred + + def get_loss(self, rcnn_pred_list, rcnn_target_list, rcnn_loss_weight_list): + """ + Get bbox_head loss. + + Args: + rcnn_pred_list(List): Cascade RCNN's head's output including + bbox_pred and cls_score + rcnn_target_list(List): Cascade rcnn's bbox and label target + rcnn_loss_weight_list(List): The weight of location and class loss + + Return: + loss_cls(Variable): bbox_head loss. + loss_bbox(Variable): bbox_head loss. + """ + loss_dict = {} + for i, (rcnn_pred, rcnn_target + ) in enumerate(zip(rcnn_pred_list, rcnn_target_list)): + labels_int64 = fluid.layers.cast(x=rcnn_target[1], dtype='int64') + labels_int64.stop_gradient = True + + loss_cls = fluid.layers.softmax_with_cross_entropy( + logits=rcnn_pred[0], + label=labels_int64, + numeric_stable_mode=True, ) + loss_cls = fluid.layers.reduce_mean( + loss_cls, name='loss_cls_' + str(i)) * rcnn_loss_weight_list[i] + + loss_bbox = self.bbox_loss( + x=rcnn_pred[1], + y=rcnn_target[2], + inside_weight=rcnn_target[3], + outside_weight=rcnn_target[4]) + loss_bbox = fluid.layers.reduce_mean( + loss_bbox, + name='loss_bbox_' + str(i)) * rcnn_loss_weight_list[i] + + loss_dict['loss_cls_%d' % i] = loss_cls + loss_dict['loss_loc_%d' % i] = loss_bbox + + return loss_dict + + def get_prediction(self, + im_info, + im_shape, + roi_feat_list, + rcnn_pred_list, + proposal_list, + cascade_bbox_reg_weights, + cls_agnostic_bbox_reg=2, + return_box_score=False): + """ + Get prediction bounding box in test stage. + : + Args: + im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the + number of input images, each element consists + of im_height, im_width, im_scale. + im_shape (Variable): Actual shape of original image with shape + [B, 3]. B is the number of images, each element consists of + original_height, original_width, 1 + rois_feat_list (List): RoI feature from RoIExtractor. + rcnn_pred_list (Variable): Cascade rcnn's head's output + including bbox_pred and cls_score + proposal_list (List): RPN proposal boxes. + cascade_bbox_reg_weights (List): BBox decode var. + cls_agnostic_bbox_reg(Int): BBox regressor are class agnostic + + Returns: + pred_result(Variable): Prediction result with shape [N, 6]. Each + row has 6 values: [label, confidence, xmin, ymin, xmax, ymax]. + N is the total number of prediction. + """ + repeat_num = 3 + # cls score + boxes_cls_prob_l = [] + for i in range(repeat_num): + cls_score = rcnn_pred_list[i][0] + cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False) + boxes_cls_prob_l.append(cls_prob) + + boxes_cls_prob_mean = fluid.layers.sum(boxes_cls_prob_l) / float( + len(boxes_cls_prob_l)) + + # bbox pred + im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + bbox_pred_l = [] + for i in range(repeat_num): + if i < 2: + continue + bbox_reg_w = cascade_bbox_reg_weights[i] + proposals_boxes = proposal_list[i] + im_scale_lod = fluid.layers.sequence_expand(im_scale, + proposals_boxes) + proposals_boxes = proposals_boxes / im_scale_lod + bbox_pred = rcnn_pred_list[i][1] + bbox_pred_new = fluid.layers.reshape(bbox_pred, + (-1, cls_agnostic_bbox_reg, 4)) + bbox_pred_l.append(bbox_pred_new) + + bbox_pred_new = bbox_pred_l[-1] + if cls_agnostic_bbox_reg == 2: + # only use fg box delta to decode box + bbox_pred_new = fluid.layers.slice( + bbox_pred_new, axes=[1], starts=[1], ends=[2]) + bbox_pred_new = fluid.layers.expand(bbox_pred_new, + [1, self.num_classes, 1]) + decoded_box = fluid.layers.box_coder( + prior_box=proposals_boxes, + prior_box_var=bbox_reg_w, + target_box=bbox_pred_new, + code_type='decode_center_size', + box_normalized=False, + axis=1) + + box_out = fluid.layers.box_clip(input=decoded_box, im_info=im_shape) + if return_box_score: + return {'bbox': box_out, 'score': boxes_cls_prob_mean} + pred_result = self.nms(bboxes=box_out, scores=boxes_cls_prob_mean) + return {"bbox": pred_result} + + def get_prediction_cls_aware(self, + im_info, + im_shape, + cascade_cls_prob, + cascade_decoded_box, + cascade_bbox_reg_weights, + return_box_score=False): + ''' + get_prediction_cls_aware: predict bbox for each class + ''' + cascade_num_stage = 3 + cascade_eval_weight = [0.2, 0.3, 0.5] + # merge 3 stages results + sum_cascade_cls_prob = sum([ + prob * cascade_eval_weight[idx] + for idx, prob in enumerate(cascade_cls_prob) + ]) + sum_cascade_decoded_box = sum([ + bbox * cascade_eval_weight[idx] + for idx, bbox in enumerate(cascade_decoded_box) + ]) + self.im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + im_scale_lod = fluid.layers.sequence_expand(self.im_scale, + sum_cascade_decoded_box) + + sum_cascade_decoded_box = sum_cascade_decoded_box / im_scale_lod + + decoded_bbox = sum_cascade_decoded_box + decoded_bbox = fluid.layers.reshape( + decoded_bbox, shape=(-1, self.num_classes, 4)) + + box_out = fluid.layers.box_clip(input=decoded_bbox, im_info=im_shape) + if return_box_score: + return {'bbox': box_out, 'score': sum_cascade_cls_prob} + pred_result = self.nms(bboxes=box_out, scores=sum_cascade_cls_prob) + return {"bbox": pred_result} diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/htc_mask_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/htc_mask_head.py new file mode 100755 index 000000000..bf4816164 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/htc_mask_head.py @@ -0,0 +1,205 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import MSRA +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import ConvNorm + +__all__ = ['HTCMaskHead'] + + +@register +class HTCMaskHead(object): + """ + htc mask head + Args: + num_convs (int): num of convolutions, 4 for FPN, 1 otherwise + conv_dim (int): num of channels after first convolution + resolution (int): size of the output mask + dilation (int): dilation rate + num_classes (int): number of output classes + """ + + __shared__ = ['num_classes'] + + def __init__(self, + num_convs=0, + conv_dim=256, + resolution=14, + dilation=1, + num_classes=81, + norm_type=None, + lr_ratio=2.0, + share_mask_conv=False): + super(HTCMaskHead, self).__init__() + self.num_convs = num_convs + self.conv_dim = conv_dim + self.resolution = resolution + self.dilation = dilation + self.num_classes = num_classes + self.norm_type = norm_type + self.lr_ratio = lr_ratio + self.share_mask_conv = share_mask_conv + + def _mask_conv_head(self, + roi_feat, + num_convs, + norm_type, + wb_scalar=1.0, + name=''): + if norm_type == 'gn': + for i in range(num_convs): + layer_name = "mask_inter_feat_" + str(i + 1) + if not self.share_mask_conv: + layer_name += name + fan = self.conv_dim * 3 * 3 + initializer = MSRA(uniform=False, fan_in=fan) + roi_feat = ConvNorm( + roi_feat, + self.conv_dim, + 3, + act='relu', + dilation=self.dilation, + initializer=initializer, + norm_type=self.norm_type, + name=layer_name, + norm_name=layer_name) + else: + for i in range(num_convs): + layer_name = "mask_inter_feat_" + str(i + 1) + if not self.share_mask_conv: + layer_name += name + fan = self.conv_dim * 3 * 3 + initializer = MSRA(uniform=False, fan_in=fan) + roi_feat = fluid.layers.conv2d( + input=roi_feat, + num_filters=self.conv_dim, + filter_size=3, + padding=1 * self.dilation, + act='relu', + stride=1, + dilation=self.dilation, + name=layer_name, + param_attr=ParamAttr( + name=layer_name + '_w', initializer=initializer), + bias_attr=ParamAttr( + name=layer_name + '_b', + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + return roi_feat + + def get_output(self, + roi_feat, + res_feat=None, + return_logits=True, + return_feat=False, + wb_scalar=1.0, + name=''): + class_num = self.num_classes + if res_feat is not None: + res_feat = fluid.layers.conv2d( + res_feat, roi_feat.shape[1], 1, name='res_net' + name) + roi_feat = fluid.layers.sum([roi_feat, res_feat]) + # configure the conv number for FPN if necessary + head_feat = self._mask_conv_head(roi_feat, self.num_convs, + self.norm_type, wb_scalar, name) + + if return_logits: + fan0 = roi_feat.shape[1] * 2 * 2 + up_head_feat = fluid.layers.conv2d_transpose( + input=head_feat, + num_filters=self.conv_dim, + filter_size=2, + stride=2, + act='relu', + param_attr=ParamAttr( + name='conv5_mask_w' + name, + initializer=MSRA( + uniform=False, fan_in=fan0)), + bias_attr=ParamAttr( + name='conv5_mask_b' + name, + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + + fan = class_num + mask_logits = fluid.layers.conv2d( + input=up_head_feat, + num_filters=class_num, + filter_size=1, + act=None, + param_attr=ParamAttr( + name='mask_fcn_logits_w' + name, + initializer=MSRA( + uniform=False, fan_in=fan)), + bias_attr=ParamAttr( + name="mask_fcn_logits_b" + name, + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + if return_feat: + return mask_logits, head_feat + else: + return mask_logits + + if return_feat: + return head_feat + + def get_loss(self, + mask_logits_list, + mask_int32_list, + cascade_loss_weights=[1.0, 0.5, 0.25]): + num_classes = self.num_classes + resolution = self.resolution + dim = num_classes * resolution * resolution + loss_mask_dict = {} + for i, (mask_logits, mask_int32 + ) in enumerate(zip(mask_logits_list, mask_int32_list)): + + mask_logits = fluid.layers.reshape(mask_logits, (-1, dim)) + mask_label = fluid.layers.cast(x=mask_int32, dtype='float32') + mask_label.stop_gradient = True + loss_name = 'loss_mask_' + str(i) + loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits( + x=mask_logits, + label=mask_label, + ignore_index=-1, + normalize=True, + name=loss_name) + loss_mask = fluid.layers.reduce_sum( + loss_mask) * cascade_loss_weights[i] + loss_mask_dict[loss_name] = loss_mask + return loss_mask_dict + + def get_prediction(self, mask_logits, bbox_pred): + """ + Get prediction mask in test stage. + + Args: + mask_logits (Variable): mask head output features. + bbox_pred (Variable): predicted bbox. + + Returns: + mask_pred (Variable): Prediction mask with shape + [N, num_classes, resolution, resolution]. + """ + mask_prob = fluid.layers.sigmoid(mask_logits) + mask_prob = fluid.layers.lod_reset(mask_prob, bbox_pred) + return mask_prob diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/htc_semantic_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/htc_semantic_head.py new file mode 100755 index 000000000..227889885 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/htc_semantic_head.py @@ -0,0 +1,88 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import MSRA +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import ConvNorm + +__all__ = ['FusedSemanticHead'] + + +@register +class FusedSemanticHead(object): + def __init__(self, semantic_num_class=183): + super(FusedSemanticHead, self).__init__() + self.semantic_num_class = semantic_num_class + + def get_out(self, + fpn_feats, + out_c=256, + num_convs=4, + fusion_level='fpn_res3_sum'): + new_feat = fpn_feats[fusion_level] + new_feat_list = [new_feat, ] + target_shape = fluid.layers.shape(new_feat)[2:] + for k, v in fpn_feats.items(): + if k != fusion_level: + v = fluid.layers.resize_bilinear( + v, target_shape, align_corners=True) + v = fluid.layers.conv2d(v, out_c, 1) + new_feat_list.append(v) + new_feat = fluid.layers.sum(new_feat_list) + + for i in range(num_convs): + new_feat = fluid.layers.conv2d(new_feat, out_c, 3, padding=1) + + # conv embedding + semantic_feat = fluid.layers.conv2d(new_feat, out_c, 1) + # conv logits + seg_pred = fluid.layers.conv2d(new_feat, self.semantic_num_class, 1) + return semantic_feat, seg_pred + + def get_loss(self, logit, label, ignore_index=255): + label = fluid.layers.resize_nearest(label, + fluid.layers.shape(logit)[2:]) + label = fluid.layers.reshape(label, [-1, 1]) + label = fluid.layers.cast(label, 'int64') + + logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) + logit = fluid.layers.reshape(logit, [-1, self.semantic_num_class]) + + loss, probs = fluid.layers.softmax_with_cross_entropy( + logit, + label, + soft_label=False, + ignore_index=ignore_index, + return_softmax=True) + + ignore_mask = (label.astype('int32') != 255).astype('int32') + if ignore_mask is not None: + ignore_mask = fluid.layers.cast(ignore_mask, 'float32') + ignore_mask = fluid.layers.reshape(ignore_mask, [-1, 1]) + loss = loss * ignore_mask + avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask) + ignore_mask.stop_gradient = True + else: + avg_loss = fluid.layers.mean(loss) + label.stop_gradient = True + + return avg_loss diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/mask_head.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/mask_head.py new file mode 100755 index 000000000..f61add040 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/roi_heads/mask_head.py @@ -0,0 +1,160 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import MSRA +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import ConvNorm + +__all__ = ['MaskHead'] + + +@register +class MaskHead(object): + """ + RCNN mask head + Args: + num_convs (int): num of convolutions, 4 for FPN, 1 otherwise + conv_dim (int): num of channels after first convolution + resolution (int): size of the output mask + dilation (int): dilation rate + num_classes (int): number of output classes + """ + + __shared__ = ['num_classes'] + + def __init__(self, + num_convs=0, + conv_dim=256, + resolution=14, + dilation=1, + num_classes=81, + norm_type=None): + super(MaskHead, self).__init__() + self.num_convs = num_convs + self.conv_dim = conv_dim + self.resolution = resolution + self.dilation = dilation + self.num_classes = num_classes + self.norm_type = norm_type + + def _mask_conv_head(self, roi_feat, num_convs, norm_type): + if norm_type == 'gn': + for i in range(num_convs): + layer_name = "mask_inter_feat_" + str(i + 1) + fan = self.conv_dim * 3 * 3 + initializer = MSRA(uniform=False, fan_in=fan) + roi_feat = ConvNorm( + roi_feat, + self.conv_dim, + 3, + act='relu', + dilation=self.dilation, + initializer=initializer, + norm_type=self.norm_type, + name=layer_name, + norm_name=layer_name) + else: + for i in range(num_convs): + layer_name = "mask_inter_feat_" + str(i + 1) + fan = self.conv_dim * 3 * 3 + initializer = MSRA(uniform=False, fan_in=fan) + roi_feat = fluid.layers.conv2d( + input=roi_feat, + num_filters=self.conv_dim, + filter_size=3, + padding=1 * self.dilation, + act='relu', + stride=1, + dilation=self.dilation, + name=layer_name, + param_attr=ParamAttr( + name=layer_name + '_w', initializer=initializer), + bias_attr=ParamAttr( + name=layer_name + '_b', + learning_rate=2., + regularizer=L2Decay(0.))) + fan = roi_feat.shape[1] * 2 * 2 + feat = fluid.layers.conv2d_transpose( + input=roi_feat, + num_filters=self.conv_dim, + filter_size=2, + stride=2, + act='relu', + param_attr=ParamAttr( + name='conv5_mask_w', + initializer=MSRA( + uniform=False, fan_in=fan)), + bias_attr=ParamAttr( + name='conv5_mask_b', learning_rate=2., regularizer=L2Decay(0.))) + return feat + + def _get_output(self, roi_feat): + class_num = self.num_classes + # configure the conv number for FPN if necessary + head_feat = self._mask_conv_head(roi_feat, self.num_convs, + self.norm_type) + fan = class_num + mask_logits = fluid.layers.conv2d( + input=head_feat, + num_filters=class_num, + filter_size=1, + act=None, + param_attr=ParamAttr( + name='mask_fcn_logits_w', + initializer=MSRA( + uniform=False, fan_in=fan)), + bias_attr=ParamAttr( + name="mask_fcn_logits_b", + learning_rate=2., + regularizer=L2Decay(0.))) + return mask_logits + + def get_loss(self, roi_feat, mask_int32): + mask_logits = self._get_output(roi_feat) + num_classes = self.num_classes + resolution = self.resolution + dim = num_classes * resolution * resolution + mask_logits = fluid.layers.reshape(mask_logits, (-1, dim)) + + mask_label = fluid.layers.cast(x=mask_int32, dtype='float32') + mask_label.stop_gradient = True + loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits( + x=mask_logits, label=mask_label, ignore_index=-1, normalize=True) + loss_mask = fluid.layers.reduce_sum(loss_mask, name='loss_mask') + return {'loss_mask': loss_mask} + + def get_prediction(self, roi_feat, bbox_pred): + """ + Get prediction mask in test stage. + + Args: + roi_feat (Variable): RoI feature from RoIExtractor. + bbox_pred (Variable): predicted bbox. + + Returns: + mask_pred (Variable): Prediction mask with shape + [N, num_classes, resolution, resolution]. + """ + mask_logits = self._get_output(roi_feat) + mask_prob = fluid.layers.sigmoid(mask_logits) + mask_prob = fluid.layers.lod_reset(mask_prob, bbox_pred) + return mask_prob diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/target_assigners.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/target_assigners.py new file mode 100755 index 000000000..0aca10dfa --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/target_assigners.py @@ -0,0 +1,82 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid + +from ppdet.core.workspace import register +from ppdet.modeling.ops import BBoxAssigner, MaskAssigner + +__all__ = [ + 'BBoxAssigner', + 'MaskAssigner', + 'CascadeBBoxAssigner', +] + + +@register +class CascadeBBoxAssigner(object): + __shared__ = ['num_classes'] + + def __init__(self, + batch_size_per_im=512, + fg_fraction=.25, + fg_thresh=[0.5, 0.6, 0.7], + bg_thresh_hi=[0.5, 0.6, 0.7], + bg_thresh_lo=[0., 0., 0.], + bbox_reg_weights=[10, 20, 30], + shuffle_before_sample=True, + num_classes=81, + class_aware=False): + super(CascadeBBoxAssigner, self).__init__() + self.batch_size_per_im = batch_size_per_im + self.fg_fraction = fg_fraction + self.fg_thresh = fg_thresh + self.bg_thresh_hi = bg_thresh_hi + self.bg_thresh_lo = bg_thresh_lo + self.bbox_reg_weights = bbox_reg_weights + self.class_nums = num_classes + self.use_random = shuffle_before_sample + self.class_aware = class_aware + + def __call__(self, input_rois, feed_vars, curr_stage, max_overlap=None): + + curr_bbox_reg_w = [ + 1. / self.bbox_reg_weights[curr_stage], + 1. / self.bbox_reg_weights[curr_stage], + 2. / self.bbox_reg_weights[curr_stage], + 2. / self.bbox_reg_weights[curr_stage], + ] + outs = fluid.layers.generate_proposal_labels( + rpn_rois=input_rois, + gt_classes=feed_vars['gt_class'], + is_crowd=feed_vars['is_crowd'], + gt_boxes=feed_vars['gt_bbox'], + im_info=feed_vars['im_info'], + batch_size_per_im=self.batch_size_per_im, + fg_thresh=self.fg_thresh[curr_stage], + bg_thresh_hi=self.bg_thresh_hi[curr_stage], + bg_thresh_lo=self.bg_thresh_lo[curr_stage], + bbox_reg_weights=curr_bbox_reg_w, + use_random=self.use_random, + class_nums=self.class_nums if self.class_aware else 2, + is_cls_agnostic=not self.class_aware, + is_cascade_rcnn=True + if curr_stage > 0 and not self.class_aware else False, + max_overlap=max_overlap, + return_max_overlap=True) + return outs diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/tests/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/tests/__init__.py new file mode 100755 index 000000000..33ed0ecf1 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/tests/decorator_helper.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/tests/decorator_helper.py new file mode 100755 index 000000000..894833ce1 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/tests/decorator_helper.py @@ -0,0 +1,33 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid + +__all__ = ['prog_scope'] + + +def prog_scope(): + def __impl__(fn): + def __fn__(*args, **kwargs): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + with fluid.unique_name.guard(): + fn(*args, **kwargs) + + return __fn__ + + return __impl__ diff --git a/VisualFL/depends/PaddleDetection/ppdet/modeling/tests/test_architectures.py b/VisualFL/depends/PaddleDetection/ppdet/modeling/tests/test_architectures.py new file mode 100755 index 000000000..23c91fbc5 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/modeling/tests/test_architectures.py @@ -0,0 +1,96 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import os +import sys +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +from ppdet.modeling.tests.decorator_helper import prog_scope +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.utils.check import enable_static_mode + + +class TestFasterRCNN(unittest.TestCase): + def setUp(self): + self.set_config() + self.cfg = load_config(self.cfg_file) + self.detector_type = self.cfg['architecture'] + + def set_config(self): + self.cfg_file = 'configs/faster_rcnn_r50_1x.yml' + + @prog_scope() + def test_train(self): + model = create(self.detector_type) + inputs_def = self.cfg['TrainReader']['inputs_def'] + inputs_def['image_shape'] = [3, None, None] + feed_vars, _ = model.build_inputs(**inputs_def) + train_fetches = model.train(feed_vars) + + @prog_scope() + def test_test(self): + inputs_def = self.cfg['EvalReader']['inputs_def'] + inputs_def['image_shape'] = [3, None, None] + model = create(self.detector_type) + feed_vars, _ = model.build_inputs(**inputs_def) + test_fetches = model.eval(feed_vars) + + +class TestMaskRCNN(TestFasterRCNN): + def set_config(self): + self.cfg_file = 'configs/mask_rcnn_r50_1x.yml' + + +@unittest.skip( + reason="It should be fixed to adapt https://github.com/PaddlePaddle/Paddle/pull/23797" +) +class TestCascadeRCNN(TestFasterRCNN): + def set_config(self): + self.cfg_file = 'configs/cascade_rcnn_r50_fpn_1x.yml' + + +@unittest.skipIf( + paddle.version.full_version < "1.8.4", + "Paddle 2.0 should be used for YOLOv3 takes scale_x_y as inputs, " + "disable this unittest for Paddle major version < 2") +class TestYolov3(TestFasterRCNN): + def set_config(self): + self.cfg_file = 'configs/yolov3_darknet.yml' + + +class TestRetinaNet(TestFasterRCNN): + def set_config(self): + self.cfg_file = 'configs/retinanet_r50_fpn_1x.yml' + + +class TestSSD(TestFasterRCNN): + def set_config(self): + self.cfg_file = 'configs/ssd/ssd_mobilenet_v1_voc.yml' + + +if __name__ == '__main__': + enable_static_mode() + unittest.main() diff --git a/VisualFL/depends/PaddleDetection/ppdet/optimizer.py b/VisualFL/depends/PaddleDetection/ppdet/optimizer.py new file mode 100755 index 000000000..b3506e55a --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/optimizer.py @@ -0,0 +1,266 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import logging + +from paddle import fluid + +import paddle.fluid.optimizer as optimizer +import paddle.fluid.regularizer as regularizer +from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter +from paddle.fluid.layers.ops import cos + +from ppdet.core.workspace import register, serializable + +__all__ = ['LearningRate', 'OptimizerBuilder'] + +logger = logging.getLogger(__name__) + + +@serializable +class PiecewiseDecay(object): + """ + Multi step learning rate decay + + Args: + gamma (float | list): decay factor + milestones (list): steps at which to decay learning rate + """ + + def __init__(self, gamma=[0.1, 0.1], milestones=[60000, 80000], + values=None): + super(PiecewiseDecay, self).__init__() + if type(gamma) is not list: + self.gamma = [] + for i in range(len(milestones)): + self.gamma.append(gamma / 10**i) + else: + self.gamma = gamma + self.milestones = milestones + self.values = values + + def __call__(self, base_lr=None, learning_rate=None): + if self.values is not None: + return fluid.layers.piecewise_decay(self.milestones, self.values) + assert base_lr is not None, "either base LR or values should be provided" + values = [base_lr] + for g in self.gamma: + new_lr = base_lr * g + values.append(new_lr) + return fluid.layers.piecewise_decay(self.milestones, values) + + +@serializable +class PolynomialDecay(object): + """ + Applies polynomial decay to the initial learning rate. + Args: + max_iter (int): The learning rate decay steps. + end_lr (float): End learning rate. + power (float): Polynomial attenuation coefficient + """ + + def __init__(self, max_iter=180000, end_lr=0.0001, power=1.0): + super(PolynomialDecay).__init__() + self.max_iter = max_iter + self.end_lr = end_lr + self.power = power + + def __call__(self, base_lr=None, learning_rate=None): + assert base_lr is not None, "either base LR or values should be provided" + lr = fluid.layers.polynomial_decay(base_lr, self.max_iter, self.end_lr, + self.power) + return lr + + +@serializable +class ExponentialDecay(object): + """ + Applies exponential decay to the learning rate. + Args: + max_iter (int): The learning rate decay steps. + decay_rate (float): The learning rate decay rate. + """ + + def __init__(self, max_iter, decay_rate): + super(ExponentialDecay).__init__() + self.max_iter = max_iter + self.decay_rate = decay_rate + + def __call__(self, base_lr=None, learning_rate=None): + assert base_lr is not None, "either base LR or values should be provided" + lr = fluid.layers.exponential_decay(base_lr, self.max_iter, + self.decay_rate) + return lr + + +@serializable +class CosineDecay(object): + """ + Cosine learning rate decay + + Args: + max_iters (float): max iterations for the training process. + if you commbine cosine decay with warmup, it is recommended that + the max_iter is much larger than the warmup iter + """ + + def __init__(self, max_iters=180000): + self.max_iters = max_iters + + def __call__(self, base_lr=None, learning_rate=None): + assert base_lr is not None, "either base LR or values should be provided" + lr = fluid.layers.cosine_decay(base_lr, 1, self.max_iters) + return lr + + +@serializable +class CosineDecayWithSkip(object): + """ + Cosine decay, with explicit support for warm up + + Args: + total_steps (int): total steps over which to apply the decay + skip_steps (int): skip some steps at the beginning, e.g., warm up + """ + + def __init__(self, total_steps, skip_steps=None): + super(CosineDecayWithSkip, self).__init__() + assert (not skip_steps or skip_steps > 0), \ + "skip steps must be greater than zero" + assert total_steps > 0, "total step must be greater than zero" + assert (not skip_steps or skip_steps < total_steps), \ + "skip steps must be smaller than total steps" + self.total_steps = total_steps + self.skip_steps = skip_steps + + def __call__(self, base_lr=None, learning_rate=None): + steps = _decay_step_counter() + total = self.total_steps + if self.skip_steps is not None: + total -= self.skip_steps + + lr = fluid.layers.tensor.create_global_var( + shape=[1], + value=base_lr, + dtype='float32', + persistable=True, + name="learning_rate") + + def decay(): + cos_lr = base_lr * .5 * (cos(steps * (math.pi / total)) + 1) + fluid.layers.tensor.assign(input=cos_lr, output=lr) + + if self.skip_steps is None: + decay() + else: + skipped = steps >= self.skip_steps + fluid.layers.cond(skipped, decay) + return lr + + +@serializable +class LinearWarmup(object): + """ + Warm up learning rate linearly + + Args: + steps (int): warm up steps + start_factor (float): initial learning rate factor + """ + + def __init__(self, steps=500, start_factor=1. / 3): + super(LinearWarmup, self).__init__() + self.steps = steps + self.start_factor = start_factor + + def __call__(self, base_lr, learning_rate): + start_lr = base_lr * self.start_factor + + return fluid.layers.linear_lr_warmup( + learning_rate=learning_rate, + warmup_steps=self.steps, + start_lr=start_lr, + end_lr=base_lr) + + +@register +class LearningRate(object): + """ + Learning Rate configuration + + Args: + base_lr (float): base learning rate + schedulers (list): learning rate schedulers + """ + __category__ = 'optim' + + def __init__(self, + base_lr=0.01, + schedulers=[PiecewiseDecay(), LinearWarmup()]): + super(LearningRate, self).__init__() + self.base_lr = base_lr + self.schedulers = schedulers + + def __call__(self): + lr = None + for sched in self.schedulers: + lr = sched(self.base_lr, lr) + return lr + + +@register +class OptimizerBuilder(): + """ + Build optimizer handles + + Args: + regularizer (object): an `Regularizer` instance + optimizer (object): an `Optimizer` instance + """ + __category__ = 'optim' + + def __init__(self, + clip_grad_by_norm=None, + regularizer={'type': 'L2', + 'factor': .0001}, + optimizer={'type': 'Momentum', + 'momentum': .9}): + self.clip_grad_by_norm = clip_grad_by_norm + self.regularizer = regularizer + self.optimizer = optimizer + + def __call__(self, learning_rate): + if self.clip_grad_by_norm is not None: + fluid.clip.set_gradient_clip( + clip=fluid.clip.GradientClipByGlobalNorm( + clip_norm=self.clip_grad_by_norm)) + if self.regularizer: + reg_type = self.regularizer['type'] + 'Decay' + reg_factor = self.regularizer['factor'] + regularization = getattr(regularizer, reg_type)(reg_factor) + else: + regularization = None + optim_args = self.optimizer.copy() + optim_type = optim_args['type'] + del optim_args['type'] + op = getattr(optimizer, optim_type) + return op(learning_rate=learning_rate, + regularization=regularization, + **optim_args) diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/__init__.py b/VisualFL/depends/PaddleDetection/ppdet/utils/__init__.py new file mode 100755 index 000000000..d0c32e260 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/bbox_utils.py b/VisualFL/depends/PaddleDetection/ppdet/utils/bbox_utils.py new file mode 100755 index 000000000..ff16e8b9d --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/bbox_utils.py @@ -0,0 +1,83 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import numpy as np + +import paddle.fluid as fluid + +__all__ = ["bbox_overlaps", "box_to_delta"] + +logger = logging.getLogger(__name__) + + +def bbox_overlaps(boxes_1, boxes_2): + ''' + bbox_overlaps + boxes_1: x1, y, x2, y2 + boxes_2: x1, y, x2, y2 + ''' + assert boxes_1.shape[1] == 4 and boxes_2.shape[1] == 4 + + num_1 = boxes_1.shape[0] + num_2 = boxes_2.shape[0] + + x1_1 = boxes_1[:, 0:1] + y1_1 = boxes_1[:, 1:2] + x2_1 = boxes_1[:, 2:3] + y2_1 = boxes_1[:, 3:4] + area_1 = (x2_1 - x1_1 + 1) * (y2_1 - y1_1 + 1) + + x1_2 = boxes_2[:, 0].transpose() + y1_2 = boxes_2[:, 1].transpose() + x2_2 = boxes_2[:, 2].transpose() + y2_2 = boxes_2[:, 3].transpose() + area_2 = (x2_2 - x1_2 + 1) * (y2_2 - y1_2 + 1) + + xx1 = np.maximum(x1_1, x1_2) + yy1 = np.maximum(y1_1, y1_2) + xx2 = np.minimum(x2_1, x2_2) + yy2 = np.minimum(y2_1, y2_2) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + + ovr = inter / (area_1 + area_2 - inter) + return ovr + + +def box_to_delta(ex_boxes, gt_boxes, weights): + """ box_to_delta """ + ex_w = ex_boxes[:, 2] - ex_boxes[:, 0] + 1 + ex_h = ex_boxes[:, 3] - ex_boxes[:, 1] + 1 + ex_ctr_x = ex_boxes[:, 0] + 0.5 * ex_w + ex_ctr_y = ex_boxes[:, 1] + 0.5 * ex_h + + gt_w = gt_boxes[:, 2] - gt_boxes[:, 0] + 1 + gt_h = gt_boxes[:, 3] - gt_boxes[:, 1] + 1 + gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_w + gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_h + + dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0] + dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1] + dw = (np.log(gt_w / ex_w)) / weights[2] + dh = (np.log(gt_h / ex_h)) / weights[3] + + targets = np.vstack([dx, dy, dw, dh]).transpose() + return targets diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/check.py b/VisualFL/depends/PaddleDetection/ppdet/utils/check.py new file mode 100755 index 000000000..960a03903 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/check.py @@ -0,0 +1,136 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import paddle +import paddle.fluid as fluid + +import logging +import six +import paddle.version as fluid_version +logger = logging.getLogger(__name__) + +__all__ = [ + 'check_gpu', + 'check_version', + 'check_config', + 'check_py_func', +] + + +def check_gpu(use_gpu): + """ + Log error and exit when set use_gpu=true in paddlepaddle + cpu version. + """ + err = "Config use_gpu cannot be set as true while you are " \ + "using paddlepaddle cpu version ! \nPlease try: \n" \ + "\t1. Install paddlepaddle-gpu to run model on GPU \n" \ + "\t2. Set use_gpu as false in config file to run " \ + "model on CPU" + + try: + if use_gpu and not fluid.is_compiled_with_cuda(): + logger.error(err) + sys.exit(1) + except Exception as e: + pass + + +def check_version(version='1.7.0'): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version {} or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code.".format(version) + + version_installed = [ + fluid_version.major, fluid_version.minor, fluid_version.patch, + fluid_version.rc + ] + if version_installed == ['0', '0', '0', '0']: + return + version_split = version.split('.') + + length = min(len(version_installed), len(version_split)) + for i in six.moves.range(length): + if version_installed[i] > version_split[i]: + return + if version_installed[i] < version_split[i]: + raise Exception(err) + + +def check_config(cfg): + """ + Check the correctness of the configuration file. Log error and exit + when Config is not compliant. + """ + err = "'{}' not specified in config file. Please set it in config file." + check_list = ['architecture', 'num_classes'] + try: + for var in check_list: + if not var in cfg: + logger.error(err.format(var)) + sys.exit(1) + except Exception as e: + pass + + if 'log_iter' not in cfg: + cfg.log_iter = 20 + + train_dataset = cfg['TrainReader']['dataset'] + eval_dataset = cfg['EvalReader']['dataset'] + test_dataset = cfg['TestReader']['dataset'] + assert train_dataset.with_background == eval_dataset.with_background, \ + "'with_background' of TrainReader is not equal to EvalReader." + assert train_dataset.with_background == test_dataset.with_background, \ + "'with_background' of TrainReader is not equal to TestReader." + + actual_num_classes = int(cfg.num_classes) - int( + train_dataset.with_background) + logger.debug("The 'num_classes'(number of classes) you set is {}, " \ + "and 'with_background' in 'dataset' sets {}.\n" \ + "So please note the actual number of categories is {}." + .format(cfg.num_classes, train_dataset.with_background, + actual_num_classes)) + + return cfg + + +def check_py_func(program): + for block in program.blocks: + for op in block.ops: + if op.type == 'py_func': + input_arg = op.input_arg_names + output_arg = op.output_arg_names + err = "The program contains py_func with input: {}, "\ + "output: {}. It is not supported in Paddle inference "\ + "engine. please replace it by paddle ops. For example, "\ + "if you use MultiClassSoftNMS, better to replace it by "\ + "MultiClassNMS.".format(input_arg, output_arg) + raise Exception(err) + + +def enable_static_mode(): + try: + paddle.enable_static() + except: + pass diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/checkpoint.py b/VisualFL/depends/PaddleDetection/ppdet/utils/checkpoint.py new file mode 100755 index 000000000..9461be8a3 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/checkpoint.py @@ -0,0 +1,304 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import errno +import os +import shutil +import tempfile +import time +import numpy as np +import re +import paddle.fluid as fluid + +from .download import get_weights_path + +import logging +logger = logging.getLogger(__name__) + +__all__ = [ + 'load_checkpoint', + 'load_and_fusebn', + 'load_params', + 'save', +] + + +def is_url(path): + """ + Whether path is URL. + Args: + path (string): URL string or not. + """ + return path.startswith('http://') or path.startswith('https://') + + +def _get_weight_path(path): + env = os.environ + if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: + trainer_id = int(env['PADDLE_TRAINER_ID']) + num_trainers = int(env['PADDLE_TRAINERS_NUM']) + if num_trainers <= 1: + path = get_weights_path(path) + else: + from ppdet.utils.download import map_path, WEIGHTS_HOME + weight_path = map_path(path, WEIGHTS_HOME) + lock_path = weight_path + '.lock' + if not os.path.exists(weight_path): + try: + os.makedirs(os.path.dirname(weight_path)) + except OSError as e: + if e.errno != errno.EEXIST: + raise + with open(lock_path, 'w'): # touch + os.utime(lock_path, None) + if trainer_id == 0: + get_weights_path(path) + os.remove(lock_path) + else: + while os.path.exists(lock_path): + time.sleep(1) + path = weight_path + else: + path = get_weights_path(path) + return path + + +def _load_state(path): + if os.path.exists(path + '.pdopt'): + # XXX another hack to ignore the optimizer state + tmp = tempfile.mkdtemp() + dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) + shutil.copy(path + '.pdparams', dst + '.pdparams') + state = fluid.io.load_program_state(dst) + shutil.rmtree(tmp) + else: + state = fluid.io.load_program_state(path) + return state + + +def _strip_postfix(path): + path, ext = os.path.splitext(path) + assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ + "Unknown postfix {} from weights".format(ext) + return path + + +def load_params(exe, prog, path, ignore_params=[]): + """ + Load model from the given path. + Args: + exe (fluid.Executor): The fluid.Executor object. + prog (fluid.Program): load weight to which Program object. + path (string): URL string or loca model path. + ignore_params (list): ignore variable to load when finetuning. + It can be specified by finetune_exclude_pretrained_params + and the usage can refer to docs/advanced_tutorials/TRANSFER_LEARNING.md + """ + + if is_url(path): + path = _get_weight_path(path) + + path = _strip_postfix(path) + if not (os.path.isdir(path) or os.path.isfile(path) or + os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + + logger.debug('Loading parameters from {}...'.format(path)) + + ignore_set = set() + state = _load_state(path) + + # ignore the parameter which mismatch the shape + # between the model and pretrain weight. + all_var_shape = {} + for block in prog.blocks: + for param in block.all_parameters(): + all_var_shape[param.name] = param.shape + ignore_set.update([ + name for name, shape in all_var_shape.items() + if name in state and shape != state[name].shape + ]) + + if ignore_params: + all_var_names = [var.name for var in prog.list_vars()] + ignore_list = filter( + lambda var: any([re.match(name, var) for name in ignore_params]), + all_var_names) + ignore_set.update(list(ignore_list)) + + if len(ignore_set) > 0: + for k in ignore_set: + if k in state: + logger.warning('variable {} not used'.format(k)) + del state[k] + fluid.io.set_program_state(prog, state) + + +def load_checkpoint(exe, prog, path): + """ + Load model from the given path. + Args: + exe (fluid.Executor): The fluid.Executor object. + prog (fluid.Program): load weight to which Program object. + path (string): URL string or loca model path. + """ + if is_url(path): + path = _get_weight_path(path) + + path = _strip_postfix(path) + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + fluid.load(prog, path, executor=exe) + + +def global_step(scope=None): + """ + Load global step in scope. + Args: + scope (fluid.Scope): load global step from which scope. If None, + from default global_scope(). + + Returns: + global step: int. + """ + if scope is None: + scope = fluid.global_scope() + v = scope.find_var('@LR_DECAY_COUNTER@') + step = np.array(v.get_tensor())[0] if v else 0 + return step + + +def save(exe, prog, path): + """ + Load model from the given path. + Args: + exe (fluid.Executor): The fluid.Executor object. + prog (fluid.Program): save weight from which Program object. + path (string): the path to save model. + """ + if os.path.isdir(path): + shutil.rmtree(path) + logger.info('Save model to {}.'.format(path)) + fluid.save(prog, path) + + +def load_and_fusebn(exe, prog, path): + """ + Fuse params of batch norm to scale and bias. + + Args: + exe (fluid.Executor): The fluid.Executor object. + prog (fluid.Program): save weight from which Program object. + path (string): the path to save model. + """ + logger.debug('Load model and fuse batch norm if have from {}...'.format( + path)) + + if is_url(path): + path = _get_weight_path(path) + + if not os.path.exists(path): + raise ValueError("Model path {} does not exists.".format(path)) + + # Since the program uses affine-channel, there is no running mean and var + # in the program, here append running mean and var. + # NOTE, the params of batch norm should be like: + # x_scale + # x_offset + # x_mean + # x_variance + # x is any prefix + mean_variances = set() + bn_vars = [] + state = _load_state(path) + + def check_mean_and_bias(prefix): + m = prefix + 'mean' + v = prefix + 'variance' + return v in state and m in state + + has_mean_bias = True + + with fluid.program_guard(prog, fluid.Program()): + for block in prog.blocks: + ops = list(block.ops) + if not has_mean_bias: + break + for op in ops: + if op.type == 'affine_channel': + # remove 'scale' as prefix + scale_name = op.input('Scale')[0] # _scale + bias_name = op.input('Bias')[0] # _offset + prefix = scale_name[:-5] + mean_name = prefix + 'mean' + variance_name = prefix + 'variance' + if not check_mean_and_bias(prefix): + has_mean_bias = False + break + + bias = block.var(bias_name) + + mean_vb = block.create_var( + name=mean_name, + type=bias.type, + shape=bias.shape, + dtype=bias.dtype) + variance_vb = block.create_var( + name=variance_name, + type=bias.type, + shape=bias.shape, + dtype=bias.dtype) + + mean_variances.add(mean_vb) + mean_variances.add(variance_vb) + + bn_vars.append( + [scale_name, bias_name, mean_name, variance_name]) + + if not has_mean_bias: + fluid.io.set_program_state(prog, state) + logger.warning( + "There is no paramters of batch norm in model {}. " + "Skip to fuse batch norm. And load paramters done.".format(path)) + return + + fluid.load(prog, path, exe) + eps = 1e-5 + for names in bn_vars: + scale_name, bias_name, mean_name, var_name = names + + scale = fluid.global_scope().find_var(scale_name).get_tensor() + bias = fluid.global_scope().find_var(bias_name).get_tensor() + mean = fluid.global_scope().find_var(mean_name).get_tensor() + var = fluid.global_scope().find_var(var_name).get_tensor() + + scale_arr = np.array(scale) + bias_arr = np.array(bias) + mean_arr = np.array(mean) + var_arr = np.array(var) + + bn_std = np.sqrt(np.add(var_arr, eps)) + new_scale = np.float32(np.divide(scale_arr, bn_std)) + new_bias = bias_arr - mean_arr * new_scale + + # fuse to scale and bias in affine_channel + scale.set(new_scale, exe.place) + bias.set(new_bias, exe.place) diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/cli.py b/VisualFL/depends/PaddleDetection/ppdet/utils/cli.py new file mode 100755 index 000000000..b8ba59d78 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/cli.py @@ -0,0 +1,151 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser, RawDescriptionHelpFormatter + +import yaml +import re +from ppdet.core.workspace import get_registered_modules, dump_value + +__all__ = ['ColorTTY', 'ArgsParser'] + + +class ColorTTY(object): + def __init__(self): + super(ColorTTY, self).__init__() + self.colors = ['red', 'green', 'yellow', 'blue', 'magenta', 'cyan'] + + def __getattr__(self, attr): + if attr in self.colors: + color = self.colors.index(attr) + 31 + + def color_message(message): + return "[{}m{}".format(color, message) + + setattr(self, attr, color_message) + return color_message + + def bold(self, message): + return self.with_code('01', message) + + def with_code(self, code, message): + return "[{}m{}".format(code, message) + + +class ArgsParser(ArgumentParser): + def __init__(self): + super(ArgsParser, self).__init__( + formatter_class=RawDescriptionHelpFormatter) + self.add_argument("-c", "--config", help="configuration file to use") + self.add_argument( + "-o", "--opt", nargs='*', help="set configuration options") + + def parse_args(self, argv=None): + args = super(ArgsParser, self).parse_args(argv) + assert args.config is not None, \ + "Please specify --config=configure_file_path." + args.opt = self._parse_opt(args.opt) + return args + + def _parse_opt(self, opts): + config = {} + if not opts: + return config + for s in opts: + s = s.strip() + k, v = s.split('=', 1) + if '.' not in k: + config[k] = yaml.load(v, Loader=yaml.Loader) + else: + keys = k.split('.') + if keys[0] not in config: + config[keys[0]] = {} + cur = config[keys[0]] + for idx, key in enumerate(keys[1:]): + if idx == len(keys) - 2: + cur[key] = yaml.load(v, Loader=yaml.Loader) + else: + cur[key] = {} + cur = cur[key] + return config + + +def print_total_cfg(config): + modules = get_registered_modules() + color_tty = ColorTTY() + green = '___{}___'.format(color_tty.colors.index('green') + 31) + + styled = {} + for key in config.keys(): + if not config[key]: # empty schema + continue + + if key not in modules and not hasattr(config[key], '__dict__'): + styled[key] = config[key] + continue + elif key in modules: + module = modules[key] + else: + type_name = type(config[key]).__name__ + if type_name in modules: + module = modules[type_name].copy() + module.update({ + k: v + for k, v in config[key].__dict__.items() + if k in module.schema + }) + key += " ({})".format(type_name) + default = module.find_default_keys() + missing = module.find_missing_keys() + mismatch = module.find_mismatch_keys() + extra = module.find_extra_keys() + dep_missing = [] + for dep in module.inject: + if isinstance(module[dep], str) and module[dep] != '': + if module[dep] not in modules: # not a valid module + dep_missing.append(dep) + else: + dep_mod = modules[module[dep]] + # empty dict but mandatory + if not dep_mod and dep_mod.mandatory(): + dep_missing.append(dep) + override = list( + set(module.keys()) - set(default) - set(extra) - set(dep_missing)) + replacement = {} + for name in set(override + default + extra + mismatch + missing): + new_name = name + if name in missing: + value = "" + else: + value = module[name] + + if name in extra: + value = dump_value(value) + " " + elif name in mismatch: + value = dump_value(value) + " " + elif name in dep_missing: + value = dump_value(value) + " " + elif name in override and value != '': + mark = green + new_name = mark + name + replacement[new_name] = value + styled[key] = replacement + buffer = yaml.dump(styled, default_flow_style=False, default_style='') + buffer = (re.sub(r"", r"[31m[0m", buffer)) + buffer = (re.sub(r"", r"[33m[0m", buffer)) + buffer = (re.sub(r"", r"[31m[0m", buffer)) + buffer = (re.sub(r"", + r"[31m[0m", buffer)) + buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2[0m:", buffer) + print(buffer) diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/coco_eval.py b/VisualFL/depends/PaddleDetection/ppdet/utils/coco_eval.py new file mode 100755 index 000000000..a4df01b3c --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/coco_eval.py @@ -0,0 +1,706 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import sys +import json +import cv2 +import numpy as np + +import logging +logger = logging.getLogger(__name__) + +__all__ = [ + 'bbox_eval', + 'mask_eval', + 'bbox2out', + 'mask2out', + 'get_category_info', + 'proposal_eval', + 'cocoapi_eval', +] + + +def clip_bbox(bbox, im_size=None): + h = 1. if im_size is None else im_size[0] + w = 1. if im_size is None else im_size[1] + xmin = max(min(bbox[0], w), 0.) + ymin = max(min(bbox[1], h), 0.) + xmax = max(min(bbox[2], w), 0.) + ymax = max(min(bbox[3], h), 0.) + return xmin, ymin, xmax, ymax + + +def proposal_eval(results, anno_file, outfile, max_dets=(100, 300, 1000)): + assert 'proposal' in results[0] + assert outfile.endswith('.json') + + xywh_results = proposal2out(results) + assert len( + xywh_results) > 0, "The number of valid proposal detected is zero.\n \ + Please use reasonable model and check input data." + + with open(outfile, 'w') as f: + json.dump(xywh_results, f) + + cocoapi_eval(outfile, 'proposal', anno_file=anno_file, max_dets=max_dets) + # flush coco evaluation result + sys.stdout.flush() + + +def bbox_eval(results, + anno_file, + outfile, + with_background=True, + is_bbox_normalized=False, + save_only=False): + assert 'bbox' in results[0] + assert outfile.endswith('.json') + from pycocotools.coco import COCO + + coco_gt = COCO(anno_file) + cat_ids = coco_gt.getCatIds() + + # when with_background = True, mapping category to classid, like: + # background:0, first_class:1, second_class:2, ... + clsid2catid = dict( + {i + int(with_background): catid + for i, catid in enumerate(cat_ids)}) + + xywh_results = bbox2out( + results, clsid2catid, is_bbox_normalized=is_bbox_normalized) + + if len(xywh_results) == 0: + logger.warning("The number of valid bbox detected is zero.\n \ + Please use reasonable model and check input data.\n \ + stop eval!") + return [0.0] + with open(outfile, 'w') as f: + json.dump(xywh_results, f) + + if save_only: + logger.info('The bbox result is saved to {} and do not ' + 'evaluate the mAP.'.format(outfile)) + return + + map_stats = cocoapi_eval(outfile, 'bbox', coco_gt=coco_gt) + # flush coco evaluation result + sys.stdout.flush() + return map_stats + + +def mask_eval(results, + anno_file, + outfile, + resolution, + thresh_binarize=0.5, + save_only=False): + """ + Format the output of mask and get mask ap by coco api evaluation. + It will be used in Mask-RCNN. + """ + assert 'mask' in results[0] + assert outfile.endswith('.json') + from pycocotools.coco import COCO + + coco_gt = COCO(anno_file) + clsid2catid = {i + 1: v for i, v in enumerate(coco_gt.getCatIds())} + + segm_results = [] + for t in results: + im_ids = np.array(t['im_id'][0]) + bboxes = t['bbox'][0] + lengths = t['bbox'][1][0] + masks = t['mask'] + if bboxes.shape == (1, 1) or bboxes is None: + continue + if len(bboxes.tolist()) == 0: + continue + s = 0 + for i in range(len(lengths)): + num = lengths[i] + im_id = int(im_ids[i][0]) + clsid_scores = bboxes[s:s + num][:, 0:2] + mask = masks[s:s + num] + s += num + for j in range(num): + clsid, score = clsid_scores[j].tolist() + catid = int(clsid2catid[clsid]) + segm = mask[j] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'image_id': im_id, + 'category_id': int(catid), + 'segmentation': segm, + 'score': score + } + segm_results.append(coco_res) + + if len(segm_results) == 0: + logger.warning("The number of valid mask detected is zero.\n \ + Please use reasonable model and check input data.") + return + + with open(outfile, 'w') as f: + json.dump(segm_results, f) + + if save_only: + logger.info('The mask result is saved to {} and do not ' + 'evaluate the mAP.'.format(outfile)) + return + + cocoapi_eval(outfile, 'segm', coco_gt=coco_gt) + + +def segm_eval(results, anno_file, outfile, save_only=False): + """ + Format the output of segmentation, category_id and score in mask.josn, and + get mask ap by coco api evaluation. It will be used in instance segmentation + networks, such as: SOLOv2. + """ + assert 'segm' in results[0] + assert outfile.endswith('.json') + from pycocotools.coco import COCO + coco_gt = COCO(anno_file) + clsid2catid = {i: v for i, v in enumerate(coco_gt.getCatIds())} + segm_results = [] + for t in results: + im_id = int(t['im_id'][0][0]) + segs = t['segm'] + for mask in segs: + catid = int(clsid2catid[mask[0]]) + masks = mask[1] + mask_score = masks[1] + segm = masks[0] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'image_id': im_id, + 'category_id': catid, + 'segmentation': segm, + 'score': mask_score + } + segm_results.append(coco_res) + + if len(segm_results) == 0: + logger.warning("The number of valid mask detected is zero.\n \ + Please use reasonable model and check input data.") + return + + with open(outfile, 'w') as f: + json.dump(segm_results, f) + + if save_only: + logger.info('The mask result is saved to {} and do not ' + 'evaluate the mAP.'.format(outfile)) + return + + map_stats = cocoapi_eval(outfile, 'segm', coco_gt=coco_gt) + return map_stats + + +def cocoapi_eval(jsonfile, + style, + coco_gt=None, + anno_file=None, + max_dets=(100, 300, 1000)): + """ + Args: + jsonfile: Evaluation json file, eg: bbox.json, mask.json. + style: COCOeval style, can be `bbox` , `segm` and `proposal`. + coco_gt: Whether to load COCOAPI through anno_file, + eg: coco_gt = COCO(anno_file) + anno_file: COCO annotations file. + max_dets: COCO evaluation maxDets. + """ + assert coco_gt != None or anno_file != None + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + if coco_gt == None: + coco_gt = COCO(anno_file) + logger.info("Start evaluate...") + coco_dt = coco_gt.loadRes(jsonfile) + if style == 'proposal': + coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') + coco_eval.params.useCats = 0 + coco_eval.params.maxDets = list(max_dets) + else: + coco_eval = COCOeval(coco_gt, coco_dt, style) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + return coco_eval.stats + + +def proposal2out(results, is_bbox_normalized=False): + xywh_res = [] + for t in results: + bboxes = t['proposal'][0] + lengths = t['proposal'][1][0] + im_ids = np.array(t['im_id'][0]).flatten() + assert len(lengths) == im_ids.size + if bboxes.shape == (1, 1) or bboxes is None: + continue + + k = 0 + for i in range(len(lengths)): + num = lengths[i] + im_id = int(im_ids[i]) + for j in range(num): + dt = bboxes[k] + xmin, ymin, xmax, ymax = dt.tolist() + + if is_bbox_normalized: + xmin, ymin, xmax, ymax = \ + clip_bbox([xmin, ymin, xmax, ymax]) + w = xmax - xmin + h = ymax - ymin + else: + w = xmax - xmin + 1 + h = ymax - ymin + 1 + + bbox = [xmin, ymin, w, h] + coco_res = { + 'image_id': im_id, + 'category_id': 1, + 'bbox': bbox, + 'score': 1.0 + } + xywh_res.append(coco_res) + k += 1 + return xywh_res + + +def bbox2out(results, clsid2catid, is_bbox_normalized=False): + """ + Args: + results: request a dict, should include: `bbox`, `im_id`, + if is_bbox_normalized=True, also need `im_shape`. + clsid2catid: class id to category id map of COCO2017 dataset. + is_bbox_normalized: whether or not bbox is normalized. + """ + xywh_res = [] + for t in results: + bboxes = t['bbox'][0] + if len(t['bbox'][1]) == 0: continue + lengths = t['bbox'][1][0] + im_ids = np.array(t['im_id'][0]).flatten() + if bboxes.shape == (1, 1) or bboxes is None or len(bboxes) == 0: + continue + + k = 0 + for i in range(len(lengths)): + num = lengths[i] + im_id = int(im_ids[i]) + for j in range(num): + dt = bboxes[k] + clsid, score, xmin, ymin, xmax, ymax = dt.tolist() + if clsid < 0: + k += 1 #to continue to handle the subsequent bbox of current image + continue + catid = (clsid2catid[int(clsid)]) + + if is_bbox_normalized: + xmin, ymin, xmax, ymax = \ + clip_bbox([xmin, ymin, xmax, ymax]) + w = xmax - xmin + h = ymax - ymin + im_shape = t['im_shape'][0][i].tolist() + im_height, im_width = int(im_shape[0]), int(im_shape[1]) + xmin *= im_width + ymin *= im_height + w *= im_width + h *= im_height + else: + # for yolov4 + # w = xmax - xmin + # h = ymax - ymin + w = xmax - xmin + 1 + h = ymax - ymin + 1 + + bbox = [xmin, ymin, w, h] + coco_res = { + 'image_id': im_id, + 'category_id': catid, + 'bbox': bbox, + 'score': score + } + xywh_res.append(coco_res) + k += 1 + return xywh_res + + +def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5): + import pycocotools.mask as mask_util + scale = (resolution + 2.0) / resolution + + segm_res = [] + + # for each batch + for t in results: + bboxes = t['bbox'][0] + + lengths = t['bbox'][1][0] + im_ids = np.array(t['im_id'][0]) + if bboxes.shape == (1, 1) or bboxes is None: + continue + if len(bboxes.tolist()) == 0: + continue + + masks = t['mask'][0] + + s = 0 + # for each sample + for i in range(len(lengths)): + num = lengths[i] + im_id = int(im_ids[i][0]) + im_shape = t['im_shape'][0][i] + + bbox = bboxes[s:s + num][:, 2:] + clsid_scores = bboxes[s:s + num][:, 0:2] + mask = masks[s:s + num] + s += num + + im_h = int(im_shape[0]) + im_w = int(im_shape[1]) + + expand_bbox = expand_boxes(bbox, scale) + expand_bbox = expand_bbox.astype(np.int32) + + padded_mask = np.zeros( + (resolution + 2, resolution + 2), dtype=np.float32) + + for j in range(num): + xmin, ymin, xmax, ymax = expand_bbox[j].tolist() + clsid, score = clsid_scores[j].tolist() + clsid = int(clsid) + padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :] + + catid = clsid2catid[clsid] + + w = xmax - xmin + 1 + h = ymax - ymin + 1 + w = np.maximum(w, 1) + h = np.maximum(h, 1) + + resized_mask = cv2.resize(padded_mask, (w, h)) + resized_mask = np.array( + resized_mask > thresh_binarize, dtype=np.uint8) + im_mask = np.zeros((im_h, im_w), dtype=np.uint8) + + x0 = min(max(xmin, 0), im_w) + x1 = min(max(xmax + 1, 0), im_w) + y0 = min(max(ymin, 0), im_h) + y1 = min(max(ymax + 1, 0), im_h) + + im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), ( + x0 - xmin):(x1 - xmin)] + segm = mask_util.encode( + np.array( + im_mask[:, :, np.newaxis], order='F'))[0] + catid = clsid2catid[clsid] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'image_id': im_id, + 'category_id': catid, + 'segmentation': segm, + 'score': score + } + segm_res.append(coco_res) + return segm_res + + +def segm2out(results, clsid2catid, thresh_binarize=0.5): + import pycocotools.mask as mask_util + segm_res = [] + + # for each batch + for t in results: + segms = t['segm'][0].astype(np.uint8) + clsid_labels = t['cate_label'][0] + clsid_scores = t['cate_score'][0] + lengths = segms.shape[0] + im_id = int(t['im_id'][0][0]) + im_shape = t['im_shape'][0][0] + if lengths == 0 or segms is None: + continue + # for each sample + for i in range(lengths - 1): + im_h = int(im_shape[0]) + im_w = int(im_shape[1]) + + clsid = int(clsid_labels[i]) + 1 + catid = clsid2catid[clsid] + score = clsid_scores[i] + mask = segms[i] + segm = mask_util.encode( + np.array( + mask[:, :, np.newaxis], order='F'))[0] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'image_id': im_id, + 'category_id': catid, + 'segmentation': segm, + 'score': score + } + segm_res.append(coco_res) + return segm_res + + +def expand_boxes(boxes, scale): + """ + Expand an array of boxes by a given scale. + """ + w_half = (boxes[:, 2] - boxes[:, 0]) * .5 + h_half = (boxes[:, 3] - boxes[:, 1]) * .5 + x_c = (boxes[:, 2] + boxes[:, 0]) * .5 + y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + + w_half *= scale + h_half *= scale + + boxes_exp = np.zeros(boxes.shape) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + + return boxes_exp + + +def get_category_info(anno_file=None, + with_background=True, + use_default_label=False): + if use_default_label or anno_file is None \ + or not os.path.exists(anno_file): + logger.info("Not found annotation file {}, load " + "coco17 categories.".format(anno_file)) + return coco17_category_info(with_background) + else: + logger.info("Load categories from {}".format(anno_file)) + return get_category_info_from_anno(anno_file, with_background) + + +def get_category_info_from_anno(anno_file, with_background=True): + """ + Get class id to category id map and category id + to category name map from annotation file. + + Args: + anno_file (str): annotation file path + with_background (bool, default True): + whether load background as class 0. + """ + from pycocotools.coco import COCO + coco = COCO(anno_file) + cats = coco.loadCats(coco.getCatIds()) + clsid2catid = { + i + int(with_background): cat['id'] + for i, cat in enumerate(cats) + } + catid2name = {cat['id']: cat['name'] for cat in cats} + if with_background: + clsid2catid.update({0: 0}) + catid2name.update({0: 'background'}) + return clsid2catid, catid2name + + +def coco17_category_info(with_background=True): + """ + Get class id to category id map and category id + to category name map of COCO2017 dataset + + Args: + with_background (bool, default True): + whether load background as class 0. + """ + clsid2catid = { + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 10: 10, + 11: 11, + 12: 13, + 13: 14, + 14: 15, + 15: 16, + 16: 17, + 17: 18, + 18: 19, + 19: 20, + 20: 21, + 21: 22, + 22: 23, + 23: 24, + 24: 25, + 25: 27, + 26: 28, + 27: 31, + 28: 32, + 29: 33, + 30: 34, + 31: 35, + 32: 36, + 33: 37, + 34: 38, + 35: 39, + 36: 40, + 37: 41, + 38: 42, + 39: 43, + 40: 44, + 41: 46, + 42: 47, + 43: 48, + 44: 49, + 45: 50, + 46: 51, + 47: 52, + 48: 53, + 49: 54, + 50: 55, + 51: 56, + 52: 57, + 53: 58, + 54: 59, + 55: 60, + 56: 61, + 57: 62, + 58: 63, + 59: 64, + 60: 65, + 61: 67, + 62: 70, + 63: 72, + 64: 73, + 65: 74, + 66: 75, + 67: 76, + 68: 77, + 69: 78, + 70: 79, + 71: 80, + 72: 81, + 73: 82, + 74: 84, + 75: 85, + 76: 86, + 77: 87, + 78: 88, + 79: 89, + 80: 90 + } + + catid2name = { + 0: 'background', + 1: 'person', + 2: 'bicycle', + 3: 'car', + 4: 'motorcycle', + 5: 'airplane', + 6: 'bus', + 7: 'train', + 8: 'truck', + 9: 'boat', + 10: 'traffic light', + 11: 'fire hydrant', + 13: 'stop sign', + 14: 'parking meter', + 15: 'bench', + 16: 'bird', + 17: 'cat', + 18: 'dog', + 19: 'horse', + 20: 'sheep', + 21: 'cow', + 22: 'elephant', + 23: 'bear', + 24: 'zebra', + 25: 'giraffe', + 27: 'backpack', + 28: 'umbrella', + 31: 'handbag', + 32: 'tie', + 33: 'suitcase', + 34: 'frisbee', + 35: 'skis', + 36: 'snowboard', + 37: 'sports ball', + 38: 'kite', + 39: 'baseball bat', + 40: 'baseball glove', + 41: 'skateboard', + 42: 'surfboard', + 43: 'tennis racket', + 44: 'bottle', + 46: 'wine glass', + 47: 'cup', + 48: 'fork', + 49: 'knife', + 50: 'spoon', + 51: 'bowl', + 52: 'banana', + 53: 'apple', + 54: 'sandwich', + 55: 'orange', + 56: 'broccoli', + 57: 'carrot', + 58: 'hot dog', + 59: 'pizza', + 60: 'donut', + 61: 'cake', + 62: 'chair', + 63: 'couch', + 64: 'potted plant', + 65: 'bed', + 67: 'dining table', + 70: 'toilet', + 72: 'tv', + 73: 'laptop', + 74: 'mouse', + 75: 'remote', + 76: 'keyboard', + 77: 'cell phone', + 78: 'microwave', + 79: 'oven', + 80: 'toaster', + 81: 'sink', + 82: 'refrigerator', + 84: 'book', + 85: 'clock', + 86: 'vase', + 87: 'scissors', + 88: 'teddy bear', + 89: 'hair drier', + 90: 'toothbrush' + } + + if not with_background: + clsid2catid = {k - 1: v for k, v in clsid2catid.items()} + catid2name.pop(0) + else: + clsid2catid.update({0: 0}) + + return clsid2catid, catid2name diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/colormap.py b/VisualFL/depends/PaddleDetection/ppdet/utils/colormap.py new file mode 100755 index 000000000..566185ef9 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/colormap.py @@ -0,0 +1,56 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np + + +def colormap(rgb=False): + """ + Get colormap + """ + color_list = np.array([ + 0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, + 0.184, 0.556, 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078, + 0.184, 0.300, 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000, + 0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, 0.333, 0.667, + 0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000, + 1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000, + 0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, 0.667, + 0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333, + 0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333, + 0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, + 1.000, 0.667, 0.333, 1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, 1.000, 0.333, 1.000, 1.000, 0.667, 1.000, 0.167, + 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.286, + 0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, 0.714, 0.714, + 0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000 + ]).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/dist_utils.py b/VisualFL/depends/PaddleDetection/ppdet/utils/dist_utils.py new file mode 100755 index 000000000..32eead4a7 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/dist_utils.py @@ -0,0 +1,41 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +import os + +import paddle.fluid as fluid + + +def nccl2_prepare(trainer_id, startup_prog, main_prog): + config = fluid.DistributeTranspilerConfig() + config.mode = "nccl2" + t = fluid.DistributeTranspiler(config=config) + t.transpile( + trainer_id, + trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'), + current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'), + startup_program=startup_prog, + program=main_prog) + + +def prepare_for_multi_process(exe, build_strategy, startup_prog, main_prog): + trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0)) + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + if num_trainers < 2: + return + build_strategy.num_trainers = num_trainers + build_strategy.trainer_id = trainer_id + nccl2_prepare(trainer_id, startup_prog, main_prog) diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/download.py b/VisualFL/depends/PaddleDetection/ppdet/utils/download.py new file mode 100755 index 000000000..6e4cb4019 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/download.py @@ -0,0 +1,415 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path as osp +import shutil +import requests +import tqdm +import hashlib +import binascii +import base64 +import tarfile +import zipfile + +from .voc_utils import create_list + +import logging +logger = logging.getLogger(__name__) + +__all__ = [ + 'get_weights_path', 'get_dataset_path', 'download_dataset', + 'create_voc_list' +] + +WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/weights") +DATASET_HOME = osp.expanduser("~/.cache/paddle/dataset") + +# dict of {dataset_name: (download_info, sub_dirs)} +# download info: [(url, md5sum)] +DATASETS = { + 'coco': ([ + ( + 'http://images.cocodataset.org/zips/train2017.zip', + 'cced6f7f71b7629ddf16f17bbcfab6b2', ), + ( + 'http://images.cocodataset.org/zips/val2017.zip', + '442b8da7639aecaf257c1dceb8ba8c80', ), + ( + 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip', + 'f4bbac642086de4f52a3fdda2de5fa2c', ), + ], ["annotations", "train2017", "val2017"]), + 'voc': ([ + ( + 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', + '6cd6e144f989b92b3379bac3b3de84fd', ), + ( + 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', + 'c52e279531787c972589f7e41ab4ae64', ), + ( + 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', + 'b6e924de25625d8de591ea690078ad9f', ), + ], ["VOCdevkit/VOC2012", "VOCdevkit/VOC2007"]), + 'wider_face': ([ + ( + 'https://dataset.bj.bcebos.com/wider_face/WIDER_train.zip', + '3fedf70df600953d25982bcd13d91ba2', ), + ( + 'https://dataset.bj.bcebos.com/wider_face/WIDER_val.zip', + 'dfa7d7e790efa35df3788964cf0bbaea', ), + ( + 'https://dataset.bj.bcebos.com/wider_face/wider_face_split.zip', + 'a4a898d6193db4b9ef3260a68bad0dc7', ), + ], ["WIDER_train", "WIDER_val", "wider_face_split"]), + 'fruit': ([( + 'https://dataset.bj.bcebos.com/PaddleDetection_demo/fruit.tar', + 'baa8806617a54ccf3685fa7153388ae6', ), ], + ['Annotations', 'JPEGImages']), + 'roadsign_voc': ([( + 'https://paddlemodels.bj.bcebos.com/object_detection/roadsign_voc.tar', + '8d629c0f880dd8b48de9aeff44bf1f3e', ), ], ['annotations', 'images']), + 'roadsign_coco': ([( + 'https://paddlemodels.bj.bcebos.com/object_detection/roadsign_coco.tar', + '49ce5a9b5ad0d6266163cd01de4b018e', ), ], ['annotations', 'images']), + 'objects365': (), +} + +DOWNLOAD_RETRY_LIMIT = 3 + + +def get_weights_path(url): + """Get weights path from WEIGHT_HOME, if not exists, + download it from url. + """ + path, _ = get_path(url, WEIGHTS_HOME) + return path + + +def get_dataset_path(path, annotation, image_dir): + """ + If path exists, return path. + Otherwise, get dataset path from DATASET_HOME, if not exists, + download it. + """ + if _dataset_exists(path, annotation, image_dir): + return path + + logger.info("Dataset {} is not valid for reason above, try searching {} or " + "downloading dataset...".format( + osp.realpath(path), DATASET_HOME)) + + data_name = os.path.split(path.strip().lower())[-1] + for name, dataset in DATASETS.items(): + if data_name == name: + logger.debug("Parse dataset_dir {} as dataset " + "{}".format(path, name)) + if name == 'objects365': + raise NotImplementedError( + "Dataset {} is not valid for download automatically. " + "Please apply and download the dataset from " + "https://www.objects365.org/download.html".format(name)) + data_dir = osp.join(DATASET_HOME, name) + # For VOC-style datasets, only check subdirs + if name in ['voc', 'fruit', 'roadsign_voc']: + exists = True + for sub_dir in dataset[1]: + check_dir = osp.join(data_dir, sub_dir) + if osp.exists(check_dir): + logger.info("Found {}".format(check_dir)) + else: + exists = False + if exists: + return data_dir + + # voc exist is checked above, voc is not exist here + check_exist = name != 'voc' and name != 'fruit' and name != 'roadsign_voc' + for url, md5sum in dataset[0]: + get_path(url, data_dir, md5sum, check_exist) + + # voc should create list after download + if name == 'voc': + create_voc_list(data_dir) + return data_dir + + # not match any dataset in DATASETS + raise ValueError( + "Dataset {} is not valid and cannot parse dataset type " + "'{}' for automaticly downloading, which only supports " + "'voc' , 'coco', 'wider_face', 'fruit' and 'roadsign_voc' currently". + format(path, osp.split(path)[-1])) + + +def create_voc_list(data_dir, devkit_subdir='VOCdevkit'): + logger.debug("Create voc file list...") + devkit_dir = osp.join(data_dir, devkit_subdir) + year_dirs = [osp.join(devkit_dir, x) for x in os.listdir(devkit_dir)] + + # NOTE: since using auto download VOC + # dataset, VOC default label list should be used, + # do not generate label_list.txt here. For default + # label, see ../data/source/voc.py + create_list(year_dirs, data_dir) + logger.debug("Create voc file list finished") + + +def map_path(url, root_dir): + # parse path after download to decompress under root_dir + fname = osp.split(url)[-1] + zip_formats = ['.zip', '.tar', '.gz'] + fpath = fname + for zip_format in zip_formats: + fpath = fpath.replace(zip_format, '') + return osp.join(root_dir, fpath) + + +def get_path(url, root_dir, md5sum=None, check_exist=True): + """ Download from given url to root_dir. + if file or directory specified by url is exists under + root_dir, return the path directly, otherwise download + from url and decompress it, return the path. + + url (str): download url + root_dir (str): root dir for downloading, it should be + WEIGHTS_HOME or DATASET_HOME + md5sum (str): md5 sum of download package + """ + # parse path after download to decompress under root_dir + fullpath = map_path(url, root_dir) + + # For same zip file, decompressed directory name different + # from zip file name, rename by following map + decompress_name_map = { + "VOCtrainval_11-May-2012": "VOCdevkit/VOC2012", + "VOCtrainval_06-Nov-2007": "VOCdevkit/VOC2007", + "VOCtest_06-Nov-2007": "VOCdevkit/VOC2007", + "annotations_trainval": "annotations" + } + for k, v in decompress_name_map.items(): + if fullpath.find(k) >= 0: + fullpath = osp.join(osp.split(fullpath)[0], v) + + if osp.exists(fullpath) and check_exist: + # If fullpath is a directory, it has been decompressed + # checking MD5 is impossible, so we skip checking when + # fullpath is a directory here + if osp.isdir(fullpath) or \ + _md5check_from_req(fullpath, + requests.get(url, stream=True)): + logger.debug("Found {}".format(fullpath)) + return fullpath, True + else: + if osp.isdir(fullpath): + shutil.rmtree(fullpath) + else: + os.remove(fullpath) + + fullname = _download(url, root_dir, md5sum) + + # new weights format whose postfix is 'pdparams', + # which is not need to decompress + if osp.splitext(fullname)[-1] != '.pdparams': + _decompress(fullname) + + return fullpath, False + + +def download_dataset(path, dataset=None): + if dataset not in DATASETS.keys(): + logger.error("Unknown dataset {}, it should be " + "{}".format(dataset, DATASETS.keys())) + return + dataset_info = DATASETS[dataset][0] + for info in dataset_info: + get_path(info[0], path, info[1], False) + logger.debug("Download dataset {} finished.".format(dataset)) + + +def _dataset_exists(path, annotation, image_dir): + """ + Check if user define dataset exists + """ + if not osp.exists(path): + logger.debug("Config dataset_dir {} is not exits, " + "dataset config is not valid".format(path)) + return False + + if annotation: + annotation_path = osp.join(path, annotation) + if not osp.exists(annotation_path): + logger.error("Config dataset_dir {} is not exits!".format(path)) + + if not osp.isfile(annotation_path): + logger.warning("Config annotation {} is not a " + "file, dataset config is not " + "valid".format(annotation_path)) + return False + if image_dir: + image_path = osp.join(path, image_dir) + if not osp.exists(image_path): + logger.warning("Config dataset_dir {} is not exits!".format(path)) + + if not osp.isdir(image_path): + logger.warning("Config image_dir {} is not a " + "directory, dataset config is not " + "valid".format(image_path)) + return False + return True + + +def _download(url, path, md5sum=None): + """ + Download from url, save to path. + + url (str): download url + path (str): download to given path + """ + if not osp.exists(path): + os.makedirs(path) + + fname = osp.split(url)[-1] + fullname = osp.join(path, fname) + retry_cnt = 0 + + while not (osp.exists(fullname) and _md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + raise RuntimeError("Download from {} failed. " + "Retry limit reached".format(url)) + + logger.info("Downloading {} from {}".format(fname, url)) + + req = requests.get(url, stream=True) + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size: + for chunk in tqdm.tqdm( + req.iter_content(chunk_size=1024), + total=(int(total_size) + 1023) // 1024, + unit='KB'): + f.write(chunk) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + + # check md5 after download in Content-MD5 in req.headers + if _md5check_from_req(tmp_fullname, req): + shutil.move(tmp_fullname, fullname) + return fullname + else: + logger.warn( + "Download from url imcomplete, try downloading again...") + os.remove(tmp_fullname) + continue + + +def _md5check_from_req(weights_path, req): + # For weights in bcebos URLs, MD5 value is contained + # in request header as 'content_md5' + content_md5 = req.headers.get('content-md5') + if not content_md5 or _md5check( + weights_path, + binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode( + )): + return True + else: + return False + + +def _md5check(fullname, md5sum=None): + if md5sum is None: + return True + + logger.debug("File {} md5 checking...".format(fullname)) + md5 = hashlib.md5() + with open(fullname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + logger.warning("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) + return False + return True + + +def _decompress(fname): + """ + Decompress for zip and tar file + """ + logger.info("Decompressing {}...".format(fname)) + + # For protecting decompressing interupted, + # decompress to fpath_tmp directory firstly, if decompress + # successed, move decompress files to fpath and delete + # fpath_tmp and remove download compress file. + fpath = osp.split(fname)[0] + fpath_tmp = osp.join(fpath, 'tmp') + if osp.isdir(fpath_tmp): + shutil.rmtree(fpath_tmp) + os.makedirs(fpath_tmp) + + if fname.find('tar') >= 0: + with tarfile.open(fname) as tf: + tf.extractall(path=fpath_tmp) + elif fname.find('zip') >= 0: + with zipfile.ZipFile(fname) as zf: + zf.extractall(path=fpath_tmp) + else: + raise TypeError("Unsupport compress file type {}".format(fname)) + + for f in os.listdir(fpath_tmp): + src_dir = osp.join(fpath_tmp, f) + dst_dir = osp.join(fpath, f) + _move_and_merge_tree(src_dir, dst_dir) + + shutil.rmtree(fpath_tmp) + os.remove(fname) + + +def _move_and_merge_tree(src, dst): + """ + Move src directory to dst, if dst is already exists, + merge src to dst + """ + if not osp.exists(dst): + shutil.move(src, dst) + elif osp.isfile(src): + shutil.move(src, dst) + else: + for fp in os.listdir(src): + src_fp = osp.join(src, fp) + dst_fp = osp.join(dst, fp) + if osp.isdir(src_fp): + if osp.isdir(dst_fp): + _move_and_merge_tree(src_fp, dst_fp) + else: + shutil.move(src_fp, dst_fp) + elif osp.isfile(src_fp) and \ + not osp.isfile(dst_fp): + shutil.move(src_fp, dst_fp) diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/eval_utils.py b/VisualFL/depends/PaddleDetection/ppdet/utils/eval_utils.py new file mode 100755 index 000000000..d6769bbcf --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/eval_utils.py @@ -0,0 +1,292 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import numpy as np +import os +import time + +import paddle.fluid as fluid + +from .voc_eval import bbox_eval as voc_bbox_eval +from .post_process import mstest_box_post_process, mstest_mask_post_process, box_flip + +__all__ = ['parse_fetches', 'eval_run', 'eval_results', 'json_eval_results'] + +logger = logging.getLogger(__name__) + + +def parse_fetches(fetches, prog=None, extra_keys=None): + """ + Parse fetch variable infos from model fetches, + values for fetch_list and keys for stat + """ + keys, values = [], [] + cls = [] + for k, v in fetches.items(): + if hasattr(v, 'name'): + keys.append(k) + #v.persistable = True + values.append(v.name) + else: + cls.append(v) + + if prog is not None and extra_keys is not None: + for k in extra_keys: + try: + v = fluid.framework._get_var(k, prog) + keys.append(k) + values.append(v.name) + except Exception: + pass + + return keys, values, cls + + +def length2lod(length_lod): + offset_lod = [0] + for i in length_lod: + offset_lod.append(offset_lod[-1] + i) + return [offset_lod] + + +def get_sub_feed(input, place): + new_dict = {} + res_feed = {} + key_name = ['bbox', 'im_info', 'im_id', 'im_shape', 'bbox_flip'] + for k in key_name: + if k in input.keys(): + new_dict[k] = input[k] + for k in input.keys(): + if 'image' in k: + new_dict[k] = input[k] + for k, v in new_dict.items(): + data_t = fluid.LoDTensor() + data_t.set(v[0], place) + if 'bbox' in k: + lod = length2lod(v[1][0]) + data_t.set_lod(lod) + res_feed[k] = data_t + return res_feed + + +def clean_res(result, keep_name_list): + clean_result = {} + for k in result.keys(): + if k in keep_name_list: + clean_result[k] = result[k] + result.clear() + return clean_result + + +def get_masks(result): + import pycocotools.mask as mask_util + if result is None: + return {} + seg_pred = result['segm'][0].astype(np.uint8) + cate_label = result['cate_label'][0].astype(np.int) + cate_score = result['cate_score'][0].astype(np.float) + num_ins = seg_pred.shape[0] + masks = [] + for idx in range(num_ins - 1): + cur_mask = seg_pred[idx, ...] + rle = mask_util.encode( + np.array( + cur_mask[:, :, np.newaxis], order='F'))[0] + rst = (rle, cate_score[idx]) + masks.append([cate_label[idx], rst]) + return masks + + +def eval_run(exe, + compile_program, + loader, + keys, + values, + cls, + cfg=None, + sub_prog=None, + sub_keys=None, + sub_values=None, + resolution=None): + """ + Run evaluation program, return program outputs. + """ + iter_id = 0 + results = [] + if len(cls) != 0: + values = [] + for i in range(len(cls)): + _, accum_map = cls[i].get_map_var() + cls[i].reset(exe) + values.append(accum_map) + + images_num = 0 + start_time = time.time() + has_bbox = 'bbox' in keys + + try: + loader.start() + while True: + outs = exe.run(compile_program, + fetch_list=values, + return_numpy=False) + res = { + k: (np.array(v), v.recursive_sequence_lengths()) + for k, v in zip(keys, outs) + } + multi_scale_test = getattr(cfg, 'MultiScaleTEST', None) + mask_multi_scale_test = multi_scale_test and 'Mask' in cfg.architecture + + if multi_scale_test: + post_res = mstest_box_post_process(res, multi_scale_test, + cfg.num_classes) + res.update(post_res) + if mask_multi_scale_test: + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + sub_feed = get_sub_feed(res, place) + sub_prog_outs = exe.run(sub_prog, + feed=sub_feed, + fetch_list=sub_values, + return_numpy=False) + sub_prog_res = { + k: (np.array(v), v.recursive_sequence_lengths()) + for k, v in zip(sub_keys, sub_prog_outs) + } + post_res = mstest_mask_post_process(sub_prog_res, cfg) + res.update(post_res) + if multi_scale_test: + res = clean_res( + res, ['im_info', 'bbox', 'im_id', 'im_shape', 'mask']) + if 'mask' in res: + from ppdet.utils.post_process import mask_encode + res['mask'] = mask_encode(res, resolution) + post_config = getattr(cfg, 'PostProcess', None) + if 'Corner' in cfg.architecture and post_config is not None: + from ppdet.utils.post_process import corner_post_process + corner_post_process(res, post_config, cfg.num_classes) + if 'TTFNet' in cfg.architecture: + res['bbox'][1].append([len(res['bbox'][0])]) + if 'segm' in res: + res['segm'] = get_masks(res) + results.append(res) + if iter_id % 100 == 0: + logger.info('Test iter {}'.format(iter_id)) + iter_id += 1 + if 'bbox' not in res or len(res['bbox'][1]) == 0: + has_bbox = False + images_num += len(res['bbox'][1][0]) if has_bbox else 1 + except (StopIteration, fluid.core.EOFException): + loader.reset() + logger.info('Test finish iter {}'.format(iter_id)) + + end_time = time.time() + fps = images_num / (end_time - start_time) + if has_bbox: + logger.info('Total number of images: {}, inference time: {} fps.'. + format(images_num, fps)) + else: + logger.info('Total iteration: {}, inference time: {} batch/s.'.format( + images_num, fps)) + + return results + + +def eval_results(results, + metric, + num_classes, + resolution=None, + is_bbox_normalized=False, + output_directory=None, + map_type='11point', + dataset=None, + save_only=False): + """Evaluation for evaluation program results""" + box_ap_stats = [] + if metric == 'COCO': + from ppdet.utils.coco_eval import proposal_eval, bbox_eval, mask_eval, segm_eval + anno_file = dataset.get_anno() + with_background = dataset.with_background + if 'proposal' in results[0]: + output = 'proposal.json' + if output_directory: + output = os.path.join(output_directory, 'proposal.json') + proposal_eval(results, anno_file, output) + if 'bbox' in results[0]: + output = 'bbox.json' + if output_directory: + output = os.path.join(output_directory, 'bbox.json') + + box_ap_stats = bbox_eval( + results, + anno_file, + output, + with_background, + is_bbox_normalized=is_bbox_normalized, + save_only=save_only) + + if 'mask' in results[0]: + output = 'mask.json' + if output_directory: + output = os.path.join(output_directory, 'mask.json') + mask_eval( + results, anno_file, output, resolution, save_only=save_only) + if 'segm' in results[0]: + output = 'segm.json' + if output_directory: + output = os.path.join(output_directory, output) + mask_ap_stats = segm_eval( + results, anno_file, output, save_only=save_only) + if len(box_ap_stats) == 0: + box_ap_stats = mask_ap_stats + else: + if 'accum_map' in results[-1]: + res = np.mean(results[-1]['accum_map'][0]) + logger.info('mAP: {:.2f}'.format(res * 100.)) + box_ap_stats.append(res * 100.) + elif 'bbox' in results[0]: + box_ap = voc_bbox_eval( + results, + num_classes, + is_bbox_normalized=is_bbox_normalized, + map_type=map_type) + box_ap_stats.append(box_ap) + return box_ap_stats + + +def json_eval_results(metric, json_directory=None, dataset=None): + """ + cocoapi eval with already exists proposal.json, bbox.json or mask.json + """ + assert metric == 'COCO' + from ppdet.utils.coco_eval import cocoapi_eval + anno_file = dataset.get_anno() + json_file_list = ['proposal.json', 'bbox.json', 'mask.json'] + if json_directory: + assert os.path.exists( + json_directory), "The json directory:{} does not exist".format( + json_directory) + for k, v in enumerate(json_file_list): + json_file_list[k] = os.path.join(str(json_directory), v) + + coco_eval_style = ['proposal', 'bbox', 'segm'] + for i, v_json in enumerate(json_file_list): + if os.path.exists(v_json): + cocoapi_eval(v_json, coco_eval_style[i], anno_file=anno_file) + else: + logger.info("{} not exists!".format(v_json)) diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/export_utils.py b/VisualFL/depends/PaddleDetection/ppdet/utils/export_utils.py new file mode 100755 index 000000000..1904e7cfd --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/export_utils.py @@ -0,0 +1,198 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import yaml +import numpy as np +from collections import OrderedDict + +import logging +logger = logging.getLogger(__name__) + +import paddle.fluid as fluid + +__all__ = ['dump_infer_config', 'save_infer_model'] + +# Global dictionary +TRT_MIN_SUBGRAPH = { + 'YOLO': 3, + 'SSD': 3, + 'RCNN': 40, + 'RetinaNet': 40, + 'EfficientDet': 40, + 'Face': 3, + 'TTFNet': 3, + 'FCOS': 3, + 'SOLOv2': 60, +} +RESIZE_SCALE_SET = { + 'RCNN', + 'RetinaNet', + 'FCOS', + 'SOLOv2', +} + + +def parse_reader(reader_cfg, metric, arch): + preprocess_list = [] + + image_shape = reader_cfg['inputs_def'].get('image_shape', [3, None, None]) + has_shape_def = not None in image_shape + + dataset = reader_cfg['dataset'] + anno_file = dataset.get_anno() + with_background = dataset.with_background + use_default_label = dataset.use_default_label + + if metric == 'COCO': + from ppdet.utils.coco_eval import get_category_info + elif metric == "VOC": + from ppdet.utils.voc_eval import get_category_info + elif metric == "WIDERFACE": + from ppdet.utils.widerface_eval_utils import get_category_info + else: + raise ValueError( + "metric only supports COCO, VOC, WIDERFACE, but received {}".format( + metric)) + clsid2catid, catid2name = get_category_info(anno_file, with_background, + use_default_label) + + label_list = [str(cat) for cat in catid2name.values()] + + sample_transforms = reader_cfg['sample_transforms'] + for st in sample_transforms[1:]: + method = st.__class__.__name__ + p = {'type': method.replace('Image', '')} + params = st.__dict__ + params.pop('_id') + if p['type'] == 'Resize' and has_shape_def: + params['target_size'] = min(image_shape[ + 1:]) if arch in RESIZE_SCALE_SET else image_shape[1] + params['max_size'] = max(image_shape[ + 1:]) if arch in RESIZE_SCALE_SET else 0 + params['image_shape'] = image_shape[1:] + if 'target_dim' in params: + params.pop('target_dim') + if p['type'] == 'ResizeAndPad': + assert has_shape_def, "missing input shape" + p['type'] = 'Resize' + p['target_size'] = params['target_dim'] + p['max_size'] = params['target_dim'] + p['interp'] = params['interp'] + p['image_shape'] = image_shape[1:] + preprocess_list.append(p) + continue + p.update(params) + preprocess_list.append(p) + batch_transforms = reader_cfg.get('batch_transforms', None) + if batch_transforms: + methods = [bt.__class__.__name__ for bt in batch_transforms] + for bt in batch_transforms: + method = bt.__class__.__name__ + if method == 'PadBatch': + preprocess_list.append({'type': 'PadStride'}) + params = bt.__dict__ + preprocess_list[-1].update({'stride': params['pad_to_stride']}) + break + + return with_background, preprocess_list, label_list + + +def dump_infer_config(FLAGS, config): + arch_state = 0 + cfg_name = os.path.basename(FLAGS.config).split('.')[0] + save_dir = os.path.join(FLAGS.output_dir, cfg_name) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + from ppdet.core.config.yaml_helpers import setup_orderdict + setup_orderdict() + infer_cfg = OrderedDict({ + 'use_python_inference': False, + 'mode': 'fluid', + 'draw_threshold': 0.5, + 'metric': config['metric'] + }) + infer_arch = config['architecture'] + + for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items(): + if arch in infer_arch: + infer_cfg['arch'] = arch + infer_cfg['min_subgraph_size'] = min_subgraph_size + arch_state = 1 + break + if not arch_state: + logger.error( + 'Architecture: {} is not supported for exporting model now'.format( + infer_arch)) + os._exit(0) + + # support land mark output + if 'with_lmk' in config and config['with_lmk'] == True: + infer_cfg['with_lmk'] = True + + if 'Mask' in config['architecture']: + infer_cfg['mask_resolution'] = config['MaskHead']['resolution'] + infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[ + 'label_list'] = parse_reader(config['TestReader'], config['metric'], + infer_cfg['arch']) + + yaml.dump(infer_cfg, open(os.path.join(save_dir, 'infer_cfg.yml'), 'w')) + logger.info("Export inference config file to {}".format( + os.path.join(save_dir, 'infer_cfg.yml'))) + + +def prune_feed_vars(feeded_var_names, target_vars, prog): + """ + Filter out feed variables which are not in program, + pruned feed variables are only used in post processing + on model output, which are not used in program, such + as im_id to identify image order, im_shape to clip bbox + in image. + """ + exist_var_names = [] + prog = prog.clone() + prog = prog._prune(targets=target_vars) + global_block = prog.global_block() + for name in feeded_var_names: + try: + v = global_block.var(name) + exist_var_names.append(str(v.name)) + except Exception: + logger.info('save_inference_model pruned unused feed ' + 'variables {}'.format(name)) + pass + return exist_var_names + + +def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): + cfg_name = os.path.basename(FLAGS.config).split('.')[0] + save_dir = os.path.join(FLAGS.output_dir, cfg_name) + feed_var_names = [var.name for var in feed_vars.values()] + fetch_list = sorted(test_fetches.items(), key=lambda i: i[0]) + target_vars = [var[1] for var in fetch_list] + feed_var_names = prune_feed_vars(feed_var_names, target_vars, infer_prog) + logger.info("Export inference model to {}, input: {}, output: " + "{}...".format(save_dir, feed_var_names, + [str(var.name) for var in target_vars])) + fluid.io.save_inference_model( + save_dir, + feeded_var_names=feed_var_names, + target_vars=target_vars, + executor=exe, + main_program=infer_prog, + params_filename="__params__") diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/map_utils.py b/VisualFL/depends/PaddleDetection/ppdet/utils/map_utils.py new file mode 100755 index 000000000..085556972 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/map_utils.py @@ -0,0 +1,216 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import numpy as np +import logging +logger = logging.getLogger(__name__) + +__all__ = ['bbox_area', 'jaccard_overlap', 'DetectionMAP'] + + +def bbox_area(bbox, is_bbox_normalized): + """ + Calculate area of a bounding box + """ + norm = 1. - float(is_bbox_normalized) + width = bbox[2] - bbox[0] + norm + height = bbox[3] - bbox[1] + norm + return width * height + + +def jaccard_overlap(pred, gt, is_bbox_normalized=False): + """ + Calculate jaccard overlap ratio between two bounding box + """ + if pred[0] >= gt[2] or pred[2] <= gt[0] or \ + pred[1] >= gt[3] or pred[3] <= gt[1]: + return 0. + inter_xmin = max(pred[0], gt[0]) + inter_ymin = max(pred[1], gt[1]) + inter_xmax = min(pred[2], gt[2]) + inter_ymax = min(pred[3], gt[3]) + inter_size = bbox_area([inter_xmin, inter_ymin, inter_xmax, inter_ymax], + is_bbox_normalized) + pred_size = bbox_area(pred, is_bbox_normalized) + gt_size = bbox_area(gt, is_bbox_normalized) + overlap = float(inter_size) / (pred_size + gt_size - inter_size) + return overlap + + +class DetectionMAP(object): + """ + Calculate detection mean average precision. + Currently support two types: 11point and integral + + Args: + class_num (int): the class number. + overlap_thresh (float): The threshold of overlap + ratio between prediction bounding box and + ground truth bounding box for deciding + true/false positive. Default 0.5. + map_type (str): calculation method of mean average + precision, currently support '11point' and + 'integral'. Default '11point'. + is_bbox_normalized (bool): whther bounding boxes + is normalized to range[0, 1]. Default False. + evaluate_difficult (bool): whether to evaluate + difficult bounding boxes. Default False. + """ + + def __init__(self, + class_num, + overlap_thresh=0.5, + map_type='11point', + is_bbox_normalized=False, + evaluate_difficult=False): + self.class_num = class_num + self.overlap_thresh = overlap_thresh + assert map_type in ['11point', 'integral'], \ + "map_type currently only support '11point' "\ + "and 'integral'" + self.map_type = map_type + self.is_bbox_normalized = is_bbox_normalized + self.evaluate_difficult = evaluate_difficult + self.reset() + + def update(self, bbox, gt_box, gt_label, difficult=None): + """ + Update metric statics from given prediction and ground + truth infomations. + """ + if difficult is None: + difficult = np.zeros_like(gt_label) + + # record class gt count + for gtl, diff in zip(gt_label, difficult): + if self.evaluate_difficult or int(diff) == 0: + self.class_gt_counts[int(np.array(gtl))] += 1 + + # record class score positive + visited = [False] * len(gt_label) + for b in bbox: + label, score, xmin, ymin, xmax, ymax = b.tolist() + pred = [xmin, ymin, xmax, ymax] + max_idx = -1 + max_overlap = -1.0 + for i, gl in enumerate(gt_label): + if int(gl) == int(label): + overlap = jaccard_overlap(pred, gt_box[i], + self.is_bbox_normalized) + if overlap > max_overlap: + max_overlap = overlap + max_idx = i + + if max_overlap > self.overlap_thresh: + if self.evaluate_difficult or \ + int(np.array(difficult[max_idx])) == 0: + if not visited[max_idx]: + self.class_score_poss[int(label)].append([score, 1.0]) + visited[max_idx] = True + else: + self.class_score_poss[int(label)].append([score, 0.0]) + else: + self.class_score_poss[int(label)].append([score, 0.0]) + + def reset(self): + """ + Reset metric statics + """ + self.class_score_poss = [[] for _ in range(self.class_num)] + self.class_gt_counts = [0] * self.class_num + self.mAP = None + + def accumulate(self): + """ + Accumulate metric results and calculate mAP + """ + mAP = 0. + valid_cnt = 0 + for score_pos, count in zip(self.class_score_poss, + self.class_gt_counts): + if count == 0: continue + if len(score_pos) == 0: + valid_cnt += 1 + continue + + accum_tp_list, accum_fp_list = \ + self._get_tp_fp_accum(score_pos) + precision = [] + recall = [] + for ac_tp, ac_fp in zip(accum_tp_list, accum_fp_list): + precision.append(float(ac_tp) / (ac_tp + ac_fp)) + recall.append(float(ac_tp) / count) + + if self.map_type == '11point': + max_precisions = [0.] * 11 + start_idx = len(precision) - 1 + for j in range(10, -1, -1): + for i in range(start_idx, -1, -1): + if recall[i] < float(j) / 10.: + start_idx = i + if j > 0: + max_precisions[j - 1] = max_precisions[j] + break + else: + if max_precisions[j] < precision[i]: + max_precisions[j] = precision[i] + mAP += sum(max_precisions) / 11. + valid_cnt += 1 + elif self.map_type == 'integral': + import math + ap = 0. + prev_recall = 0. + for i in range(len(precision)): + recall_gap = math.fabs(recall[i] - prev_recall) + if recall_gap > 1e-6: + ap += precision[i] * recall_gap + prev_recall = recall[i] + mAP += ap + valid_cnt += 1 + else: + logger.error("Unspported mAP type {}".format(self.map_type)) + sys.exit(1) + + self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP + + def get_map(self): + """ + Get mAP result + """ + if self.mAP is None: + logger.error("mAP is not calculated.") + return self.mAP + + def _get_tp_fp_accum(self, score_pos_list): + """ + Calculate accumulating true/false positive results from + [score, pos] records + """ + sorted_list = sorted(score_pos_list, key=lambda s: s[0], reverse=True) + accum_tp = 0 + accum_fp = 0 + accum_tp_list = [] + accum_fp_list = [] + for (score, pos) in sorted_list: + accum_tp += int(pos) + accum_tp_list.append(accum_tp) + accum_fp += 1 - int(pos) + accum_fp_list.append(accum_fp) + return accum_tp_list, accum_fp_list diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/oid_eval.py b/VisualFL/depends/PaddleDetection/ppdet/utils/oid_eval.py new file mode 100755 index 000000000..0a5a0c534 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/oid_eval.py @@ -0,0 +1,543 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import sys +import numpy as np + +from .coco_eval import bbox2out + +import logging +logger = logging.getLogger(__name__) + +__all__ = ['bbox2out', 'get_category_info'] + + +def get_category_info(anno_file=None, + with_background=True, + use_default_label=False): + clsid2catid = {k: k for k in range(1, 501)} + + catid2name = { + 0: "background", + 1: "Infant bed", + 2: "Rose", + 3: "Flag", + 4: "Flashlight", + 5: "Sea turtle", + 6: "Camera", + 7: "Animal", + 8: "Glove", + 9: "Crocodile", + 10: "Cattle", + 11: "House", + 12: "Guacamole", + 13: "Penguin", + 14: "Vehicle registration plate", + 15: "Bench", + 16: "Ladybug", + 17: "Human nose", + 18: "Watermelon", + 19: "Flute", + 20: "Butterfly", + 21: "Washing machine", + 22: "Raccoon", + 23: "Segway", + 24: "Taco", + 25: "Jellyfish", + 26: "Cake", + 27: "Pen", + 28: "Cannon", + 29: "Bread", + 30: "Tree", + 31: "Shellfish", + 32: "Bed", + 33: "Hamster", + 34: "Hat", + 35: "Toaster", + 36: "Sombrero", + 37: "Tiara", + 38: "Bowl", + 39: "Dragonfly", + 40: "Moths and butterflies", + 41: "Antelope", + 42: "Vegetable", + 43: "Torch", + 44: "Building", + 45: "Power plugs and sockets", + 46: "Blender", + 47: "Billiard table", + 48: "Cutting board", + 49: "Bronze sculpture", + 50: "Turtle", + 51: "Broccoli", + 52: "Tiger", + 53: "Mirror", + 54: "Bear", + 55: "Zucchini", + 56: "Dress", + 57: "Volleyball", + 58: "Guitar", + 59: "Reptile", + 60: "Golf cart", + 61: "Tart", + 62: "Fedora", + 63: "Carnivore", + 64: "Car", + 65: "Lighthouse", + 66: "Coffeemaker", + 67: "Food processor", + 68: "Truck", + 69: "Bookcase", + 70: "Surfboard", + 71: "Footwear", + 72: "Bench", + 73: "Necklace", + 74: "Flower", + 75: "Radish", + 76: "Marine mammal", + 77: "Frying pan", + 78: "Tap", + 79: "Peach", + 80: "Knife", + 81: "Handbag", + 82: "Laptop", + 83: "Tent", + 84: "Ambulance", + 85: "Christmas tree", + 86: "Eagle", + 87: "Limousine", + 88: "Kitchen & dining room table", + 89: "Polar bear", + 90: "Tower", + 91: "Football", + 92: "Willow", + 93: "Human head", + 94: "Stop sign", + 95: "Banana", + 96: "Mixer", + 97: "Binoculars", + 98: "Dessert", + 99: "Bee", + 100: "Chair", + 101: "Wood-burning stove", + 102: "Flowerpot", + 103: "Beaker", + 104: "Oyster", + 105: "Woodpecker", + 106: "Harp", + 107: "Bathtub", + 108: "Wall clock", + 109: "Sports uniform", + 110: "Rhinoceros", + 111: "Beehive", + 112: "Cupboard", + 113: "Chicken", + 114: "Man", + 115: "Blue jay", + 116: "Cucumber", + 117: "Balloon", + 118: "Kite", + 119: "Fireplace", + 120: "Lantern", + 121: "Missile", + 122: "Book", + 123: "Spoon", + 124: "Grapefruit", + 125: "Squirrel", + 126: "Orange", + 127: "Coat", + 128: "Punching bag", + 129: "Zebra", + 130: "Billboard", + 131: "Bicycle", + 132: "Door handle", + 133: "Mechanical fan", + 134: "Ring binder", + 135: "Table", + 136: "Parrot", + 137: "Sock", + 138: "Vase", + 139: "Weapon", + 140: "Shotgun", + 141: "Glasses", + 142: "Seahorse", + 143: "Belt", + 144: "Watercraft", + 145: "Window", + 146: "Giraffe", + 147: "Lion", + 148: "Tire", + 149: "Vehicle", + 150: "Canoe", + 151: "Tie", + 152: "Shelf", + 153: "Picture frame", + 154: "Printer", + 155: "Human leg", + 156: "Boat", + 157: "Slow cooker", + 158: "Croissant", + 159: "Candle", + 160: "Pancake", + 161: "Pillow", + 162: "Coin", + 163: "Stretcher", + 164: "Sandal", + 165: "Woman", + 166: "Stairs", + 167: "Harpsichord", + 168: "Stool", + 169: "Bus", + 170: "Suitcase", + 171: "Human mouth", + 172: "Juice", + 173: "Skull", + 174: "Door", + 175: "Violin", + 176: "Chopsticks", + 177: "Digital clock", + 178: "Sunflower", + 179: "Leopard", + 180: "Bell pepper", + 181: "Harbor seal", + 182: "Snake", + 183: "Sewing machine", + 184: "Goose", + 185: "Helicopter", + 186: "Seat belt", + 187: "Coffee cup", + 188: "Microwave oven", + 189: "Hot dog", + 190: "Countertop", + 191: "Serving tray", + 192: "Dog bed", + 193: "Beer", + 194: "Sunglasses", + 195: "Golf ball", + 196: "Waffle", + 197: "Palm tree", + 198: "Trumpet", + 199: "Ruler", + 200: "Helmet", + 201: "Ladder", + 202: "Office building", + 203: "Tablet computer", + 204: "Toilet paper", + 205: "Pomegranate", + 206: "Skirt", + 207: "Gas stove", + 208: "Cookie", + 209: "Cart", + 210: "Raven", + 211: "Egg", + 212: "Burrito", + 213: "Goat", + 214: "Kitchen knife", + 215: "Skateboard", + 216: "Salt and pepper shakers", + 217: "Lynx", + 218: "Boot", + 219: "Platter", + 220: "Ski", + 221: "Swimwear", + 222: "Swimming pool", + 223: "Drinking straw", + 224: "Wrench", + 225: "Drum", + 226: "Ant", + 227: "Human ear", + 228: "Headphones", + 229: "Fountain", + 230: "Bird", + 231: "Jeans", + 232: "Television", + 233: "Crab", + 234: "Microphone", + 235: "Home appliance", + 236: "Snowplow", + 237: "Beetle", + 238: "Artichoke", + 239: "Jet ski", + 240: "Stationary bicycle", + 241: "Human hair", + 242: "Brown bear", + 243: "Starfish", + 244: "Fork", + 245: "Lobster", + 246: "Corded phone", + 247: "Drink", + 248: "Saucer", + 249: "Carrot", + 250: "Insect", + 251: "Clock", + 252: "Castle", + 253: "Tennis racket", + 254: "Ceiling fan", + 255: "Asparagus", + 256: "Jaguar", + 257: "Musical instrument", + 258: "Train", + 259: "Cat", + 260: "Rifle", + 261: "Dumbbell", + 262: "Mobile phone", + 263: "Taxi", + 264: "Shower", + 265: "Pitcher", + 266: "Lemon", + 267: "Invertebrate", + 268: "Turkey", + 269: "High heels", + 270: "Bust", + 271: "Elephant", + 272: "Scarf", + 273: "Barrel", + 274: "Trombone", + 275: "Pumpkin", + 276: "Box", + 277: "Tomato", + 278: "Frog", + 279: "Bidet", + 280: "Human face", + 281: "Houseplant", + 282: "Van", + 283: "Shark", + 284: "Ice cream", + 285: "Swim cap", + 286: "Falcon", + 287: "Ostrich", + 288: "Handgun", + 289: "Whiteboard", + 290: "Lizard", + 291: "Pasta", + 292: "Snowmobile", + 293: "Light bulb", + 294: "Window blind", + 295: "Muffin", + 296: "Pretzel", + 297: "Computer monitor", + 298: "Horn", + 299: "Furniture", + 300: "Sandwich", + 301: "Fox", + 302: "Convenience store", + 303: "Fish", + 304: "Fruit", + 305: "Earrings", + 306: "Curtain", + 307: "Grape", + 308: "Sofa bed", + 309: "Horse", + 310: "Luggage and bags", + 311: "Desk", + 312: "Crutch", + 313: "Bicycle helmet", + 314: "Tick", + 315: "Airplane", + 316: "Canary", + 317: "Spatula", + 318: "Watch", + 319: "Lily", + 320: "Kitchen appliance", + 321: "Filing cabinet", + 322: "Aircraft", + 323: "Cake stand", + 324: "Candy", + 325: "Sink", + 326: "Mouse", + 327: "Wine", + 328: "Wheelchair", + 329: "Goldfish", + 330: "Refrigerator", + 331: "French fries", + 332: "Drawer", + 333: "Treadmill", + 334: "Picnic basket", + 335: "Dice", + 336: "Cabbage", + 337: "Football helmet", + 338: "Pig", + 339: "Person", + 340: "Shorts", + 341: "Gondola", + 342: "Honeycomb", + 343: "Doughnut", + 344: "Chest of drawers", + 345: "Land vehicle", + 346: "Bat", + 347: "Monkey", + 348: "Dagger", + 349: "Tableware", + 350: "Human foot", + 351: "Mug", + 352: "Alarm clock", + 353: "Pressure cooker", + 354: "Human hand", + 355: "Tortoise", + 356: "Baseball glove", + 357: "Sword", + 358: "Pear", + 359: "Miniskirt", + 360: "Traffic sign", + 361: "Girl", + 362: "Roller skates", + 363: "Dinosaur", + 364: "Porch", + 365: "Human beard", + 366: "Submarine sandwich", + 367: "Screwdriver", + 368: "Strawberry", + 369: "Wine glass", + 370: "Seafood", + 371: "Racket", + 372: "Wheel", + 373: "Sea lion", + 374: "Toy", + 375: "Tea", + 376: "Tennis ball", + 377: "Waste container", + 378: "Mule", + 379: "Cricket ball", + 380: "Pineapple", + 381: "Coconut", + 382: "Doll", + 383: "Coffee table", + 384: "Snowman", + 385: "Lavender", + 386: "Shrimp", + 387: "Maple", + 388: "Cowboy hat", + 389: "Goggles", + 390: "Rugby ball", + 391: "Caterpillar", + 392: "Poster", + 393: "Rocket", + 394: "Organ", + 395: "Saxophone", + 396: "Traffic light", + 397: "Cocktail", + 398: "Plastic bag", + 399: "Squash", + 400: "Mushroom", + 401: "Hamburger", + 402: "Light switch", + 403: "Parachute", + 404: "Teddy bear", + 405: "Winter melon", + 406: "Deer", + 407: "Musical keyboard", + 408: "Plumbing fixture", + 409: "Scoreboard", + 410: "Baseball bat", + 411: "Envelope", + 412: "Adhesive tape", + 413: "Briefcase", + 414: "Paddle", + 415: "Bow and arrow", + 416: "Telephone", + 417: "Sheep", + 418: "Jacket", + 419: "Boy", + 420: "Pizza", + 421: "Otter", + 422: "Office supplies", + 423: "Couch", + 424: "Cello", + 425: "Bull", + 426: "Camel", + 427: "Ball", + 428: "Duck", + 429: "Whale", + 430: "Shirt", + 431: "Tank", + 432: "Motorcycle", + 433: "Accordion", + 434: "Owl", + 435: "Porcupine", + 436: "Sun hat", + 437: "Nail", + 438: "Scissors", + 439: "Swan", + 440: "Lamp", + 441: "Crown", + 442: "Piano", + 443: "Sculpture", + 444: "Cheetah", + 445: "Oboe", + 446: "Tin can", + 447: "Mango", + 448: "Tripod", + 449: "Oven", + 450: "Mouse", + 451: "Barge", + 452: "Coffee", + 453: "Snowboard", + 454: "Common fig", + 455: "Salad", + 456: "Marine invertebrates", + 457: "Umbrella", + 458: "Kangaroo", + 459: "Human arm", + 460: "Measuring cup", + 461: "Snail", + 462: "Loveseat", + 463: "Suit", + 464: "Teapot", + 465: "Bottle", + 466: "Alpaca", + 467: "Kettle", + 468: "Trousers", + 469: "Popcorn", + 470: "Centipede", + 471: "Spider", + 472: "Sparrow", + 473: "Plate", + 474: "Bagel", + 475: "Personal care", + 476: "Apple", + 477: "Brassiere", + 478: "Bathroom cabinet", + 479: "studio couch", + 480: "Computer keyboard", + 481: "Table tennis racket", + 482: "Sushi", + 483: "Cabinetry", + 484: "Street light", + 485: "Towel", + 486: "Nightstand", + 487: "Rabbit", + 488: "Dolphin", + 489: "Dog", + 490: "Jug", + 491: "Wok", + 492: "Fire hydrant", + 493: "Human eye", + 494: "Skyscraper", + 495: "Backpack", + 496: "Potato", + 497: "Paper towel", + 498: "Lifejacket", + 499: "Bicycle wheel", + 500: "Toilet", + } + + if not with_background: + clsid2catid = {k - 1: v for k, v in clsid2catid.items()} + return clsid2catid, catid2name diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/post_process.py b/VisualFL/depends/PaddleDetection/ppdet/utils/post_process.py new file mode 100755 index 000000000..cf2519983 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/post_process.py @@ -0,0 +1,327 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import numpy as np +import cv2 +import paddle.fluid as fluid + +__all__ = ['nms'] + +logger = logging.getLogger(__name__) + + +def box_flip(boxes, im_shape): + im_width = im_shape[0][1] + flipped_boxes = boxes.copy() + + flipped_boxes[:, 0::4] = im_width - boxes[:, 2::4] - 1 + flipped_boxes[:, 2::4] = im_width - boxes[:, 0::4] - 1 + return flipped_boxes + + +def nms(dets, thresh): + """Apply classic DPM-style greedy NMS.""" + if dets.shape[0] == 0: + return dets[[], :] + scores = dets[:, 0] + x1 = dets[:, 1] + y1 = dets[:, 2] + x2 = dets[:, 3] + y2 = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + ndets = dets.shape[0] + suppressed = np.zeros((ndets), dtype=np.int) + + # nominal indices + # _i, _j + # sorted indices + # i, j + # temp variables for box i's (the box currently under consideration) + # ix1, iy1, ix2, iy2, iarea + + # variables for computing overlap with box j (lower scoring box) + # xx1, yy1, xx2, yy2 + # w, h + # inter, ovr + + for _i in range(ndets): + i = order[_i] + if suppressed[i] == 1: + continue + ix1 = x1[i] + iy1 = y1[i] + ix2 = x2[i] + iy2 = y2[i] + iarea = areas[i] + for _j in range(_i + 1, ndets): + j = order[_j] + if suppressed[j] == 1: + continue + xx1 = max(ix1, x1[j]) + yy1 = max(iy1, y1[j]) + xx2 = min(ix2, x2[j]) + yy2 = min(iy2, y2[j]) + w = max(0.0, xx2 - xx1 + 1) + h = max(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (iarea + areas[j] - inter) + if ovr >= thresh: + suppressed[j] = 1 + keep = np.where(suppressed == 0)[0] + dets = dets[keep, :] + return dets + + +def soft_nms(dets, sigma, thres): + dets_final = [] + while len(dets) > 0: + maxpos = np.argmax(dets[:, 0]) + dets_final.append(dets[maxpos].copy()) + ts, tx1, ty1, tx2, ty2 = dets[maxpos] + scores = dets[:, 0] + # force remove bbox at maxpos + scores[maxpos] = -1 + x1 = dets[:, 1] + y1 = dets[:, 2] + x2 = dets[:, 3] + y2 = dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + xx1 = np.maximum(tx1, x1) + yy1 = np.maximum(ty1, y1) + xx2 = np.minimum(tx2, x2) + yy2 = np.minimum(ty2, y2) + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas + areas[maxpos] - inter) + weight = np.exp(-(ovr * ovr) / sigma) + scores = scores * weight + idx_keep = np.where(scores >= thres) + dets[:, 0] = scores + dets = dets[idx_keep] + dets_final = np.array(dets_final).reshape(-1, 5) + return dets_final + + +def bbox_area(box): + w = box[2] - box[0] + 1 + h = box[3] - box[1] + 1 + return w * h + + +def bbox_overlaps(x, y): + N = x.shape[0] + K = y.shape[0] + overlaps = np.zeros((N, K), dtype=np.float32) + for k in range(K): + y_area = bbox_area(y[k]) + for n in range(N): + iw = min(x[n, 2], y[k, 2]) - max(x[n, 0], y[k, 0]) + 1 + if iw > 0: + ih = min(x[n, 3], y[k, 3]) - max(x[n, 1], y[k, 1]) + 1 + if ih > 0: + x_area = bbox_area(x[n]) + ua = x_area + y_area - iw * ih + overlaps[n, k] = iw * ih / ua + return overlaps + + +def box_voting(nms_dets, dets, vote_thresh): + top_dets = nms_dets.copy() + top_boxes = nms_dets[:, 1:] + all_boxes = dets[:, 1:] + all_scores = dets[:, 0] + top_to_all_overlaps = bbox_overlaps(top_boxes, all_boxes) + for k in range(nms_dets.shape[0]): + inds_to_vote = np.where(top_to_all_overlaps[k] >= vote_thresh)[0] + boxes_to_vote = all_boxes[inds_to_vote, :] + ws = all_scores[inds_to_vote] + top_dets[k, 1:] = np.average(boxes_to_vote, axis=0, weights=ws) + + return top_dets + + +def get_nms_result(boxes, + scores, + config, + num_classes, + background_label=0, + labels=None): + has_labels = labels is not None + cls_boxes = [[] for _ in range(num_classes)] + start_idx = 1 if background_label == 0 else 0 + for j in range(start_idx, num_classes): + inds = np.where(labels == j)[0] if has_labels else np.where( + scores[:, j] > config['score_thresh'])[0] + scores_j = scores[inds] if has_labels else scores[inds, j] + boxes_j = boxes[inds, :] if has_labels else boxes[inds, j * 4:(j + 1) * + 4] + dets_j = np.hstack((scores_j[:, np.newaxis], boxes_j)).astype( + np.float32, copy=False) + if config.get('use_soft_nms', False): + nms_dets = soft_nms(dets_j, config['sigma'], config['nms_thresh']) + else: + nms_dets = nms(dets_j, config['nms_thresh']) + if config.get('enable_voting', False): + nms_dets = box_voting(nms_dets, dets_j, config['vote_thresh']) + #add labels + label = np.array([j for _ in range(len(nms_dets))]) + nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype( + np.float32, copy=False) + cls_boxes[j] = nms_dets + # Limit to max_per_image detections **over all classes** + image_scores = np.hstack( + [cls_boxes[j][:, 1] for j in range(start_idx, num_classes)]) + if len(image_scores) > config['detections_per_im']: + image_thresh = np.sort(image_scores)[-config['detections_per_im']] + for j in range(start_idx, num_classes): + keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0] + cls_boxes[j] = cls_boxes[j][keep, :] + + im_results = np.vstack( + [cls_boxes[j] for j in range(start_idx, num_classes)]) + return im_results + + +def mstest_box_post_process(result, config, num_classes): + """ + Multi-scale Test + Only available for batch_size=1 now. + """ + post_bbox = {} + use_flip = False + ms_boxes = [] + ms_scores = [] + im_shape = result['im_shape'][0] + for k in result.keys(): + if 'bbox' in k: + boxes = result[k][0] + boxes = np.reshape(boxes, (-1, 4 * num_classes)) + scores = result['score' + k[4:]][0] + if 'flip' in k: + boxes = box_flip(boxes, im_shape) + use_flip = True + ms_boxes.append(boxes) + ms_scores.append(scores) + + ms_boxes = np.concatenate(ms_boxes) + ms_scores = np.concatenate(ms_scores) + bbox_pred = get_nms_result(ms_boxes, ms_scores, config, num_classes) + post_bbox.update({'bbox': (bbox_pred, [[len(bbox_pred)]])}) + if use_flip: + bbox = bbox_pred[:, 2:] + bbox_flip = np.append( + bbox_pred[:, :2], box_flip(bbox, im_shape), axis=1) + post_bbox.update({'bbox_flip': (bbox_flip, [[len(bbox_flip)]])}) + return post_bbox + + +def mstest_mask_post_process(result, cfg): + mask_list = [] + im_shape = result['im_shape'][0] + M = cfg.FPNRoIAlign['mask_resolution'] + for k in result.keys(): + if 'mask' in k: + masks = result[k][0] + if len(masks.shape) != 4: + masks = np.zeros((0, M, M)) + mask_list.append(masks) + continue + if 'flip' in k: + masks = masks[:, :, :, ::-1] + mask_list.append(masks) + + mask_pred = np.mean(mask_list, axis=0) + return {'mask': (mask_pred, [[len(mask_pred)]])} + + +def mask_encode(results, resolution, thresh_binarize=0.5): + import pycocotools.mask as mask_util + from ppdet.utils.coco_eval import expand_boxes + scale = (resolution + 2.0) / resolution + bboxes = results['bbox'][0] + masks = results['mask'][0] + lengths = results['mask'][1][0] + im_shapes = results['im_shape'][0] + segms = [] + if bboxes.shape == (1, 1) or bboxes is None: + return segms + if len(bboxes.tolist()) == 0: + return segms + + s = 0 + # for each sample + for i in range(len(lengths)): + num = lengths[i] + im_shape = im_shapes[i] + + bbox = bboxes[s:s + num][:, 2:] + clsid_scores = bboxes[s:s + num][:, 0:2] + mask = masks[s:s + num] + s += num + + im_h = int(im_shape[0]) + im_w = int(im_shape[1]) + expand_bbox = expand_boxes(bbox, scale) + expand_bbox = expand_bbox.astype(np.int32) + padded_mask = np.zeros( + (resolution + 2, resolution + 2), dtype=np.float32) + + for j in range(num): + xmin, ymin, xmax, ymax = expand_bbox[j].tolist() + clsid, score = clsid_scores[j].tolist() + clsid = int(clsid) + padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :] + + w = xmax - xmin + 1 + h = ymax - ymin + 1 + w = np.maximum(w, 1) + h = np.maximum(h, 1) + resized_mask = cv2.resize(padded_mask, (w, h)) + resized_mask = np.array( + resized_mask > thresh_binarize, dtype=np.uint8) + im_mask = np.zeros((im_h, im_w), dtype=np.uint8) + + x0 = min(max(xmin, 0), im_w) + x1 = min(max(xmax + 1, 0), im_w) + y0 = min(max(ymin, 0), im_h) + y1 = min(max(ymax + 1, 0), im_h) + + im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), ( + x0 - xmin):(x1 - xmin)] + segm = mask_util.encode( + np.array( + im_mask[:, :, np.newaxis], order='F'))[0] + segms.append(segm) + return segms + + +def corner_post_process(results, config, num_classes): + detections = results['bbox'][0] + keep_inds = (detections[:, 1] > -1) + detections = detections[keep_inds] + labels = detections[:, 0] + scores = detections[:, 1] + boxes = detections[:, 2:6] + cls_boxes = get_nms_result( + boxes, scores, config, num_classes, background_label=-1, labels=labels) + results.update({'bbox': (cls_boxes, [[len(cls_boxes)]])}) diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/stats.py b/VisualFL/depends/PaddleDetection/ppdet/utils/stats.py new file mode 100755 index 000000000..4d7e36bab --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/stats.py @@ -0,0 +1,65 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import numpy as np +import datetime + +__all__ = ['TrainingStats', 'Time'] + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size): + self.deque = collections.deque(maxlen=window_size) + + def add_value(self, value): + self.deque.append(value) + + def get_median_value(self): + return np.median(self.deque) + + +def Time(): + return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + + +class TrainingStats(object): + def __init__(self, window_size, stats_keys): + self.smoothed_losses_and_metrics = { + key: SmoothedValue(window_size) + for key in stats_keys + } + + def update(self, stats): + for k, v in self.smoothed_losses_and_metrics.items(): + v.add_value(stats[k]) + + def get(self, extras=None): + stats = collections.OrderedDict() + if extras: + for k, v in extras.items(): + stats[k] = v + for k, v in self.smoothed_losses_and_metrics.items(): + stats[k] = format(v.get_median_value(), '.6f') + + return stats + + def log(self, extras=None): + d = self.get(extras) + strs = ', '.join(str(dict({x: y})).strip('{}') for x, y in d.items()) + return strs diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/visualizer.py b/VisualFL/depends/PaddleDetection/ppdet/utils/visualizer.py new file mode 100755 index 000000000..4e087227b --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/visualizer.py @@ -0,0 +1,195 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +from PIL import Image, ImageDraw +from scipy import ndimage +import cv2 + +from .colormap import colormap + +__all__ = ['visualize_results'] + + +def visualize_results(image, + im_id, + catid2name, + threshold=0.5, + bbox_results=None, + mask_results=None, + segm_results=None, + lmk_results=None): + """ + Visualize bbox and mask results + """ + if mask_results: + image = draw_mask(image, im_id, mask_results, threshold) + if bbox_results: + image = draw_bbox(image, im_id, catid2name, bbox_results, threshold) + if lmk_results: + image = draw_lmk(image, im_id, lmk_results, threshold) + if segm_results: + image = draw_segm(image, im_id, catid2name, segm_results, threshold) + return image + + +def draw_mask(image, im_id, segms, threshold, alpha=0.7): + """ + Draw mask on image + """ + mask_color_id = 0 + w_ratio = .4 + color_list = colormap(rgb=True) + img_array = np.array(image).astype('float32') + for dt in np.array(segms): + if im_id != dt['image_id']: + continue + segm, score = dt['segmentation'], dt['score'] + if score < threshold: + continue + import pycocotools.mask as mask_util + mask = mask_util.decode(segm) * 255 + color_mask = color_list[mask_color_id % len(color_list), 0:3] + mask_color_id += 1 + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(mask) + img_array[idx[0], idx[1], :] *= 1.0 - alpha + img_array[idx[0], idx[1], :] += alpha * color_mask + return Image.fromarray(img_array.astype('uint8')) + + +def draw_segm(image, + im_id, + catid2name, + segms, + threshold, + alpha=0.7, + draw_box=True): + """ + Draw segmentation on image + """ + mask_color_id = 0 + w_ratio = .4 + color_list = colormap(rgb=True) + img_array = np.array(image).astype('float32') + for dt in np.array(segms): + if im_id != dt['image_id']: + continue + segm, score, catid = dt['segmentation'], dt['score'], dt['category_id'] + if score < threshold: + continue + import pycocotools.mask as mask_util + mask = mask_util.decode(segm) * 255 + color_mask = color_list[mask_color_id % len(color_list), 0:3] + mask_color_id += 1 + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(mask) + img_array[idx[0], idx[1], :] *= 1.0 - alpha + img_array[idx[0], idx[1], :] += alpha * color_mask + + if not draw_box: + center_y, center_x = ndimage.measurements.center_of_mass(mask) + label_text = "{}".format(catid2name[catid]) + vis_pos = (max(int(center_x) - 10, 0), int(center_y)) + cv2.putText(img_array, label_text, vis_pos, + cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255)) + else: + mask = mask_util.decode(segm) * 255 + sum_x = np.sum(mask, axis=0) + x = np.where(sum_x > 0.5)[0] + sum_y = np.sum(mask, axis=1) + y = np.where(sum_y > 0.5)[0] + x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1] + cv2.rectangle(img_array, (x0, y0), (x1, y1), + tuple(color_mask.astype('int32').tolist()), 1) + bbox_text = '%s %.2f' % (catid2name[catid], score) + t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0] + cv2.rectangle(img_array, (x0, y0), (x0 + t_size[0], + y0 - t_size[1] - 3), + tuple(color_mask.astype('int32').tolist()), -1) + cv2.putText( + img_array, + bbox_text, (x0, y0 - 2), + cv2.FONT_HERSHEY_SIMPLEX, + 0.3, (0, 0, 0), + 1, + lineType=cv2.LINE_AA) + + return Image.fromarray(img_array.astype('uint8')) + + +def draw_bbox(image, im_id, catid2name, bboxes, threshold): + """ + Draw bbox on image + """ + draw = ImageDraw.Draw(image) + + catid2color = {} + color_list = colormap(rgb=True)[:40] + for dt in np.array(bboxes): + if im_id != dt['image_id']: + continue + catid, bbox, score = dt['category_id'], dt['bbox'], dt['score'] + if score < threshold: + continue + + xmin, ymin, w, h = bbox + xmax = xmin + w + ymax = ymin + h + + if catid not in catid2color: + idx = np.random.randint(len(color_list)) + catid2color[catid] = color_list[idx] + color = tuple(catid2color[catid]) + + # draw bbox + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=2, + fill=color) + + # draw label + text = "{} {:.2f}".format(catid2name[catid], score) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + + return image + + +def draw_lmk(image, im_id, lmk_results, threshold): + draw = ImageDraw.Draw(image) + catid2color = {} + color_list = colormap(rgb=True)[:40] + for dt in np.array(lmk_results): + lmk_decode, score = dt['landmark'], dt['score'] + if im_id != dt['image_id']: + continue + if score < threshold: + continue + for j in range(5): + x1 = int(round(lmk_decode[2 * j])) + y1 = int(round(lmk_decode[2 * j + 1])) + draw.ellipse( + (x1, y1, x1 + 5, y1 + 5), fill='green', outline='green') + return image diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/voc_eval.py b/VisualFL/depends/PaddleDetection/ppdet/utils/voc_eval.py new file mode 100755 index 000000000..4ffd91260 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/voc_eval.py @@ -0,0 +1,184 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import sys +import numpy as np + +from ..data.source.voc import pascalvoc_label +from .map_utils import DetectionMAP +from .coco_eval import bbox2out + +import logging +logger = logging.getLogger(__name__) + +__all__ = ['bbox_eval', 'bbox2out', 'get_category_info'] + + +def bbox_eval(results, + class_num, + overlap_thresh=0.5, + map_type='11point', + is_bbox_normalized=False, + evaluate_difficult=False): + """ + Bounding box evaluation for VOC dataset + + Args: + results (list): prediction bounding box results. + class_num (int): evaluation class number. + overlap_thresh (float): the postive threshold of + bbox overlap + map_type (string): method for mAP calcualtion, + can only be '11point' or 'integral' + is_bbox_normalized (bool): whether bbox is normalized + to range [0, 1]. + evaluate_difficult (bool): whether to evaluate + difficult gt bbox. + """ + assert 'bbox' in results[0] + logger.info("Start evaluate...") + + detection_map = DetectionMAP( + class_num=class_num, + overlap_thresh=overlap_thresh, + map_type=map_type, + is_bbox_normalized=is_bbox_normalized, + evaluate_difficult=evaluate_difficult) + + for t in results: + bboxes = t['bbox'][0] + bbox_lengths = t['bbox'][1][0] + + if bboxes.shape == (1, 1) or bboxes is None: + continue + gt_boxes = t['gt_bbox'][0] + gt_labels = t['gt_class'][0] + difficults = t['is_difficult'][0] if not evaluate_difficult \ + else None + + if len(t['gt_bbox'][1]) == 0: + # gt_bbox, gt_class, difficult read as zero padded Tensor + bbox_idx = 0 + for i in range(len(gt_boxes)): + gt_box = gt_boxes[i] + gt_label = gt_labels[i] + difficult = None if difficults is None \ + else difficults[i] + bbox_num = bbox_lengths[i] + bbox = bboxes[bbox_idx:bbox_idx + bbox_num] + gt_box, gt_label, difficult = prune_zero_padding( + gt_box, gt_label, difficult) + detection_map.update(bbox, gt_box, gt_label, difficult) + bbox_idx += bbox_num + else: + # gt_box, gt_label, difficult read as LoDTensor + gt_box_lengths = t['gt_bbox'][1][0] + bbox_idx = 0 + gt_box_idx = 0 + for i in range(len(bbox_lengths)): + bbox_num = bbox_lengths[i] + gt_box_num = gt_box_lengths[i] + bbox = bboxes[bbox_idx:bbox_idx + bbox_num] + gt_box = gt_boxes[gt_box_idx:gt_box_idx + gt_box_num] + gt_label = gt_labels[gt_box_idx:gt_box_idx + gt_box_num] + difficult = None if difficults is None else \ + difficults[gt_box_idx: gt_box_idx + gt_box_num] + detection_map.update(bbox, gt_box, gt_label, difficult) + bbox_idx += bbox_num + gt_box_idx += gt_box_num + + logger.info("Accumulating evaluatation results...") + detection_map.accumulate() + map_stat = 100. * detection_map.get_map() + logger.info("mAP({:.2f}, {}) = {:.2f}%".format(overlap_thresh, map_type, + map_stat)) + return map_stat + + +def prune_zero_padding(gt_box, gt_label, difficult=None): + valid_cnt = 0 + for i in range(len(gt_box)): + if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \ + gt_box[i, 2] == 0 and gt_box[i, 3] == 0: + break + valid_cnt += 1 + return (gt_box[:valid_cnt], gt_label[:valid_cnt], difficult[:valid_cnt] + if difficult is not None else None) + + +def get_category_info(anno_file=None, + with_background=True, + use_default_label=False): + if use_default_label or anno_file is None \ + or not os.path.exists(anno_file): + logger.info("Not found annotation file {}, load " + "voc2012 categories.".format(anno_file)) + return vocall_category_info(with_background) + else: + logger.info("Load categories from {}".format(anno_file)) + return get_category_info_from_anno(anno_file, with_background) + + +def get_category_info_from_anno(anno_file, with_background=True): + """ + Get class id to category id map and category id + to category name map from annotation file. + + Args: + anno_file (str): annotation file path + with_background (bool, default True): + whether load background as class 0. + """ + cats = [] + with open(anno_file) as f: + for line in f.readlines(): + cats.append(line.strip()) + + if cats[0] != 'background' and with_background: + cats.insert(0, 'background') + if cats[0] == 'background' and not with_background: + cats = cats[1:] + + clsid2catid = {i: i for i in range(len(cats))} + catid2name = {i: name for i, name in enumerate(cats)} + + return clsid2catid, catid2name + + +def vocall_category_info(with_background=True): + """ + Get class id to category id map and category id + to category name map of mixup voc dataset + + Args: + with_background (bool, default True): + whether load background as class 0. + """ + label_map = pascalvoc_label(with_background) + label_map = sorted(label_map.items(), key=lambda x: x[1]) + cats = [l[0] for l in label_map] + + if with_background: + cats.insert(0, 'background') + + clsid2catid = {i: i for i in range(len(cats))} + catid2name = {i: name for i, name in enumerate(cats)} + + return clsid2catid, catid2name diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/voc_utils.py b/VisualFL/depends/PaddleDetection/ppdet/utils/voc_utils.py new file mode 100755 index 000000000..69fec1c16 --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/voc_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path as osp +import re +import random +import shutil + +__all__ = ['create_list'] + + +def create_list(year_dirs, output_dir): + """ + create following list: + 1. trainval.txt + 2. test.txt + """ + trainval_list = [] + test_list = [] + for year_dir in year_dirs: + trainval, test = _walk_voc_dir(year_dir, output_dir) + trainval_list.extend(trainval) + test_list.extend(test) + + random.shuffle(trainval_list) + with open(osp.join(output_dir, 'trainval.txt'), 'w') as ftrainval: + for item in trainval_list: + ftrainval.write(item[0] + ' ' + item[1] + '\n') + + with open(osp.join(output_dir, 'test.txt'), 'w') as fval: + ct = 0 + for item in test_list: + ct += 1 + fval.write(item[0] + ' ' + item[1] + '\n') + + +def _walk_voc_dir(year_dir, output_dir): + filelist_dir = osp.join(year_dir, 'ImageSets/Main') + annotation_dir = osp.join(year_dir, 'Annotations') + img_dir = osp.join(year_dir, 'JPEGImages') + trainval_list = [] + test_list = [] + added = set() + + img_dict = {} + for img_file in os.listdir(img_dir): + img_dict[img_file.split('.')[0]] = img_file + + for _, _, files in os.walk(filelist_dir): + for fname in files: + img_ann_list = [] + if re.match('trainval\.txt', fname): + img_ann_list = trainval_list + elif re.match('test\.txt', fname): + img_ann_list = test_list + else: + continue + fpath = osp.join(filelist_dir, fname) + for line in open(fpath): + name_prefix = line.strip().split()[0] + if name_prefix in added: + continue + added.add(name_prefix) + ann_path = osp.join( + osp.relpath(annotation_dir, output_dir), + name_prefix + '.xml') + img_path = osp.join( + osp.relpath(img_dir, output_dir), img_dict[name_prefix]) + img_ann_list.append((img_path, ann_path)) + + return trainval_list, test_list diff --git a/VisualFL/depends/PaddleDetection/ppdet/utils/widerface_eval_utils.py b/VisualFL/depends/PaddleDetection/ppdet/utils/widerface_eval_utils.py new file mode 100755 index 000000000..e7447e8fe --- /dev/null +++ b/VisualFL/depends/PaddleDetection/ppdet/utils/widerface_eval_utils.py @@ -0,0 +1,284 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from ppdet.data.source.widerface import widerface_label +from ppdet.utils.coco_eval import bbox2out + +import logging +logger = logging.getLogger(__name__) + +__all__ = [ + 'get_shrink', 'bbox_vote', 'save_widerface_bboxes', 'save_fddb_bboxes', + 'to_chw_bgr', 'bbox2out', 'get_category_info', 'lmk2out' +] + + +def to_chw_bgr(image): + """ + Transpose image from HWC to CHW and from RBG to BGR. + Args: + image (np.array): an image with HWC and RBG layout. + """ + # HWC to CHW + if len(image.shape) == 3: + image = np.swapaxes(image, 1, 2) + image = np.swapaxes(image, 1, 0) + # RBG to BGR + image = image[[2, 1, 0], :, :] + return image + + +def bbox_vote(det): + order = det[:, 4].ravel().argsort()[::-1] + det = det[order, :] + if det.shape[0] == 0: + dets = np.array([[10, 10, 20, 20, 0.002]]) + det = np.empty(shape=[0, 5]) + while det.shape[0] > 0: + # IOU + area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1) + xx1 = np.maximum(det[0, 0], det[:, 0]) + yy1 = np.maximum(det[0, 1], det[:, 1]) + xx2 = np.minimum(det[0, 2], det[:, 2]) + yy2 = np.minimum(det[0, 3], det[:, 3]) + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + o = inter / (area[0] + area[:] - inter) + + # nms + merge_index = np.where(o >= 0.3)[0] + det_accu = det[merge_index, :] + det = np.delete(det, merge_index, 0) + if merge_index.shape[0] <= 1: + if det.shape[0] == 0: + try: + dets = np.row_stack((dets, det_accu)) + except: + dets = det_accu + continue + det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4)) + max_score = np.max(det_accu[:, 4]) + det_accu_sum = np.zeros((1, 5)) + det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], + axis=0) / np.sum(det_accu[:, -1:]) + det_accu_sum[:, 4] = max_score + try: + dets = np.row_stack((dets, det_accu_sum)) + except: + dets = det_accu_sum + dets = dets[0:750, :] + # Only keep 0.3 or more + keep_index = np.where(dets[:, 4] >= 0.01)[0] + dets = dets[keep_index, :] + return dets + + +def get_shrink(height, width): + """ + Args: + height (int): image height. + width (int): image width. + """ + # avoid out of memory + max_shrink_v1 = (0x7fffffff / 577.0 / (height * width))**0.5 + max_shrink_v2 = ((678 * 1024 * 2.0 * 2.0) / (height * width))**0.5 + + def get_round(x, loc): + str_x = str(x) + if '.' in str_x: + str_before, str_after = str_x.split('.') + len_after = len(str_after) + if len_after >= 3: + str_final = str_before + '.' + str_after[0:loc] + return float(str_final) + else: + return x + + max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3 + if max_shrink >= 1.5 and max_shrink < 2: + max_shrink = max_shrink - 0.1 + elif max_shrink >= 2 and max_shrink < 3: + max_shrink = max_shrink - 0.2 + elif max_shrink >= 3 and max_shrink < 4: + max_shrink = max_shrink - 0.3 + elif max_shrink >= 4 and max_shrink < 5: + max_shrink = max_shrink - 0.4 + elif max_shrink >= 5: + max_shrink = max_shrink - 0.5 + elif max_shrink <= 0.1: + max_shrink = 0.1 + + shrink = max_shrink if max_shrink < 1 else 1 + return shrink, max_shrink + + +def save_widerface_bboxes(image_path, bboxes_scores, output_dir): + image_name = image_path.split('/')[-1] + image_class = image_path.split('/')[-2] + odir = os.path.join(output_dir, image_class) + if not os.path.exists(odir): + os.makedirs(odir) + + ofname = os.path.join(odir, '%s.txt' % (image_name[:-4])) + f = open(ofname, 'w') + f.write('{:s}\n'.format(image_class + '/' + image_name)) + f.write('{:d}\n'.format(bboxes_scores.shape[0])) + for box_score in bboxes_scores: + xmin, ymin, xmax, ymax, score = box_score + f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(xmin, ymin, ( + xmax - xmin + 1), (ymax - ymin + 1), score)) + f.close() + logger.info("The predicted result is saved as {}".format(ofname)) + + +def save_fddb_bboxes(bboxes_scores, + output_dir, + output_fname='pred_fddb_res.txt'): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + predict_file = os.path.join(output_dir, output_fname) + f = open(predict_file, 'w') + for image_path, dets in bboxes_scores.iteritems(): + f.write('{:s}\n'.format(image_path)) + f.write('{:d}\n'.format(dets.shape[0])) + for box_score in dets: + xmin, ymin, xmax, ymax, score = box_score + width, height = xmax - xmin, ymax - ymin + f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n' + .format(xmin, ymin, width, height, score)) + logger.info("The predicted result is saved as {}".format(predict_file)) + return predict_file + + +def get_category_info(anno_file=None, + with_background=True, + use_default_label=False): + if use_default_label or anno_file is None \ + or not os.path.exists(anno_file): + logger.info("Not found annotation file {}, load " + "wider-face categories.".format(anno_file)) + return widerfaceall_category_info(with_background) + else: + logger.info("Load categories from {}".format(anno_file)) + return get_category_info_from_anno(anno_file, with_background) + + +def get_category_info_from_anno(anno_file, with_background=True): + """ + Get class id to category id map and category id + to category name map from annotation file. + Args: + anno_file (str): annotation file path + with_background (bool, default True): + whether load background as class 0. + """ + cats = [] + with open(anno_file) as f: + for line in f.readlines(): + cats.append(line.strip()) + + if cats[0] != 'background' and with_background: + cats.insert(0, 'background') + if cats[0] == 'background' and not with_background: + cats = cats[1:] + + clsid2catid = {i: i for i in range(len(cats))} + catid2name = {i: name for i, name in enumerate(cats)} + + return clsid2catid, catid2name + + +def widerfaceall_category_info(with_background=True): + """ + Get class id to category id map and category id + to category name map of mixup wider_face dataset + + Args: + with_background (bool, default True): + whether load background as class 0. + """ + label_map = widerface_label(with_background) + label_map = sorted(label_map.items(), key=lambda x: x[1]) + cats = [l[0] for l in label_map] + + if with_background: + cats.insert(0, 'background') + + clsid2catid = {i: i for i in range(len(cats))} + catid2name = {i: name for i, name in enumerate(cats)} + + return clsid2catid, catid2name + + +def lmk2out(results, is_bbox_normalized=False): + """ + Args: + results: request a dict, should include: `landmark`, `im_id`, + if is_bbox_normalized=True, also need `im_shape`. + is_bbox_normalized: whether or not landmark is normalized. + """ + xywh_res = [] + for t in results: + bboxes = t['bbox'][0] + lengths = t['bbox'][1][0] + im_ids = np.array(t['im_id'][0]).flatten() + if bboxes.shape == (1, 1) or bboxes is None: + continue + face_index = t['face_index'][0] + prior_box = t['prior_boxes'][0] + predict_lmk = t['landmark'][0] + prior = np.reshape(prior_box, (-1, 4)) + predictlmk = np.reshape(predict_lmk, (-1, 10)) + + k = 0 + for a in range(len(lengths)): + num = lengths[a] + im_id = int(im_ids[a]) + for i in range(num): + score = bboxes[k][1] + theindex = face_index[i][0] + me_prior = prior[theindex, :] + lmk_pred = predictlmk[theindex, :] + prior_w = me_prior[2] - me_prior[0] + prior_h = me_prior[3] - me_prior[1] + prior_w_center = (me_prior[2] + me_prior[0]) / 2 + prior_h_center = (me_prior[3] + me_prior[1]) / 2 + lmk_decode = np.zeros((10)) + for j in [0, 2, 4, 6, 8]: + lmk_decode[j] = lmk_pred[j] * 0.1 * prior_w + prior_w_center + for j in [1, 3, 5, 7, 9]: + lmk_decode[j] = lmk_pred[j] * 0.1 * prior_h + prior_h_center + im_shape = t['im_shape'][0][a].tolist() + image_h, image_w = int(im_shape[0]), int(im_shape[1]) + if is_bbox_normalized: + lmk_decode = lmk_decode * np.array([ + image_w, image_h, image_w, image_h, image_w, image_h, + image_w, image_h, image_w, image_h + ]) + lmk_res = { + 'image_id': im_id, + 'landmark': lmk_decode, + 'score': score, + } + xywh_res.append(lmk_res) + k += 1 + return xywh_res diff --git a/VisualFL/depends/PaddleFL/AUTHORS.md b/VisualFL/depends/PaddleFL/AUTHORS.md new file mode 100755 index 000000000..f22dad93d --- /dev/null +++ b/VisualFL/depends/PaddleFL/AUTHORS.md @@ -0,0 +1,5 @@ +| Github account | name | +|---|---| +| guru4elephant | Daxiang Dong | +| frankwhzhang | Wenhui Zhang | +| qjing666 | Qinghe Jing | diff --git a/VisualFL/depends/PaddleFL/LICENSE b/VisualFL/depends/PaddleFL/LICENSE new file mode 100755 index 000000000..261eeb9e9 --- /dev/null +++ b/VisualFL/depends/PaddleFL/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/__init__.py b/VisualFL/depends/PaddleFL/python/paddle_fl/__init__.py new file mode 100755 index 000000000..9d0531501 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/__init__.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/__init__.py new file mode 100755 index 000000000..9d0531501 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/__init__.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/__init__.py new file mode 100755 index 000000000..33ed0ecf1 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/master/__init__.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/master/__init__.py new file mode 100755 index 000000000..33ed0ecf1 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/master/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/master/fl_job.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/master/fl_job.py new file mode 100755 index 000000000..a76eaccba --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/master/fl_job.py @@ -0,0 +1,173 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import paddle.fluid as fluid + + +class FLJobBase(object): + """ + FLJobBase is fl job base class, responsible for save and load + a federated learning job + """ + + def _save_str_list(self, items, output): + with open(output, "w") as fout: + for item in items: + fout.write(item + "\n") + + def _load_str_list(self, input_file): + res = [] + with open(input_file, "r") as fin: + for line in fin: + res.append(line.strip()) + return res + + def _save_strategy(self, strategy, output_file): + import pickle + + pickle.dump(strategy, open(output_file, "wb")) + + def _save_endpoints(self, endpoints, output_file): + with open(output_file, "w") as fout: + for ep in endpoints: + fout.write(str(ep) + "\n") + + def _load_endpoints(self, input_file): + ep_list = [] + with open(input_file, "r") as fin: + for line in fin: + ep_list.append(line.strip()) + return ep_list + + def _save_program(self, program, output_file): + with open(output_file, "wb") as fout: + fout.write(program.desc.serialize_to_string()) + + def _save_readable_program(self, program, output_file): + with open(output_file, "w") as fout: + fout.write(str(program)) + + def _load_program(self, input_file): + with open(input_file, "rb") as fin: + program_desc_str = fin.read() + return fluid.Program.parse_from_string(program_desc_str) + return None + + +class FLCompileTimeJob(FLJobBase): + """ + FLCompileTimeJob is a container for compile time job in federated learning. + trainer startup programs, trainer main programs and other trainer programs + are in FLCompileTimeJob. Also, server main programs and server startup programs + are in this class. FLCompileTimeJob has server endpoints for debugging as well + """ + + def __init__(self): + self._trainer_startup_programs = [] + self._trainer_recv_programs = [] + self._trainer_main_programs = [] + self._trainer_send_programs = [] + self._server_startup_programs = [] + self._server_main_programs = [] + self._server_endpoints = [] + + def set_strategy(self, strategy): + self._strategy = strategy + + def set_server_endpoints(self, ps_endpoints): + self._server_endpoints = ps_endpoints + + def set_feed_names(self, names): + self._feed_names = names + + def set_feeds(self, feeds): + self._feeds = feeds + + def set_target_names(self, names): + self._target_names = names + + def save(self, folder=None): + server_num = len(self._server_startup_programs) + trainer_num = len(self._trainer_startup_programs) + send_prog_num = len(self._trainer_send_programs) + for i in range(server_num): + server_folder = "%s/server%d" % (folder, i) + os.system("mkdir -p %s" % server_folder) + server_startup = self._server_startup_programs[i] + server_main = self._server_main_programs[i] + self._save_program( + server_startup, "%s/server.startup.program" % server_folder + ) + self._save_program(server_main, "%s/server.main.program" % server_folder) + self._save_readable_program( + server_startup, "%s/server.startup.program.txt" % server_folder + ) + self._save_readable_program( + server_main, "%s/server.main.program.txt" % server_folder + ) + self._save_str_list(self._feed_names, "%s/feed_names" % server_folder) + self._save_str_list(self._target_names, "%s/target_names" % server_folder) + self._save_endpoints(self._server_endpoints, "%s/endpoints" % server_folder) + self._save_strategy(self._strategy, "%s/strategy.pkl" % server_folder) + + for i in range(trainer_num): + trainer_folder = "%s/trainer%d" % (folder, i) + os.system("mkdir -p %s" % trainer_folder) + trainer_startup = self._trainer_startup_programs[i] + trainer_main = self._trainer_main_programs[i] + self._save_program( + trainer_startup, "%s/trainer.startup.program" % trainer_folder + ) + self._save_program(trainer_main, "%s/trainer.main.program" % trainer_folder) + self._save_readable_program( + trainer_startup, "%s/trainer.startup.program.txt" % trainer_folder + ) + self._save_readable_program( + trainer_main, "%s/trainer.main.program.txt" % trainer_folder + ) + self._save_str_list(self._feed_names, "%s/feed_names" % trainer_folder) + self._save_str_list(self._target_names, "%s/target_names" % trainer_folder) + self._save_endpoints( + self._server_endpoints, "%s/endpoints" % trainer_folder + ) + self._save_strategy(self._strategy, "%s/strategy.pkl" % trainer_folder) + + # save feed_variable + import pickle + + with open(f"{trainer_folder}/feeds.pkl", "wb") as f: + pickle.dump(len(self._feeds), f) + for feed in self._feeds: + pickle.dump( + obj={ + "name": feed.name, + "shape": feed.shape, + "dtype": feed.dtype, + "lod_level": feed.lod_level, + }, + file=f, + ) + + for i in range(send_prog_num): + trainer_folder = "%s/trainer%d" % (folder, i) + trainer_send = self._trainer_send_programs[i] + trainer_recv = self._trainer_recv_programs[i] + self._save_program(trainer_send, "%s/trainer.send.program" % trainer_folder) + self._save_program(trainer_recv, "%s/trainer.recv.program" % trainer_folder) + self._save_readable_program( + trainer_send, "%s/trainer.send.program.txt" % trainer_folder + ) + self._save_readable_program( + trainer_recv, "%s/trainer.recv.program.txt" % trainer_folder + ) diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/master/job_generator.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/master/job_generator.py new file mode 100755 index 000000000..df7867da3 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/master/job_generator.py @@ -0,0 +1,332 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import paddle.fluid as fluid +from .fl_job import FLCompileTimeJob + + +class JobGenerator(object): + """ + A JobGenerator is responsible for generating distributed federated + learning configs. Before federated learning job starts, organizations + need to define a deep learning model together to do horizontal federated + learning. + """ + + def __init__(self): + # worker num for federated learning + self._worker_num = 0 + # startup program + self._startup_prog = None + # inner optimizer + self._optimizer = fluid.optimizer.SGD(learning_rate=0.001) + self._feed_names = [] + self._target_names = [] + + def set_optimizer(self, optimizer): + """ + Set optimizer of current job + """ + self._optimizer = optimizer + + def set_losses(self, losses): + """ + Set losses of current job + losses can be a list of loss so that we can do + optimization on multiple losses + """ + self._losses = losses + + def set_startup_program(self, startup=None): + """ + set startup program for user defined program + """ + if startup == None: + startup = fluid.default_startup_program() + self._startup_prog = startup + + def set_infer_feed_and_target_names(self, feed_names, target_names): + if not isinstance(feed_names, list) or not isinstance(target_names, list): + raise ValueError("input should be list in set_infer_feed_and_target_names") + """ + print(feed_names) + print(target_names) + for item in feed_names: + if type(item) != str: + raise ValueError("item in feed_names should be string") + for item in target_names: + if type(item) != str: + raise ValueError("item in target_names should be string") + """ + self._feed_names = feed_names + self._target_names = target_names + + def set_feeds(self, feeds): + self._feeds = feeds + + def generate_fl_job( + self, fl_strategy, server_endpoints=[], worker_num=1, output=None + ): + """ + Generate Federated Learning Job, based on user defined configs + + Args: + fl_strategy(FLStrategyBase): federated learning strategy defined by current federated users + server_endpoints(List(str)): endpoints for federated server nodes + worker_num(int): number of training nodes + output(str): output directory of generated fl job + + Returns: + None + + Examples: + import paddle.fluid as fluid + import paddle_fl as fl + from paddle_fl.core.master.job_generator import JobGenerator + from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory + + input_x = fluid.layers.data(name="input_x", shape=[10], dtype="float32") + label = fluid.layers.data(name="label", shape[1], dtype="int64") + fc0 = fluid.layers.fc(input=input_x, size=2, act='sigmoid') + cost = fluid.layers.cross_entropy(input=fc0, label=label) + loss = fluid.layers.reduce_mean(cost) + + job_generator = JobGenerator() + optimizer = fluid.optimizer.SGD(learning_rate=0.1) + job_generator.set_optimizer(optimizer) + job_generator.set_losses([loss]) + server_endpoints = [127.0.0.1:8181] + worker_num = 10 + build_strategy = FLStrategyFactor() + build_strategy.fed_avg = True + strategy = build_strategy.create_fl_strategy() + job_output_dir = "fl_job_config" + job_generator.generate_fl_job(strategy, + server_endpoints=server_endpoints, + worker_num=1, + output=output) + + """ + local_job = FLCompileTimeJob() + assert len(self._losses) > 0 + assert self._startup_prog != None + assert fl_strategy != None + assert output != None + fl_strategy.minimize(self._optimizer, self._losses) + + # strategy can generate startup and main program + # of a single worker and servers + for trainer_id in range(worker_num): + startup_program = self._startup_prog.clone() + main_program = self._losses[0].block.program.clone() + fl_strategy._build_trainer_program_for_job( + trainer_id, + program=main_program, + ps_endpoints=server_endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, + job=local_job, + ) + + startup_program = self._startup_prog.clone() + main_program = self._losses[0].block.program.clone() + fl_strategy._build_server_programs_for_job( + program=main_program, + ps_endpoints=server_endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, + job=local_job, + ) + + local_job.set_feed_names(self._feed_names) + local_job.set_target_names(self._target_names) + local_job.set_feeds(self._feeds) + local_job.set_strategy(fl_strategy) + local_job.save(output) + + def generate_fl_job_for_k8s( + self, + fl_strategy, + server_pod_endpoints=[], + server_service_endpoints=[], + worker_num=1, + output=None, + ): + + local_job = FLCompileTimeJob() + assert len(self._losses) > 0 + assert self._startup_prog != None + assert fl_strategy != None + assert output != None + fl_strategy.minimize(self._optimizer, self._losses) + + # strategy can generate startup and main program + # of a single worker and servers + for trainer_id in range(worker_num): + startup_program = self._startup_prog.clone() + main_program = self._losses[0].block.program.clone() + fl_strategy._build_trainer_program_for_job( + trainer_id, + program=main_program, + ps_endpoints=server_service_endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, + job=local_job, + ) + + startup_program = self._startup_prog.clone() + main_program = self._losses[0].block.program.clone() + fl_strategy._build_server_programs_for_job( + program=main_program, + ps_endpoints=server_pod_endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, + job=local_job, + ) + + local_job.set_feed_names(self._feed_names) + local_job.set_target_names(self._target_names) + local_job.set_strategy(fl_strategy) + local_job.save(output) + + def save_program( + self, + main_prog, + startup_prog, + program_path, + input_list, + hidden_vars, + loss, + learning_rate=None, + ): + if not os.path.exists(program_path): + os.makedirs(program_path) + main_program_str = main_prog.desc.serialize_to_string() + startup_program_str = startup_prog.desc.serialize_to_string() + params = main_prog.global_block().all_parameters() + para_info = [] + for pa in params: + para_info.append(pa.name) + with open(program_path + "/input_names", "w") as fout: + for input in input_list: + fout.write("%s\n" % input) + if hidden_vars != None: + with open(program_path + "/hidden_vars", "w") as fout: + for var in hidden_vars: + fout.write("%s:%s\n" % (var[0], var[1].name)) + with open(program_path + "/para_info", "w") as fout: + for item in para_info: + fout.write("%s\n" % item) + with open(program_path + "/startup_program", "wb") as fout: + fout.write(startup_program_str) + with open(program_path + "/main_program", "wb") as fout: + fout.write(main_program_str) + with open(program_path + "/loss_name", "w") as fout: + fout.write(loss.name) + if type(learning_rate) == fluid.Variable: + with open(program_path + "/lr_name", "w") as fout: + fout.write(learning_rate.name) + + def generate_fl_job_from_program( + self, strategy, endpoints, worker_num, program_input, output + ): + local_job = FLCompileTimeJob() + with open(program_input + "/startup_program", "rb") as fin: + program_desc_str = fin.read() + new_startup = fluid.Program.parse_from_string(program_desc_str) + + with open(program_input + "/main_program", "rb") as fin: + program_desc_str = fin.read() + new_main = fluid.Program.parse_from_string(program_desc_str) + + para_list = [] + with open(program_input + "/para_info", "r") as fin: + for line in fin: + current_para = line[:-1] + para_list.append(current_para) + + input_list = [] + with open(program_input + "/input_names", "r") as fin: + for line in fin: + current_input = line[:-1] + input_list.append(current_input) + + with open(program_input + "/loss_name", "r") as fin: + loss_name = fin.read() + + if os.path.exists(program_input + "/lr_name"): + with open(program_input + "/lr_name", "r") as fin: + lr_name = fin.read() + else: + lr_name = None + + for item in para_list: + para = new_main.global_block().var(item) + para.regularizer = None + para.optimize_attr = {"learning_rate": 1.0} + para.trainable = True + exe = fluid.Executor(fluid.CPUPlace()) + loss = None + for var in new_main.list_vars(): + if var.name == loss_name: + loss = var + if lr_name != None: + if var.name == lr_name: + lr = var + + with fluid.program_guard(new_main, new_startup): + if lr_name != None: + optimizer = fluid.optimizer.MomentumOptimizer( + learning_rate=lr, momentum=0.9, parameter_list=para_list + ) + else: + optimizer = fluid.optimizer.MomentumOptimizer( + learning_rate=0.00001, momentum=0.9, parameter_list=para_list + ) + + exe.run(new_startup) + strategy.minimize(optimizer, loss) + + for trainer_id in range(worker_num): + startup_program = new_startup.clone() + main_program = loss.block.program.clone() + strategy._build_trainer_program_for_job( + trainer_id, + program=main_program, + ps_endpoints=endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, + job=local_job, + ) + + startup_program = new_startup.clone() + main_program = loss.block.program.clone() + strategy._build_server_programs_for_job( + program=main_program, + ps_endpoints=endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, + job=local_job, + ) + + local_job.set_feed_names(input_list) + local_job.set_target_names([loss.name]) + local_job.set_strategy(strategy) + local_job.save(output) diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/__init__.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/__init__.py new file mode 100755 index 000000000..33ed0ecf1 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/__init__.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/__init__.py new file mode 100755 index 000000000..82d0d336e --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from .program_utils import * +from .ufind import * +from .checkport import * +from .vars_distributed import * diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/checkport.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/checkport.py new file mode 100755 index 000000000..76f6216e5 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/checkport.py @@ -0,0 +1,54 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time +import socket +from contextlib import closing +from six import string_types + + +def wait_server_ready(endpoints): + """ + Wait until parameter servers are ready, use connext_ex to detect + port readiness. + + Args: + endpoints (list): endpoints string list, like: + ["127.0.0.1:8080", "127.0.0.1:8081"] + + Examples: + .. code-block:: python + + wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"]) + """ + assert not isinstance(endpoints, string_types) + while True: + all_ok = True + not_ready_endpoints = [] + for ep in endpoints: + ip_port = ep.split(":") + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + sock.settimeout(2) + result = sock.connect_ex((ip_port[0], int(ip_port[1]))) + if result != 0: + all_ok = False + not_ready_endpoints.append(ep) + if not all_ok: + sys.stderr.write("server not ready, wait 3 sec to retry...\n") + sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) + "\n") + sys.stderr.flush() + time.sleep(3) + else: + break diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/program_utils.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/program_utils.py new file mode 100755 index 000000000..16391cb42 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/program_utils.py @@ -0,0 +1,214 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import six + +from paddle.fluid import core +import paddle + + +def delete_ops(block, ops): + for op in ops: + try: + idx = list(block.ops).index(op) + block._remove_op(idx) + except Exception as e: + print(e) + + +def find_op_by_input_arg(block, arg_name): + for index, op in enumerate(block.ops): + if arg_name in op.input_arg_names: + return index + return -1 + + +def find_op_by_output_arg(block, arg_name, reverse=False): + if reverse: + pos = len(block.ops) - 1 + while pos >= 0: + op = block.ops[pos] + if arg_name in op.output_arg_names: + return pos + pos -= 1 + else: + for index, op in enumerate(block.ops): + if arg_name in op.output_arg_names: + return index + return -1 + + +def get_indent_space(indent, space_num=4): + ret = "" + for i in range(0, indent * space_num): + ret += " " + + return ret + + +def variable_to_code(var): + """ + Get readable codes of fluid variable. + + Args: + var: A fluid operator. + + Returns: + string: The formatted string. + """ + if ( + var.type == core.VarDesc.VarType.SELECTED_ROWS + or var.type == core.VarDesc.VarType.LOD_TENSOR + ): + var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})".format( + i="{", e="}", name=var.name, type=var.type, shape=var.shape, dtype=var.dtype + ) + else: + var_str = "{name} : fluid.{type})".format( + i="{", e="}", name=var.name, type=var.type + ) + + if type(var) == paddle.fluid.framework.Parameter: + if var.trainable: + var_str = "trainable parameter " + var_str + else: + var_str = "parameter " + var_str + else: + var_str = "var " + var_str + + if var.persistable: + var_str = "persist " + var_str + + return var_str + + +def op_to_code(op, skip_op_callstack=True): + """ + Get readable codes of fluid operator. + + Args: + op: A fluid operator. + + Returns: + string: The foramtted string. + """ + + outputs_str = "{" + for i in range(0, len(op.output_names)): + outputs_str += "{name}=".format(name=op.output_names[i]) + o = op.output(op.output_names[i]) + outputs_str += "{value}".format(value=o) + if i != len(op.output_names) - 1: + outputs_str += ", " + outputs_str += "}" + + inputs_str = "{" + for i in range(0, len(op.input_names)): + inputs_str += "{name}=".format(name=op.input_names[i]) + o = op.input(op.input_names[i]) + inputs_str += "{value}".format(value=o) + + if i != len(op.input_names) - 1: + inputs_str += ", " + inputs_str += "}" + + attr_names = sorted(op.attr_names) + attrs_str = "" + for i in range(0, len(attr_names)): + name = attr_names[i] + if skip_op_callstack and name == "op_callstack": + continue + + attr_type = op.desc.attr_type(name) + if attr_type == core.AttrType.BLOCK: + a = "{name} = block[{value}]".format( + name=name, type=attr_type, value=op._block_attr_id(name) + ) + attrs_str += a + if i != len(attr_names) - 1: + attrs_str += ", " + continue + + if attr_type == core.AttrType.BLOCKS: + a = "{name} = blocks{value}".format( + name=name, type=attr_type, value=op._blocks_attr_ids(name) + ) + attrs_str += a + if i != len(attr_names) - 1: + attrs_str += ", " + continue + + a = "{name} = {value}".format( + name=name, type=attr_type, value=op.desc.attr(name) + ) + attrs_str += a + if i != len(attr_names) - 1: + attrs_str += ", " + + if outputs_str != "{}": + op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".format( + outputs=outputs_str, op_type=op.type, inputs=inputs_str, attrs=attrs_str + ) + else: + op_str = "{op_type}(inputs={inputs}, {attrs})".format( + op_type=op.type, inputs=inputs_str, attrs=attrs_str + ) + return op_str + + +def block_to_code(block, block_idx, fout=None, skip_op_callstack=False): + indent = 0 + + print( + "{0}{1} // block {2}".format(get_indent_space(indent), "{", block_idx), + file=fout, + ) + + indent += 1 + # sort all vars + all_vars = sorted(six.iteritems(block.vars), key=lambda x: x[0]) + for var in all_vars: + print( + "{}{}".format(get_indent_space(indent), variable_to_code(var[1])), file=fout + ) + + if len(all_vars) > 0: + print("", file=fout) + + for op in block.ops: + print( + "{}{}".format(get_indent_space(indent), op_to_code(op, skip_op_callstack)), + file=fout, + ) + indent -= 1 + + print("{0}{1}".format(get_indent_space(indent), "}"), file=fout) + + +def program_to_code(prog, fout=None, skip_op_callstack=True): + """ + Print readable codes of fluid program. + + Args: + prog : A fluid program. + + An example result like bellow: + https://github.com/PaddlePaddle/Paddle/pull/12673 + """ + block_idx = 0 + for block in prog.blocks: + block_to_code(block, block_idx, fout, skip_op_callstack) + block_idx += 1 diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/ps_dispatcher.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/ps_dispatcher.py new file mode 100755 index 000000000..a04f6c2c7 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/ps_dispatcher.py @@ -0,0 +1,110 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + + +class PSDispatcher(object): + """ + PSDispatcher is the base class for dispatching vars + into different pserver instance. + You need to implement the `dispatch` inferface. + """ + + def __init__(self, pserver_endpoints): + self._eps = pserver_endpoints + self._step = 0 + + @property + def eps(self): + return self._eps + + def reset(self): + self._step = 0 + + def dispatch(self, varlist): + """ + Args: + varlist(list): a list of Variables + Returns: + a map of pserver endpoint -> varname + """ + AssertionError("Interface has not been implemented.") + + +class HashName(PSDispatcher): + """ + Hash variable names to several endpoints using python + "hash()" function. + + Args: + pserver_endpoints (list): list of endpoint(ip:port). + + Examples: + .. code-block:: python + + pserver_endpoints = ["127.0.0.1:6007", "127.0.0.1:6008"] + vars = ["var1","var2","var3","var4","var5"] + + rr = RoundRobin(pserver_endpoints) + rr.dispatch(vars) + + """ + + def __init__(self, pserver_endpoints): + super(self.__class__, self).__init__(pserver_endpoints) + + def _hash_block(self, block_str, total): + return hash(block_str) % total + + def dispatch(self, varlist): + eplist = [] + for var in varlist: + server_id = self._hash_block(var.name(), len(self._eps)) + server_for_param = self._eps[server_id] + eplist.append(server_for_param) + return eplist + + +class RoundRobin(PSDispatcher): + """ + Distribute variables to serveral endpoints using + RondRobin method. + + Args: + pserver_endpoints (list): list of endpoint(ip:port). + + Examples: + .. code-block:: python + + pserver_endpoints = ["127.0.0.1:6007", "127.0.0.1:6008"] + vars = ["var1","var2","var3","var4","var5"] + + rr = RoundRobin(pserver_endpoints) + rr.dispatch(vars) + + """ + + def __init__(self, pserver_endpoints): + super(self.__class__, self).__init__(pserver_endpoints) + + def dispatch(self, varlist): + eplist = [] + for var in varlist: + server_for_param = self._eps[self._step] + eplist.append(server_for_param) + self._step += 1 + if self._step >= len(self._eps): + self._step = 0 + return eplist diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/ufind.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/ufind.py new file mode 100755 index 000000000..9eae516a7 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/ufind.py @@ -0,0 +1,66 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + + +class UnionFind(object): + """Union-find data structure. + + Union-find is a data structure that keeps track of a set of elements partitioned + into a number of disjoint (non-overlapping) subsets. + + Reference: + https://en.wikipedia.org/wiki/Disjoint-set_data_structure + + Args: + elements(list): The initialize element list. + """ + + def __init__(self, elementes=None): + self._parents = [] # index -> parent index + self._index = {} # element -> index + self._curr_idx = 0 + if not elementes: + elementes = [] + for ele in elementes: + self._parents.append(self._curr_idx) + self._index.update({ele: self._curr_idx}) + self._curr_idx += 1 + + def find(self, x): + # Find the root index of given element x, + # execute the path compress while findind the root index + if not x in self._index: + return -1 + idx = self._index[x] + while idx != self._parents[idx]: + t = self._parents[idx] + self._parents[idx] = self._parents[t] + idx = t + return idx + + def union(self, x, y): + # Union two given element + x_root = self.find(x) + y_root = self.find(y) + + if x_root == y_root: + return + self._parents[x_root] = y_root + + def is_connected(self, x, y): + # If two given elements have the same root index, + # then they are connected. + return self.find(x) == self.find(y) diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/vars_distributed.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/vars_distributed.py new file mode 100755 index 000000000..f1849cbef --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/details/vars_distributed.py @@ -0,0 +1,298 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function +from paddle.fluid.framework import Variable + + +class VarStruct(object): + """ + record part properties of a Variable in python. + """ + + def __init__(self, name, shape, dtype, type, lod_level, persistable): + self.name = name + self.shape = shape + self.dtype = dtype + self.type = type + self.lod_level = lod_level + self.persistable = persistable + + +class VarDistributed(object): + """ + a class to record the var distributed on parameter servers. + the class will record the relationship between origin var and slice var. + the slice var's properties, such as type/shape/offset/endpoint. + """ + + def __init__( + self, + origin_var, + slice_var, + is_slice=None, + block_id=None, + offset=None, + vtype=None, + endpoint=None, + ): + """ + Args: + origin_var(Variable|VarStruct): origin var properties + slice_var(Variable|VarStruct): slice var properties + is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. + block_id(int|None): the number about the slice var. + offset(int|None): if the slice var is sliced, offset is the numel before the var. + vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. + endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" + """ + + if isinstance(origin_var, Variable): + self.origin = self.__create_var_struct(origin_var) + else: + self.origin = origin_var + + if isinstance(slice_var, Variable): + self.slice = self.__create_var_struct(slice_var) + else: + self.slice = slice_var + + if self.equal(self.origin, self.slice): + self.is_slice = False + self.block_id = 0 + self.offset = 0 + else: + self.is_slice = True + self.block_id = 0 + self.offset = 0 + + if is_slice is not None: + self.is_slice = is_slice + if block_id is not None: + self.block_id = block_id + if offset is not None: + self.offset = offset + + self.vtype = vtype + self.endpoint = endpoint + + @staticmethod + def __create_var_struct(var): + return VarStruct( + var.name, var.shape, var.dtype, var.type, var.lod_level, var.persistable + ) + + @staticmethod + def equal(var1, var2): + """ + the two var is equal or not. + Returns: + bool: equal will return True else False + """ + assert isinstance(var1, VarStruct) and isinstance(var2, VarStruct) + + return ( + var1.name == var2.name + and var1.type == var2.type + and var1.shape == var2.shape + and var1.dtype == var2.dtype + and var1.lod_level == var2.lod_level + and var1.persistable == var2.persistable + ) + + def __str__(self): + origin_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})".format( + i="{", + e="}", + name=self.origin.name, + type=self.origin.type, + shape=self.origin.shape, + dtype=self.origin.dtype, + ) + + slice_var_str = ( + "{name} : fluid.{type}.shape{shape}.astype({dtype})" + ".slice({is_slice}).block({block_id}).offset({offset})".format( + i="{", + e="}", + name=self.slice.name, + type=self.slice.type, + shape=self.slice.shape, + dtype=self.slice.dtype, + is_slice=self.is_slice, + block_id=self.block_id, + offset=self.offset, + ) + ) + + return "var owned: {}, origin var: ( {} ), slice var: ( {} ), endpoint: {} ".format( + self.vtype, origin_var_str, slice_var_str, self.endpoint + ) + + +class VarsDistributed(object): + """ + a gather about VarDistributed with many methods to find distributed vars. + through the class, we can get overview about the distributed parameters on parameter servers. + this class may centralized and convenient for developer to manage and get variable's distribute. + other module can also use this to find variables such io.py. + """ + + def __init__(self): + self.distributed_vars = [] + + def add_distributed_var( + self, + origin_var, + slice_var, + is_slice=None, + block_id=None, + offset=None, + vtype=None, + endpoint=None, + ): + """ + add distributed var in this. + + Args: + origin_var(Variable|VarStruct): origin var properties + slice_var(Variable|VarStruct): slice var properties + is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. + block_id(int|None): the number about the slice var. + offset(int|None): if the slice var is sliced, offset is the numel before the var. + vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. + endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" + Returns: + None + """ + self.distributed_vars.append( + VarDistributed( + origin_var, slice_var, is_slice, block_id, offset, vtype, endpoint + ) + ) + + def get_distributed_var_by_slice(self, var_name): + """ + get distributed var by conditions. + + Args: + var_name(str): slice var name, such as "w.traier0.block1" + Returns: + VarDistributed: distributed var. + """ + for dist_var in self.distributed_vars: + if dist_var.slice.name == var_name: + return dist_var + return None + + @staticmethod + def equal(var1, var2): + """ + the two var is equal or not. + Returns: + bool: equal will return True else False + """ + return ( + var1.name == var2.name + and var1.type == var2.type + and var1.shape == var2.shape + and var1.dtype == var2.dtype + and var1.lod_level == var2.lod_level + and var1.persistable == var2.persistable + ) + + def get_distributed_var_by_origin_and_ep(self, origin_var_name, endpoint): + """ + get distributed var by conditions. + + Args: + origin_var_name(str): + endpoint(str): the parameter endpoint, such as "127.0.0.1:1001" + Returns: + VarDistributed: distributed var. + """ + for dist_var in self.distributed_vars: + if ( + dist_var.origin.name == origin_var_name + and dist_var.endpoint == endpoint + ): + return dist_var + return None + + def get_distributed_vars_by_vtypes(self, vtypes, groupby=False): + """ + get distributed vars by conditions. + + Args: + vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" + groupby(bool|False): group by origin var or not. + + Returns: + list: distributed var list. + dict: distributed var map when groupby=True + """ + vtype_vars = [] + for var in self.distributed_vars: + if var.vtype in vtypes: + vtype_vars.append(var) + if not groupby: + return vtype_vars + + params_map = {} + for var in vtype_vars: + origin_var_name = var.origin.name + + if origin_var_name in params_map.keys(): + optimizers = params_map.get(origin_var_name) + else: + optimizers = [] + optimizers.append(var) + params_map[origin_var_name] = optimizers + return params_map + + def get_distributed_vars_by_ep(self, endpoint, vtype=None): + """ + get distributed vars by conditions. + + Args: + endpoint(str): the parameter server endpoint, such as "127.0.0.1:2001" + vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" + + Returns: + list: distributed var list. + """ + endpoint_vars = [] + for var in self.distributed_vars: + if var.endpoint == endpoint: + endpoint_vars.append(var) + if not vtype: + return endpoint_vars + + vtype_vars = [] + for var in endpoint_vars: + if var.vtype == vtype: + vtype_vars.append(var) + return vtype_vars + + def overview(self): + """ + get the overview string about all params on all parameter servers. + + Returns: + Str: overview string. + + """ + vars_str = [] + for var in self.distributed_vars: + vars_str.append(str(var)) + return "\n".join(vars_str) diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/fl_distribute_transpiler.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/fl_distribute_transpiler.py new file mode 100755 index 000000000..420421bad --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/fl_distribute_transpiler.py @@ -0,0 +1,882 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import collections +import sys + +import six +from paddle.fluid import core, framework +from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table +from paddle.fluid.framework import ( + Program, + default_main_program, + default_startup_program, + Parameter, +) +from paddle.fluid.transpiler.distribute_transpiler import ( + DistributeTranspilerConfig, + slice_variable, +) + +from .details import UnionFind, VarsDistributed +from .details import delete_ops +from .details.ps_dispatcher import RoundRobin, PSDispatcher + +LOOKUP_TABLE_TYPE = "lookup_table" +LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" +OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() +RPC_OP_ROLE_ATTR_NAME = ( + op_role_attr_name +) = core.op_proto_and_checker_maker.kOpRoleAttrName() +OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize +RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC +DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist +LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched +PRINT_LOG = False + + +def log(*args): + if PRINT_LOG: + print(args) + + +def same_or_split_var(p_name, var_name): + return p_name == var_name or p_name.startswith(var_name + ".block") + + +class FLDistributeTranspiler(object): + """ + **FlDistributeTranspiler** + + Convert the fluid program to distributed data-parallelism programs. + + In pserver mode, the trainers' main program do forward, backward and optimizaiton. + pserver's main_program will sum and scale. + + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_loss = fluid.layers.mean(cost) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(avg_loss) + + # for pserver mode + pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174" + trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174" + current_endpoint = "192.168.0.1:6174" + trainer_id = 0 + trainers = 4 + role = "PSERVER" + t = fluid.FlDistributeTranspiler() + t.transpile( + trainer_id, pservers=pserver_endpoints, trainers=trainers) + if role == "PSERVER": + pserver_program = t.get_pserver_program(current_endpoint) + pserver_startup_program = t.get_startup_program(current_endpoint, + pserver_program) + elif role == "TRAINER": + trainer_program = t.get_trainer_program() + + """ + + def __init__(self, config=None): + if config is not None: + self.config = config + else: + self.config = DistributeTranspilerConfig() + + if self.config.split_method is None: + self.config.split_method = RoundRobin + + global PRINT_LOG + if self.config.print_log: + PRINT_LOG = True + assert self.config.min_block_size >= 8192 + assert self.config.split_method.__bases__[0] == PSDispatcher + + def _get_all_remote_sparse_update_op(self, main_program): + sparse_update_ops = [] + sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"] + for op in main_program.global_block().ops: + if op.type in sparse_update_op_types and op.attr("remote_prefetch") is True: + sparse_update_ops.append(op) + return sparse_update_ops + + def transpile( + self, + trainer_id, + program=None, + pservers="127.0.0.1:6174", + trainers=1, + sync_mode=True, + startup_program=None, + current_endpoint="127.0.0.1:6174", + ): + """ + Run the transpiler. Transpile the input program. + + Args: + trainer_id (int): id for current trainer worker, if you have + n workers, the id may range from 0 ~ n-1 + program (Program|None): program to transpile, + default is fluid.default_main_program(). + startup_program (Program|None): startup_program to transpile, + default is fluid.default_startup_program(). + pservers (str): comma separated ip:port string for the pserver + list. + trainers (int|str): in pserver mode this is the number of + trainers. + sync_mode (bool): Do sync training or not, default is True. + startup_program (Program|None): startup_program to transpile, + default is fluid.default_main_program(). + current_endpoint (str): In pserver mode + this argument is not used. + + Examples: + .. code-block:: python + + transpiler = fluid.DistributeTranspiler() + t.transpile( + trainer_id=0, + pservers="127.0.0.1:7000,127.0.0.1:7001", + trainers=2, + sync_mode=False, + current_endpoint="127.0.0.1:7000") + """ + if program is None: + program = default_main_program() + if startup_program is None: + startup_program = default_startup_program() + self.origin_program = program + self.startup_program = startup_program + self.origin_startup_program = self.startup_program.clone() + + self.trainer_num = trainers + self.sync_mode = sync_mode + self.trainer_id = trainer_id + pserver_endpoints = pservers.split(",") + self.pserver_endpoints = pserver_endpoints + self.vars_overview = VarsDistributed() + self.optimize_ops, self.params_grads = self._get_optimize_pass() + + ps_dispatcher = self.config.split_method(self.pserver_endpoints) + self.table_name = find_distributed_lookup_table(self.origin_program) + self.has_distributed_lookup_table = self.table_name != None + self.param_name_to_grad_name = dict() + self.grad_name_to_param_name = dict() + for param_var, grad_var in self.params_grads: + self.param_name_to_grad_name[param_var.name] = grad_var.name + self.grad_name_to_param_name[grad_var.name] = param_var.name + + # get all sparse update ops + self.sparse_update_ops = self._get_all_remote_sparse_update_op( + self.origin_program + ) + # use_sparse_update_param_name -> split_height_section + self.sparse_param_to_height_sections = dict() + + # add distributed attrs to program + self.origin_program._is_distributed = True + self.origin_program._endpoints = self.pserver_endpoints + self.origin_program._ps_endpoint = current_endpoint + self.origin_program._is_chief = self.trainer_id == 0 + self.origin_program._distributed_lookup_table = ( + self.table_name if self.table_name else None + ) + + # split and create vars, then put splited vars in dicts for later use. + # step 1: split and create vars, then put splited vars in dicts for later use. + self._init_splited_vars() + + # step 2: insert send op to send gradient vars to parameter servers + ps_dispatcher.reset() + send_vars = [] + + # in general cases, the number of pservers is times of 2, and this + # will lead to uneven distribution among weights and bias: + # fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1 + # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 + # shuffle the map will avoid the uneven distribution above + + self.opti_name_to_send_dummy_out = dict() + self.recv_program = self.origin_program.clone() + all_ops = [] + for op in self.recv_program.global_block().ops: + all_ops.append(op) + delete_ops(self.recv_program.global_block(), all_ops) + + self.split_num = len(program.global_block().ops) + for opti_varname in self._opti_var_list: + opti_var = program.global_block().var(opti_varname) + eplist = ps_dispatcher.dispatch([opti_var]) + + dummy_output = program.global_block().create_var( + name=framework.generate_control_dev_var_name() + ) + self.opti_name_to_send_dummy_out[opti_varname] = dummy_output + + program.global_block().append_op( + type="send", + inputs={"X": [opti_var]}, + outputs={"Out": dummy_output}, + attrs={ + "epmap": eplist, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: [ + self._opti_to_param[opti_varname], + opti_varname, + ], + "sync_mode": not self.sync_mode, + }, + ) + send_vars.append(opti_var) + + if self.sync_mode: + send_barrier_out = program.global_block().create_var( + name=framework.generate_control_dev_var_name() + ) + input_deps = list(self.opti_name_to_send_dummy_out.values()) + + program.global_block().append_op( + type="send_barrier", + inputs={"X": list(input_deps)}, + outputs={"Out": send_barrier_out}, + attrs={ + "endpoints": pserver_endpoints, + "sync_mode": self.sync_mode, + "trainer_id": self.trainer_id, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + }, + ) + + # step 3: insert recv op to receive parameters from parameter server + recv_vars = [] + for _, var in enumerate(send_vars): + recv_vars.append(program.global_block().var(self._opti_to_param[var.name])) + ps_dispatcher.reset() + eplist = ps_dispatcher.dispatch(recv_vars) + for i, ep in enumerate(eplist): + self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) + self.param_grad_ep_mapping[ep]["opti"].append(send_vars[i]) + + distributed_var = self.vars_overview.get_distributed_var_by_slice( + recv_vars[i].name + ) + distributed_var.endpoint = ep + + # step4: Concat the parameters splits together after recv. + all_recv_outputs = [] + for opti_varname in self._opti_var_list: + opti_var = program.global_block().var(opti_varname) + param_varname = self._opti_to_param[opti_varname] + param_var = program.global_block().var(param_varname) + eps = [] + table_names = [] + index = [v.name for v in recv_vars].index(param_varname) + eps.append(eplist[index]) + table_names.append(var.name) + if self.sync_mode: + recv_dep_in = send_barrier_out + # get recv op_role_var, if not splited, the grad should have .trainer suffix + # if splited, grad should be the original grad var name. ParallelExecutor + # will use op_role_var to get expected device place to run this op. + + all_recv_outputs.extend([param_var]) + self.recv_program.global_block().append_op( + type="recv", + inputs={"X": []}, + outputs={"Out": [param_var]}, + attrs={ + "epmap": eps, + "trainer_id": self.trainer_id, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: [param_varname, opti_varname], + "sync_mode": not self.sync_mode, + }, + ) + + if self.sync_mode: + # form a WAW dependency + self.recv_program.global_block()._insert_op( + index=len(self._opti_var_list), + type="fetch_barrier", + inputs={}, + outputs={"Out": all_recv_outputs}, + attrs={ + "endpoints": pserver_endpoints, + "trainer_id": self.trainer_id, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + }, + ) + + self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) + + self._get_distributed_optimizer_vars() + self.origin_program._parameters_on_pservers = self.vars_overview + + def get_trainer_program(self, wait_port=True): + """ + Get transpiled trainer side program. + + Returns: + Program: trainer side program. + """ + # remove optimize ops and add a send op to main_program + # FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay? + + lr_ops = self._get_lr_ops() + + self.origin_program.__str__() + + self.send_program = self.origin_program.clone() + compute_ops = self.send_program.global_block().ops[0 : self.split_num] + delete_ops(self.send_program.global_block(), compute_ops) + send_ops = self.origin_program.global_block().ops[self.split_num :] + delete_ops(self.origin_program.global_block(), send_ops) + + return self.recv_program, self.origin_program, self.send_program + + def _get_trainer_startup_program(self, recv_vars, eplist): + """ + Get transpiled trainer side startup program. + + Args: + recv_vars (list): Variable list to recv for current trainer_id + eplist (list): A list of strings indicating + + Returns: + Program: trainer side startup program. + """ + startup_program = self.startup_program + + # FIXME(gongwb): delete not need ops. + # note that: some parameter is not trainable and those ops can't be deleted. + for opti_varname in self._opti_var_list: + opti_var = self.origin_program.global_block().var(opti_varname) + param_varname = self._opti_to_param[opti_varname] + var = self.origin_program.global_block().var(param_varname) + + # Get the eplist of recv vars + eps = [] + table_names = [] + index = [v.name for v in recv_vars].index(param_varname) + eps.append(eplist[index]) + + return startup_program + + def get_pserver_program(self, endpoint): + """ + Get parameter server side program. + + Args: + endpoint (str): current parameter server endpoint. + + Returns: + Program: the program for current parameter server to run. + """ + # TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers. + # NOTE: assume blocks of the same variable is not distributed + # on the same pserver, only change param/grad varnames for + # trainers to fetch. + sys.stderr.write( + "get_pserver_program() is deprecated, call get_pserver_programs() to get pserver main and startup in a single call.\n" + ) + # step1 + pserver_program = Program() + pserver_program.random_seed = self.origin_program.random_seed + pserver_program._copy_dist_param_info_from(self.origin_program) + + # step2: Create vars to receive vars at parameter servers. + recv_inputs = [] + for v in self.param_grad_ep_mapping[endpoint]["params"]: + self._clone_var(pserver_program.global_block(), v) + for v in self.param_grad_ep_mapping[endpoint]["opti"]: + # create vars for each trainer in global scope, so + # we don't need to create them when grad arrives. + # change client side var name to origin name by + # removing ".trainer_%d" suffix + suff_idx = v.name.find(".opti.trainer_") + if suff_idx >= 0: + orig_var_name = v.name[:suff_idx] + # NOTE: single_trainer_var must be created for multi-trainer + # case to merge grads from multiple trainers + single_trainer_var = pserver_program.global_block().var(orig_var_name) + + if self.sync_mode and self.trainer_num > 1: + for trainer_id in range(self.trainer_num): + var = pserver_program.global_block().create_var( + name="%s.opti.trainer_%d" % (orig_var_name, trainer_id), + persistable=False, + type=v.type, + dtype=v.dtype, + shape=v.shape, + ) + recv_inputs.append(var) + + # step 3 + # Create a union-find data structure from optimize ops, + # If two ops are connected, we could add these two ops + # into one set. + ufind = self._create_ufind(self.optimize_ops) + # step 3.2 + # Iterate through the ops and append optimize op which + # located on current pserver + opt_op_on_pserver = [] + for _, op in enumerate(self.optimize_ops): + if self._is_optimizer_op(op) and self._is_opt_op_on_pserver(endpoint, op): + opt_op_on_pserver.append(op) + + # step 3.4 + # Iterate through the ops, and if an op and the optimize ops + # which located on current pserver are in one set, then + # append it into the sub program. + + global_ops = [] + + # sparse grad name to param name + sparse_grad_to_param = [] + + # append lr decay ops to the child block if exists + lr_ops = self._get_lr_ops() + # record optimize blocks and we can run them on pserver parallel + opti_blocks = [] + + # append op to the current block + grad_to_block_id = [] + pre_block_idx = pserver_program.num_blocks - 1 + for idx, opt_op in enumerate(self._opti_var_list): + per_opt_block = pserver_program._create_block(pre_block_idx) + opti_blocks.append(per_opt_block) + optimize_target_param_name = self._opti_to_param[opt_op] + pserver_block = per_opt_block.program.global_block() + # append grad merging ops before clip and weight decay + # e.g. merge grad -> L2Decay op -> clip op -> optimize + merged_var = pserver_block.vars[optimize_target_param_name] + if self.sync_mode and self.trainer_num > 1: + vars2merge = [] + for i in range(self.trainer_num): + per_trainer_name = "%s.opti.trainer_%d" % ( + optimize_target_param_name, + i, + ) + vars2merge.append(pserver_block.vars[per_trainer_name]) + per_opt_block.append_op( + type="sum", + inputs={"X": vars2merge}, + outputs={"Out": merged_var}, + attrs={"use_mkldnn": False}, + ) + per_opt_block.append_op( + type="scale", + inputs={"X": merged_var}, + outputs={"Out": merged_var}, + attrs={"scale": 1.0 / float(self.trainer_num)}, + ) + + # In some case, some parameter server will have no parameter to optimize + # So we give an empty optimize block to parameter server. + attrs = { + "optimize_blocks": opti_blocks, + "endpoint": endpoint, + "Fanin": self.trainer_num, + "sync_mode": self.sync_mode, + } + + # step5 append the listen_and_serv op + pserver_program.global_block().append_op( + type="fl_listen_and_serv", + inputs={"X": recv_inputs}, + outputs={}, + attrs=attrs, + ) + + pserver_program._sync_with_cpp() + # save pserver program to generate pserver side startup relatively. + self.pserver_program = pserver_program + return pserver_program + + def get_startup_program(self, endpoint, pserver_program=None, startup_program=None): + """ + **Deprecated** + + Get startup program for current parameter server. + Modify operator input variables if there are variables that + were split to several blocks. + + Args: + endpoint (str): current pserver endpoint. + pserver_program (Program): deprecated, call get_pserver_program first. + startup_program (Program): deprecated, should pass startup_program + when initalizing + + Returns: + Program: parameter server side startup program. + """ + s_prog = Program() + orig_s_prog = self.startup_program + s_prog.random_seed = orig_s_prog.random_seed + params = self.param_grad_ep_mapping[endpoint]["params"] + + def _get_splited_name_and_shape(varname): + for idx, splited_param in enumerate(params): + pname = splited_param.name + if same_or_split_var(pname, varname) and varname != pname: + return pname, splited_param.shape + return "", [] + + # 1. create vars in pserver program to startup program + pserver_vars = pserver_program.global_block().vars + created_var_map = collections.OrderedDict() + for _, var in six.iteritems(pserver_vars): + tmpvar = s_prog.global_block()._clone_variable(var) + created_var_map[var.name] = tmpvar + + # 2. rename op outputs + for op in orig_s_prog.global_block().ops: + new_outputs = collections.OrderedDict() + # do not append startup op if var is not on this pserver + op_on_pserver = False + # TODO(gongwb): remove this line. + if op.type not in ["recv", "fetch_barrier", "concat"]: + for key in op.output_names: + newname, _ = _get_splited_name_and_shape(op.output(key)[0]) + if newname: + op_on_pserver = True + new_outputs[key] = created_var_map[newname] + elif op.output(key)[0] in pserver_vars: + op_on_pserver = True + new_outputs[key] = pserver_vars[op.output(key)[0]] + + if op_on_pserver: + # most startup program ops have no inputs + new_inputs = self._get_input_map_from_op(pserver_vars, op) + + if op.type in [ + "gaussian_random", + "fill_constant", + "uniform_random", + "truncated_gaussian_random", + ]: + op._set_attr("shape", list(new_outputs["Out"].shape)) + s_prog.global_block().append_op( + type=op.type, + inputs=new_inputs, + outputs=new_outputs, + attrs=op.all_attrs(), + ) + + return s_prog + + # ====================== private transpiler functions ===================== + def _get_slice_var_info(self, slice_var): + block_suffix = "block" + block_idx = 0 + offset = 0 + is_slice = False + + orig_var_name, block_name, _ = self._get_varname_parts(slice_var.name) + + if not block_name: + return is_slice, block_idx, offset + + def _get_distributed_optimizer_vars(self): + def _get_distributed_optimizer_var(endpoint): + opt_op_on_pserver = [] + for _, op in enumerate(self.optimize_ops): + if self._is_optimizer_op(op) and self._is_opt_op_on_pserver( + endpoint, op + ): + opt_op_on_pserver.append(op) + + for opt_op in opt_op_on_pserver: + dist_var = None + for key in opt_op.input_names: + if key == "Param": + param_name = opt_op.input(key)[0] + dist_var = ( + self.vars_overview.get_distributed_var_by_origin_and_ep( + param_name, endpoint + ) + ) + break + for key in opt_op.input_names: + if key in ["Param", "Grad", "LearningRate"]: + continue + + for ep in self.pserver_endpoints: + _get_distributed_optimizer_var(ep) + + def _update_dist_lookup_table_vars(self, param_list, grad_list, params_grads): + # TODO(wuyi): put find a way to put dist lookup table stuff all together. + # update self.table_param_grad and self.trainer_side_table_grad_list + program = self.origin_program + return param_list, grad_list + + def _init_splited_vars(self): + # update these mappings for further transpile: + # 1. param_var_mapping: param var name -> [splited params vars] + # 2. grad_var_mapping: grad var name -> [splited grads vars] + # 3. grad_param_mapping: grad.blockx -> param.blockx + # 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []} + + param_list = [] + grad_list = [] + param_grad_set = set() + for p, g in self.params_grads: + # skip parameter marked not trainable + if type(p) == Parameter and p.trainable == False: + continue + if p.name not in param_grad_set: + param_list.append(p) + param_grad_set.add(p.name) + if g.name not in param_grad_set: + grad_list.append(g) + param_grad_set.add(g.name) + + # To do : consider lookup table later + param_list, grad_list = self._update_dist_lookup_table_vars( + param_list, grad_list, self.params_grads + ) + + if self.config.slice_var_up: + # when we slice var up into blocks, we will slice the var according to + # pserver services' count. A pserver may have two or more listening ports. + grad_blocks = slice_variable( + grad_list, len(self.pserver_endpoints), self.config.min_block_size + ) + param_blocks = slice_variable( + param_list, len(self.pserver_endpoints), self.config.min_block_size + ) + assert len(grad_blocks) == len(param_blocks) + + # origin_param_name -> [splited_param_vars] + self.param_var_mapping = self._create_vars_from_blocklist( + self.origin_program, param_blocks + ) + + for orig_name, splited_vars in self.param_var_mapping.items(): + orig_var = self.origin_program.global_block().var(orig_name) + for splited_var in splited_vars: + is_slice, block_id, offset = self._get_slice_var_info(splited_var) + + self.vars_overview.add_distributed_var( + origin_var=orig_var, + slice_var=splited_var, + block_id=block_id, + offset=offset, + is_slice=is_slice, + vtype="Param", + ) + + # origin_grad_name -> [splited_grad_vars] + self.grad_var_mapping = self._create_vars_from_blocklist( + self.origin_program, grad_blocks + ) + # add_trainer_suffix=self.trainer_num > 1) + # dict(grad_splited_var -> param_splited_var) + self.grad_param_mapping = collections.OrderedDict() + for g, p in zip(grad_blocks, param_blocks): + g_name, g_bid, _ = g.split(":") + p_name, p_bid, _ = p.split(":") + self.grad_param_mapping[ + self.grad_var_mapping[g_name][int(g_bid)] + ] = self.param_var_mapping[p_name][int(p_bid)] + + # create mapping of endpoint -> split var to create pserver side program + self.param_grad_ep_mapping = collections.OrderedDict() + [ + self.param_grad_ep_mapping.update({ep: {"params": [], "opti": []}}) + for ep in self.pserver_endpoints + ] + + opti_list = [] + opti_to_param = dict() + param_to_opti = dict() + for op in self.optimize_ops: + if (op.type == "sgd") or (op.type == "adam") or (op.type == "momentum"): + origin_name = op.output("ParamOut") + var = self.origin_program.global_block().var(origin_name[0]) + new_var_name = "%s.opti.trainer_%d" % (origin_name[0], self.trainer_id) + self.origin_program.global_block().create_var( + name=new_var_name, + persistable=True, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + ) + new_var = self.origin_program.global_block().var(new_var_name) + opti_list.append(new_var.name) + opti_to_param[new_var.name] = var.name + param_to_opti[var.name] = new_var.name + self.origin_program.global_block().append_op( + type="scale", + inputs={"X": var}, + outputs={"Out": new_var}, + attrs={"scale": 1.0}, + ) + self._param_to_opti = param_to_opti + self._opti_to_param = opti_to_param + self._opti_var_list = opti_list + + def _create_vars_from_blocklist( + self, program, block_list, add_trainer_suffix=False + ): + """ + Create vars for each split. + NOTE: only grads need to be named for different trainers, use + add_trainer_suffix to rename the grad vars. + Args: + program (ProgramDesc): ProgramDesc which gradients blong. + block_list (list[(varname, block_id, block_size)]): List of gradient blocks. + add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True. + Returns: + var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping + from original var name to each var split. + """ + + # varname->[(block_id, current_block_size)] + block_map = collections.OrderedDict() + + var_mapping = collections.OrderedDict() + for block_str in block_list: + varname, offset, size = block_str.split(":") + if varname not in block_map: + block_map[varname] = [] + block_map[varname].append((int(offset), int(size))) + + for varname, splited in six.iteritems(block_map): + orig_var = program.global_block().var(varname) + if len(splited) == 1: + var_mapping[varname] = [program.global_block().var(orig_var.name)] + continue + return var_mapping + + def _clone_var(self, block, var, persistable=True): + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=persistable, + ) + + def _get_varname_parts(self, varname): + # returns origin, blockid, trainerid + orig_var_name = "" + trainer_part = "" + block_part = "" + trainer_idx = varname.find(".trainer_") + if trainer_idx >= 0: + trainer_part = varname[trainer_idx + 1 :] + else: + trainer_idx = len(varname) + block_index = varname.find(".block") + if block_index >= 0: + block_part = varname[block_index + 1 : trainer_idx] + else: + block_index = len(varname) + orig_var_name = varname[0 : min(block_index, trainer_idx)] + return orig_var_name, block_part, trainer_part + + def _is_op_connected(self, op1, op2): + # If one op's input is another op's output or + # one op's output is another op's input, we say + # the two operator is connected. + if set(op1.desc.output_arg_names()) & set(op2.desc.input_arg_names()) or set( + op1.desc.input_arg_names() + ) & set(op2.desc.output_arg_names()): + return True + return False + + def _create_ufind(self, optimize_ops): + # Create a unit find data struct by optimize ops + ufind = UnionFind(optimize_ops) + for i in range(len(optimize_ops)): + for j in range(i, len(optimize_ops)): + op1 = optimize_ops[i] + op2 = optimize_ops[j] + if self._is_op_connected(op1, op2): + ufind.union(op1, op2) + return ufind + + def _is_optimizer_op(self, op): + if "Param" in op.input_names and "LearningRate" in op.input_names: + return True + return False + + def _is_opt_op_on_pserver(self, endpoint, op): + param_names = [p.name for p in self.param_grad_ep_mapping[endpoint]["params"]] + if op.input("Param")[0] in param_names: + return True + + def _get_input_map_from_op(self, varmap, op): + """Returns a dict from op input name to the vars in varmap.""" + iomap = collections.OrderedDict() + return iomap + + def _get_lr_ops(self): + lr_ops = [] + block = self.origin_program.global_block() + for op in block.ops: + role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) + return lr_ops + + def _is_opt_role_op(self, op): + # NOTE: depend on oprole to find out whether this op is for + # optimize + op_maker = core.op_proto_and_checker_maker + optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize + if op_maker.kOpRoleAttrName() in op.attr_names and int( + op.all_attrs()[op_maker.kOpRoleAttrName()] + ) == int(optimize_role): + return True + return False + + def _get_optimize_pass(self): + """ + Get optimizer operators, parameters and gradients from origin_program + Returns: + opt_ops (list): optimize operators. + params_grads (dict): parameter->gradient. + """ + block = self.origin_program.global_block() + opt_ops = [] + params_grads = [] + # tmp set to dedup + optimize_params = set() + origin_var_dict = self.origin_program.global_block().vars + for op in block.ops: + if self._is_opt_role_op(op): + opt_ops.append(op) + if op.attr(OP_ROLE_VAR_ATTR_NAME): + param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] + grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] + if not param_name in optimize_params: + optimize_params.add(param_name) + log("adding param_grad pair: ", param_name, grad_name) + params_grads.append( + [origin_var_dict[param_name], origin_var_dict[grad_name]] + ) + return opt_ops, params_grads diff --git a/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/fl_strategy_base.py b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/fl_strategy_base.py new file mode 100755 index 000000000..4c1a84e12 --- /dev/null +++ b/VisualFL/depends/PaddleFL/python/paddle_fl/paddle_fl/core/strategy/fl_strategy_base.py @@ -0,0 +1,130 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .fl_distribute_transpiler import FLDistributeTranspiler + + +class FLStrategyBase(object): + """ + FLStrategyBase is federated learning algorithm container + """ + + def __init__(self): + self._fed_avg = False + self._dpsgd = False + self._inner_step = 1 + pass + + def minimize(self, optimizer=None, losses=[]): + """ + minmize can do minimization as paddle.fluid.Optimizer.minimize does + this function can be overloaded so that for some FLStrategy, the + program should be transpiled before minimize + + Args: + optimizer(paddle.fluid.optimizer): the user defined optimizer + losses(List(Variable)): list of loss variables in paddle.fluid + """ + for loss in losses: + optimizer.minimize(loss) + + def _build_trainer_program_for_job( + self, + trainer_id=0, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None, + ): + pass + + def _build_server_programs_for_job( + self, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None, + ): + pass + + +class FedAvgStrategy(FLStrategyBase): + """ + FedAvgStrategy: this is model averaging optimization proposed in + H. Brendan McMahan, Eider Moore, Daniel Ramage, Blaise Aguera y Arcas. Federated Learning of Deep Networks using Model Averaging. 2017 + """ + + def __init__(self): + super(FedAvgStrategy, self).__init__() + + def minimize(self, optimizer=None, losses=[]): + """ + minimize the first loss as in paddle.fluid + """ + optimizer.minimize(losses[0]) + + def _build_trainer_program_for_job( + self, + trainer_id=0, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None, + ): + transpiler = FLDistributeTranspiler() + transpiler.transpile( + trainer_id, + program=program, + pservers=",".join(ps_endpoints), + trainers=trainers, + sync_mode=sync_mode, + startup_program=startup_program, + ) + recv, main, send = transpiler.get_trainer_program() + job._trainer_startup_programs.append(startup_program) + job._trainer_main_programs.append(main) + job._trainer_send_programs.append(send) + job._trainer_recv_programs.append(recv) + + def _build_server_programs_for_job( + self, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None, + ): + transpiler = FLDistributeTranspiler() + trainer_id = 0 + transpiler.transpile( + trainer_id, + program=program, + pservers=",".join(ps_endpoints), + trainers=trainers, + sync_mode=sync_mode, + startup_program=startup_program, + ) + job.set_server_endpoints(ps_endpoints) + for endpoint in ps_endpoints: + main_prog = transpiler.get_pserver_program(endpoint) + startup_prog = transpiler.get_startup_program(endpoint, main_prog) + job._server_startup_programs.append(startup_prog) + job._server_main_programs.append(main_prog) diff --git a/VisualFL/depends/README.md b/VisualFL/depends/README.md new file mode 100755 index 000000000..03f91492f --- /dev/null +++ b/VisualFL/depends/README.md @@ -0,0 +1,8 @@ +## + +This project is heavily depends on the Project [PaddleFL](https://github.com/PaddlePaddle/PaddleFL). + +While we meet some problem here. There are many packages we don't need right now and trying to install their requirements could be really painful. +So, we have to simply copy and modify [PaddleFL's](https://github.com/PaddlePaddle/PaddleFL/tree/f1a6f8951ad78feb594064d165db951df1a0e0bd) codes here. + +We would report issues directly to the upstream and, these files would be delete once issues resolved. diff --git a/VisualFL/deploy_tools/MANIFEST.in b/VisualFL/deploy_tools/MANIFEST.in new file mode 100755 index 000000000..6f454c4c5 --- /dev/null +++ b/VisualFL/deploy_tools/MANIFEST.in @@ -0,0 +1,3 @@ +include visualfl_deploy/data/visualfl.tar.gz +include visualfl_deploy/template/* + diff --git a/VisualFL/deploy_tools/README.md b/VisualFL/deploy_tools/README.md new file mode 100755 index 000000000..942e19daf --- /dev/null +++ b/VisualFL/deploy_tools/README.md @@ -0,0 +1,89 @@ + +## Wefe VisualFL Deploy Toolkit + +### Prerequisites + +Too run VisualFL, following dependency or tools required: + +- machine to install wefe_visualfl_deploy_tools: + + - python virtualenv with Python>=3 + + - setup SSH password less login to machine(s) for deploy visualfl framework. + +- machine(s) to deploy visualfl framework: + + - Python>=3.7(with pip) + + - an isolated directory (each directory will be deployed with a copy of code) + + +### Build package + +```bash +cd VisualFL/deploy_tools/visualfl_deploy +python visualfl_deploy/_build.py +cd .. +pyhon setup.py sdist +``` +### upload `VisualFL/deploy_tools/dist/visualfl_deploy-1.0.tar.gz` to server. + +### Deploy + +1. install visualfl deploy toolkit + + ``` bash + # ceate a python virtual envirement (recommanded) or use an exist one. + cd {base_dir} + python -m venv venv + source venv/bin/activate + python -m pip install -U pip && python -m pip install visualfl_deploy-1.0.tar.gz + ``` + +2. generate deploy template + + 1) standalone deployment + ```bash + wefe_visualfl_deploy template standalone + ``` + 2) cluster deployment + ```bash + wefe_visualfl_deploy template cluster + ``` +3. read comments in generated template `standalone_template.yaml` or `template.yaml` and modify as you want. + +4. run deploy cmd + + 1) standalone deployment + ```bash + wefe_visualfl_deploy deploy deploy --config standalone_template.yaml + ``` + 2) cluster deployment + ```bash + wefe_visualfl_deploy deploy deploy --config template.yaml + ``` + +### Services start and stop + +Services could be start/stop with scripts in `VisualFL/script` or, use visualfl deploy toolkits: + +1.standalone deployment +```bash +wefe_visualfl_deploy services all start standalone_template.yaml +``` +1.cluster deployment +```bash +wefe_visualfl_deploy services all start template.yaml +``` + + +### Run examples + +Jobs could be submitted at each deployed machine with master service started. + +```bash +cd {base_dir} +source venv/bin/activate +export PYTHONPATH=$PYTHONPATH:{base_dir}/VisualFL +sh VisualFL/examples/paddle_clas/run.sh 127.0.0.1:10002 +``` diff --git a/VisualFL/deploy_tools/__init__.py b/VisualFL/deploy_tools/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/deploy_tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/deploy_tools/pyproject.toml b/VisualFL/deploy_tools/pyproject.toml new file mode 100755 index 000000000..07de284aa --- /dev/null +++ b/VisualFL/deploy_tools/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/VisualFL/deploy_tools/setup.cfg b/VisualFL/deploy_tools/setup.cfg new file mode 100755 index 000000000..580897dc2 --- /dev/null +++ b/VisualFL/deploy_tools/setup.cfg @@ -0,0 +1,15 @@ +[metadata] +name = visualfl_deploy +version = 1.0 + +[options] +packages = find: +include_package_data = True +install_requires = + fabric + typer + pyyaml + +[options.entry_points] +console_scripts = + wefe_visualfl_deploy = visualfl_deploy.cli:app \ No newline at end of file diff --git a/VisualFL/deploy_tools/setup.py b/VisualFL/deploy_tools/setup.py new file mode 100644 index 000000000..56dc47ff5 --- /dev/null +++ b/VisualFL/deploy_tools/setup.py @@ -0,0 +1,17 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import setuptools + +setuptools.setup() diff --git a/VisualFL/deploy_tools/visualfl_deploy/__init__.py b/VisualFL/deploy_tools/visualfl_deploy/__init__.py new file mode 100644 index 000000000..a6a150554 --- /dev/null +++ b/VisualFL/deploy_tools/visualfl_deploy/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +from pathlib import Path + +__base_dir__ = Path(__file__).parent +__visualfl_tarball__ = __base_dir__.joinpath("data", "visualfl.tar.gz").absolute() +__template__ = __base_dir__.joinpath("template").absolute() +__BASE_NAME__ = "VisualFL" diff --git a/VisualFL/deploy_tools/visualfl_deploy/_build.py b/VisualFL/deploy_tools/visualfl_deploy/_build.py new file mode 100644 index 000000000..523a528e5 --- /dev/null +++ b/VisualFL/deploy_tools/visualfl_deploy/_build.py @@ -0,0 +1,47 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tarfile + +from visualfl_deploy import __base_dir__, __visualfl_tarball__, __BASE_NAME__ + +_project_base = __base_dir__.parent.parent + + + +def _build(): + __visualfl_tarball__.parent.mkdir(exist_ok=True, parents=True) + with tarfile.open(__visualfl_tarball__, "w:gz") as tar: + for path in [ + _project_base.joinpath("visualfl"), + _project_base.joinpath("depends", "PaddleDetection", "ppdet"), + _project_base.joinpath("depends", "PaddleFL", "python", "paddle_fl","paddle_fl"), + _project_base.joinpath("data"), + _project_base.joinpath("examples"), + _project_base.joinpath("script"), + _project_base.joinpath("requirements.txt"), + _project_base.joinpath("config.properties"), + ]: + tar.add(path, f"{__BASE_NAME__}/{os.path.basename(path)}") + + +def _clean(): + if __visualfl_tarball__.exists(): + os.remove(__visualfl_tarball__) + + +if __name__ == "__main__": + _clean() + _build() diff --git a/VisualFL/deploy_tools/visualfl_deploy/_deploy.py b/VisualFL/deploy_tools/visualfl_deploy/_deploy.py new file mode 100644 index 000000000..e1dbc7c62 --- /dev/null +++ b/VisualFL/deploy_tools/visualfl_deploy/_deploy.py @@ -0,0 +1,101 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 The FedVision Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +import fabric +import typer +import yaml +from visualfl_deploy import __visualfl_tarball__, __BASE_NAME__ + +app = typer.Typer(help="deploy tools") + + +@app.command() +def deploy( + config: Path = typer.Option( + ..., + file_okay=True, + exists=True, + dir_okay=False, + readable=True, + resolve_path=True, + ) +): + with config.open() as f: + config_dict = yaml.safe_load(f) + machines = config_dict.get("machines", []) + typer.echo( + f"deploying {len(machines)} machines: {[machine['name'] for machine in machines]}" + ) + with typer.progressbar(machines, length=len(machines)) as bar: + for machine in bar: + _maybe_create_python_venv(machine) + _upload_code(machine) + _install_deps(machine) + typer.echo(f"deploy done") + + +def _upload_code(machine): + tarfile = os.path.abspath(__visualfl_tarball__) + base_dir = Path(machine["base_dir"]) + with fabric.Connection(machine["ssh_string"]) as c: + c.run(f"mkdir -p {base_dir}") + with c.cd(str(base_dir)): + c.put(tarfile, f"{base_dir}") + c.run(f"tar -xf {tarfile} -C {base_dir}") + c.run(f"rm {os.path.join(base_dir, os.path.basename(tarfile))}") + + +def _maybe_create_python_venv(machine: dict): + with fabric.Connection(machine["ssh_string"]) as c: + version = c.run( + f"{machine['python_for_venv_create']} " + f"-c 'import sys; assert sys.version_info.major >= 3 and sys.version_info.minor >= 7'", + warn=True, + ) + if version.failed: + raise RuntimeError(f"python executable {machine['python']} not valid") + + base_dir = Path(machine["base_dir"]) + c.run(f"mkdir -p {base_dir}") + with c.cd(str(base_dir)): + if c.run( + f"test -f {base_dir.joinpath('venv/bin/python')}", warn=True + ).failed: + c.run(f"{machine['python_for_venv_create']} -m venv venv") + + +def _install_deps(machine): + with fabric.Connection(machine["ssh_string"]) as c: + base_dir = Path(machine["base_dir"]) + with c.cd(str(base_dir)): + c.run(f"venv/bin/python -m pip install -U pip --quiet") + c.run( + f"venv/bin/python -m pip install -r {__BASE_NAME__}/requirements.txt --log deps_install.log --quiet" + ) diff --git a/VisualFL/deploy_tools/visualfl_deploy/_generate_template.py b/VisualFL/deploy_tools/visualfl_deploy/_generate_template.py new file mode 100644 index 000000000..3c11721d1 --- /dev/null +++ b/VisualFL/deploy_tools/visualfl_deploy/_generate_template.py @@ -0,0 +1,39 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil + +import typer +from visualfl_deploy import __template__ + +app = typer.Typer(help="template tools") + + +@app.command(name="cluster") +def generate(): + """ + generate template + """ + shutil.copy(os.path.join(__template__, "template.yaml"), os.getcwd()) + + +@app.command(name="standalone") +def standalone_template(): + """ + generate template for standalone deploy + """ + shutil.copy(os.path.join(__template__, "standalone_template.yaml"), os.getcwd()) + diff --git a/VisualFL/deploy_tools/visualfl_deploy/_service.py b/VisualFL/deploy_tools/visualfl_deploy/_service.py new file mode 100644 index 000000000..cd0f0e9c4 --- /dev/null +++ b/VisualFL/deploy_tools/visualfl_deploy/_service.py @@ -0,0 +1,306 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 The FedVision Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +import fabric +import typer +import yaml +from visualfl_deploy import __BASE_NAME__ + +app = typer.Typer(help="services [start|stop] tools") +all_app = typer.Typer(help="[start|stop] all services") + +cluster_manager_app = typer.Typer(help="[start|stop] cluster manager service") +cluster_worker_app = typer.Typer(help="[start|stop] cluster worker service") +master_app = typer.Typer(help="[start|stop] master service") +app.add_typer(all_app, name="all") +app.add_typer(cluster_manager_app, name="cluster-manager") +app.add_typer(cluster_worker_app, name="cluster-worker") +app.add_typer(master_app, name="master") + + +@all_app.command(name="start", help="start all services") +def start_all( + config: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False) +): + with config.open() as f: + config_dict = yaml.safe_load(f) + + machines_map = {} + for machine_config in config_dict["machines"]: + machines_map[machine_config["name"]] = machine_config + + # start cluster + cluster_address_map = {} + for cluster_config in config_dict.get("clusters", []): + cluster_name = cluster_config["name"] + typer.echo(f"starting cluster {cluster_name}") + manager_config = cluster_config["manager"] + manager_machine = machines_map[manager_config["machine"]] + status = start_cluster_manager( + manager_machine["ssh_string"], + manager_machine["base_dir"], + manager_config["port"], + ) + typer.echo(f"start cluster {cluster_name} done, success: {status}\n") + + cluster_address = f"{manager_machine['ip']}:{manager_config['port']}" + cluster_address_map[cluster_name] = cluster_address + + # if status: + typer.echo(f"starting cluster workers for cluster {cluster_name}") + for worker_config in cluster_config.get("workers", []): + typer.echo(f"starting worker {worker_config['name']}") + worker_machine = machines_map[worker_config["machine"]] + if "data_base_dir" in worker_machine: + data_base_dir = worker_machine["data_base_dir"] + else: + data_base_dir = None + status = start_cluster_worker( + machine_ssh=worker_machine["ssh_string"], + machine_base_dir=worker_machine["base_dir"], + name=worker_config["name"], + local_ip=worker_machine["ip"], + port_start=int(worker_config["ports"].split("-")[0]), + port_end=int(worker_config["ports"].split("-")[1]), + max_tasks=worker_config["max_tasks"], + cluster_manager_address=cluster_address, + data_base_dir=data_base_dir, + ) + typer.echo( + f"start worker {worker_config['name']} done, success: {status}\n" + ) + + # start master + for master_config in config_dict.get("masters", []): + typer.echo(f"starting master {master_config['name']}") + master_machine = machines_map[master_config["machine"]] + status = start_master( + machine_ssh=master_machine["ssh_string"], + machine_base_dir=master_machine["base_dir"], + submit_port=master_config["submit_port"], + member_id=master_config["name"], + cluster_manager_address=cluster_address_map[master_config["cluster"]], + local=master_config.get("local",False) + ) + typer.echo(f"start master {master_config['name']} done, success: {status}\n") + + typer.echo() + + +@all_app.command(name="stop", help="stop all services") +def stop_all( + config: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False) +): + with config.open() as f: + config_dict = yaml.safe_load(f) + + machines_map = {} + for machine_config in config_dict["machines"]: + machines_map[machine_config["name"]] = machine_config + + # stop cluster + for cluster_config in config_dict.get("clusters", []): + cluster_name = cluster_config["name"] + typer.echo(f"stopping cluster {cluster_name}") + manager_config = cluster_config["manager"] + manager_machine = machines_map[manager_config["machine"]] + + for worker_config in cluster_config.get("workers", []): + typer.echo(f"stopping worker {worker_config['name']}") + worker_machine = machines_map[worker_config["machine"]] + stop_cluster_worker( + machine_ssh=worker_machine["ssh_string"], + machine_base_dir=worker_machine["base_dir"], + name=worker_config["name"], + ) + stop_cluster_manager( + machine_ssh=manager_machine["ssh_string"], + machine_base_dir=manager_machine["base_dir"], + manager_port=manager_config["port"], + ) + + # stop master + for master_config in config_dict.get("masters", []): + typer.echo(f"stopping master {master_config['name']}") + master_machine = machines_map[master_config["machine"]] + stop_master( + machine_ssh=master_machine["ssh_string"], + machine_base_dir=master_machine["base_dir"], + submit_port=master_config["submit_port"], + ) + + typer.echo("stop all") + + +@cluster_manager_app.command(name="start", help="start cluster manager") +def start_cluster_manager( + machine_ssh: str = typer.Argument(..., help="machine ssh string: user@host:port"), + machine_base_dir: str = typer.Argument(..., help="deployed base name"), + manager_port: int = typer.Argument( + ..., help="port number for cluster manager to serve" + ), +): + with fabric.Connection(machine_ssh) as c: + with c.cd(machine_base_dir): + if c.run( + f"PYTHON_EXECUTABLE={os.path.join(machine_base_dir, 'venv/bin/python')} " + f"{__BASE_NAME__}/script/cluster_manager.sh start {manager_port}", + warn=True, + ).failed: + typer.echo(f"failed: can't start cluster manager") + return False + else: + typer.echo( + f"{machine_ssh}:{machine_base_dir} started cluster manager: port={manager_port}" + ) + return True + + +@cluster_manager_app.command(name="stop", help="stop cluster manager") +def stop_cluster_manager( + machine_ssh: str = typer.Argument(..., help="machine ssh string: user@host:port"), + machine_base_dir: str = typer.Argument(..., help="deployed base name"), + manager_port: int = typer.Argument( + ..., help="port number for cluster manager to serve" + ), +): + with fabric.Connection(machine_ssh) as c: + with c.cd(machine_base_dir): + if c.run( + f"PYTHON_EXECUTABLE={os.path.join(machine_base_dir, 'venv/bin/python')} " + f"{__BASE_NAME__}/script/cluster_manager.sh stop {manager_port}", + warn=True, + ).failed: + typer.echo(f"failed: can't stop cluster manager") + else: + typer.echo( + f"success: {machine_ssh}:{machine_base_dir} stop cluster manager: port={manager_port}" + ) + + +@cluster_worker_app.command(name="start", help="start cluster worker") +def start_cluster_worker( + machine_ssh: str = typer.Argument(..., help="machine ssh string: user@host:port"), + machine_base_dir: str = typer.Argument(..., help="deployed base name"), + name: str = typer.Argument(..., help="worker name"), + local_ip: str = typer.Argument(..., help="local ip"), + port_start: int = typer.Argument(..., help="port start"), + port_end: int = typer.Argument(..., help="port start"), + max_tasks: int = typer.Argument(..., help="num of maximum parallel tasks"), + cluster_manager_address=typer.Argument(..., help="cluster manager address"), + data_base_dir: str = typer.Option(None, "--data-dir", help="data dir"), +): + if data_base_dir is None or isinstance(data_base_dir, typer.params.OptionInfo): + data_base_dir = os.path.join(machine_base_dir, __BASE_NAME__, "data") + with fabric.Connection(machine_ssh) as c: + with c.cd(machine_base_dir): + if c.run( + f"PYTHON_EXECUTABLE={os.path.join(machine_base_dir, 'venv/bin/python')} " + f"{__BASE_NAME__}/script/cluster_worker.sh start " + f"{name} {local_ip} {port_start} {port_end} {max_tasks} {cluster_manager_address} {data_base_dir}", + warn=True, + ).failed: + typer.echo(f"failed: can't start cluster worker named {name}") + return False + else: + typer.echo( + f"{machine_ssh}:{machine_base_dir} started cluster worker: name={name}" + ) + return True + + +@cluster_worker_app.command(name="stop", help="stop cluster worker") +def stop_cluster_worker( + machine_ssh: str = typer.Argument(..., help="machine ssh string: user@host:port"), + machine_base_dir: str = typer.Argument(..., help="deployed base name"), + name: str = typer.Argument(..., help="worker name"), +): + with fabric.Connection(machine_ssh) as c: + with c.cd(machine_base_dir): + if c.run( + f"PYTHON_EXECUTABLE={os.path.join(machine_base_dir, 'venv/bin/python')} " + f"{__BASE_NAME__}/script/cluster_worker.sh stop {name}", + warn=True, + ).failed: + typer.echo(f"failed: can't stop cluster worker") + else: + typer.echo( + f"success: {machine_ssh}:{machine_base_dir} stop cluster worker: name={name}" + ) + + +@master_app.command(name="start", help="start master") +def start_master( + machine_ssh: str = typer.Argument(..., help="machine ssh string: user@host:port"), + machine_base_dir: str = typer.Argument(..., help="deployed base name"), + submit_port: int = typer.Argument(..., help="submit port"), + member_id: str = typer.Argument(..., help="party id"), + cluster_manager_address: str = typer.Argument(..., help="cluster manager address"), + local: bool = typer.Argument(..., help="is local template"), +): + with fabric.Connection(machine_ssh) as c: + with c.cd(machine_base_dir): + if c.run( + f"PYTHON_EXECUTABLE={os.path.join(machine_base_dir, 'venv/bin/python')} " + f"{__BASE_NAME__}/script/master.sh start " + f"{submit_port} {member_id} {cluster_manager_address} {local}", + warn=True, + ).failed: + typer.echo(f"failed: can't start master at port {submit_port}") + return False + else: + typer.echo( + f"{machine_ssh}:{machine_base_dir} started master: port={submit_port}" + ) + return True + + +@master_app.command(name="stop", help="stop master") +def stop_master( + machine_ssh: str = typer.Argument(..., help="machine ssh string: user@host:port"), + machine_base_dir: str = typer.Argument(..., help="deployed base name"), + submit_port: int = typer.Argument(..., help="submit port"), +): + with fabric.Connection(machine_ssh) as c: + with c.cd(machine_base_dir): + if c.run( + f"PYTHON_EXECUTABLE={os.path.join(machine_base_dir, 'venv/bin/python')} " + f"{__BASE_NAME__}/script/master.sh stop {submit_port}", + warn=True, + ).failed: + typer.echo(f"failed: can't stop master") + else: + typer.echo( + f"success: {machine_ssh}:{machine_base_dir} stop master: port={submit_port}" + ) + + +if __name__ == '__main__': + start_all() \ No newline at end of file diff --git a/VisualFL/deploy_tools/visualfl_deploy/cli.py b/VisualFL/deploy_tools/visualfl_deploy/cli.py new file mode 100644 index 000000000..beae99ee2 --- /dev/null +++ b/VisualFL/deploy_tools/visualfl_deploy/cli.py @@ -0,0 +1,26 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import typer +from visualfl_deploy import _deploy, _generate_template, _service + +app = typer.Typer() + +app.add_typer(_deploy.app, name="deploy") +app.add_typer(_service.app, name="services") +app.add_typer(_generate_template.app, name="template") + +if __name__ == "__main__": + app() diff --git a/VisualFL/deploy_tools/visualfl_deploy/template/__init__.py b/VisualFL/deploy_tools/visualfl_deploy/template/__init__.py new file mode 100644 index 000000000..ed5e36eba --- /dev/null +++ b/VisualFL/deploy_tools/visualfl_deploy/template/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/VisualFL/deploy_tools/visualfl_deploy/template/standalone_template.yaml b/VisualFL/deploy_tools/visualfl_deploy/template/standalone_template.yaml new file mode 100755 index 000000000..b76ac9aaa --- /dev/null +++ b/VisualFL/deploy_tools/visualfl_deploy/template/standalone_template.yaml @@ -0,0 +1,70 @@ +# standalone deploy config demo +# deploy machine: single +# coordinator: name=coordinator1, port=10000 +# clusters: single, name=cluster1, port=10001, single worker +# masters: 4, share cluster `cluster1` +# deploy structure: +# +#+----------------------------------------------------------------------------------------+ +#| | +#| +-----+------+ +-----+------+ +-----+------+ +-----+------+ | +#| | master1 | | master2 | | master3 | | master4 | | +#| | | | | | | | | | +#| +-----+------+ +-----+------+ +-----+------+ +-----+------+ | +#| | | | | | +#| v v v v | +#| +-----+-----------------+------------------+-----------------+-------+ | +#| | | | +#| | +-----------+ +-----------+ +-----------+ | | +#| | |worker1 | |worker2 | |worker3 | | | | | +#| | | | | | | | | | | | +#| | +-----------+ +-----------+ +-----------+ cluster1 | | +#| | | | +#| +--------------------------------------------------------------------+ | +#| | +#| machine1 | +#| | +#+----------------------------------------------------------------------------------------+ + + +machines: + - name: machine1 + ip: 127.0.0.1 + ssh_string: 127.0.0.1:22 + base_dir: /data + python_for_venv_create: python3 # use to create venv, python3.7+ required + + +clusters: + - name: cluster1 + manager: + machine: machine1 + port: 10001 + workers: + - name: worker1 + machine: machine1 + ports: 12000-12099 + max_tasks: 10 + - name: worker2 + machine: machine1 + ports: 12000-12099 + max_tasks: 10 + - name: worker3 + machine: machine1 + ports: 12000-12099 + max_tasks: 10 + - name: worker4 + machine: machine1 + ports: 12000-12099 + max_tasks: 10 + + +masters: + - name: master1 + machine: machine1 + submit_port: 10002 + cluster: cluster1 + local: false + + + diff --git a/VisualFL/deploy_tools/visualfl_deploy/template/template.yaml b/VisualFL/deploy_tools/visualfl_deploy/template/template.yaml new file mode 100644 index 000000000..8c9804e45 --- /dev/null +++ b/VisualFL/deploy_tools/visualfl_deploy/template/template.yaml @@ -0,0 +1,82 @@ +# cluster deploy config demo +# deploy machine: machine1, machine2, machine3 +# coordinator: name=coordinator1, port=10000, at machine3 +# clusters: +# - cluster1, name=cluster1, port=10001, workers: worker1, worker2 at machine1 +# - cluster2, name=cluster2, port=10001, workers: worker1, worker2 at machine2 +# masters: 2 +# - master1, use cluster `cluster1`, connect to coordinator `coordinator1` at machine1 +# - master2, use cluster `cluster2`, connect to coordinator `coordinator1` at machine1 + # deploy structure: + +#+---------------------------------------------+ | +----------------------------------------------+ +#| | | | | +#| +-------------+-----------+------+-------------+ | +#| | | | | | | | +#| | master1 | | | | master2 | | +#| +-------------+ | | +-------------+ | +#| +---------------------------------------+ | | +----------------------------------------+ | +#| | | | | | | | +#| | +-------------+ +-------------+ | | | | +-------------+ +-------------+ | | +#| | | worker1 | | worker2 | | | | | | worker1 | | worker2 | | | +#| | +-------------+ +-------------+ | | | | +-------------+ +-------------+ | | +#| | cluster1| | | | cluster2 | | +#| +---------------------------------------+ | | +----------------------------------------+ | +#| | | | +#| machine1 | | machine2 | +#+---------------------------------------------+ +----------------------------------------------+ + +machines: + - name: machine1 + ip: xxx + ssh_string: xxx@xxxx # [user@]ip:port + base_dir: /data/visualfl + python_for_venv_create: python3 # use to create venv + + - name: machine2 + ip: xxx + ssh_string: xxx@xxxx + base_dir: /data/visualfl + python_for_venv_create: python3 + + +clusters: + - name: cluster1 + manager: + machine: machine1 + port: 10001 + workers: + - name: worker1 + machine: machine1 + ports: 12000-12999 + max_tasks: 10 + - name: worker2 + machine: machine1 + ports: 13000-13999 + max_tasks: 10 + + - name: cluster2 + manager: + machine: machine2 + port: 10001 + workers: + - name: worker1 + machine: machine2 + ports: 12000-12999 + max_tasks: 10 + - name: worker2 + machine: machine2 + ports: 13000-13999 + max_tasks: 10 + +masters: + - name: master1 + machine: machine1 + submit_port: 10002 + cluster: cluster1 + + - name: master2 + machine: machine2 + submit_port: 10002 + cluster: cluster2 + diff --git a/VisualFL/examples/paddle_clas/apply.sh b/VisualFL/examples/paddle_clas/apply.sh new file mode 100755 index 000000000..5cc1c9d60 --- /dev/null +++ b/VisualFL/examples/paddle_clas/apply.sh @@ -0,0 +1,18 @@ +#!/bin/bash + + + +# export Environment Variables +# PYTHONPATH Python default search path for module files, PaddleFL, PaddleDetection +# PYTHON_EXECUTABLE python executable path, such as python | python3 | venv/bin/python + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" + +usage="Usage: run.sh " + +if [ $# -le 0 ]; then + echo "$usage" + exit 1 +fi + +python -m visualfl.client.apply apply --config "${DIR}/config.yaml" --endpoint "$1" diff --git a/VisualFL/examples/paddle_clas/config.yaml b/VisualFL/examples/paddle_clas/config.yaml new file mode 100755 index 000000000..4a439b720 --- /dev/null +++ b/VisualFL/examples/paddle_clas/config.yaml @@ -0,0 +1,28 @@ +job_id: 06ac2812dc004aa38ee1de9588dfdac8 +task_id: 06ac2812dc004aa38ee1de9588dfdac8 +job_type: paddle_fl +role: promoter +member_id: master1 +callback_url: https://www.xxx.com +env: + worker_num: 2 + local_worker_num: 2 + local_trainer_indexs: [0,1] + device: cpu + use_vdl: true + server_endpoint: 127.0.0.1:12000 + aggregator_endpoint: 127.0.0.1:12001 + aggregator_assignee: worker2 +data_set: + name: test + download_url: http://xxx.com +algorithm_config: + program: paddle_clas + max_iter: 10 + inner_step: 10 + architecture: LeNet + num_classes: 102 + base_lr: 0.01 + batch_size: 128 + need_shuffle: True + image_shape: [3, 224, 224] diff --git a/VisualFL/examples/paddle_clas/infer.sh b/VisualFL/examples/paddle_clas/infer.sh new file mode 100755 index 000000000..c3cf33fa9 --- /dev/null +++ b/VisualFL/examples/paddle_clas/infer.sh @@ -0,0 +1,18 @@ +#!/bin/bash + + + +# export Environment Variables +# PYTHONPATH Python default search path for module files, PaddleFL, PaddleDetection +# PYTHON_EXECUTABLE python executable path, such as python | python3 | venv/bin/python + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" + +usage="Usage: run.sh " + +if [ $# -le 0 ]; then + echo "$usage" + exit 1 +fi + +python -m visualfl.client.infer infer --config "${DIR}/config.yaml" --endpoint "$1" diff --git a/VisualFL/examples/paddle_clas/run.sh b/VisualFL/examples/paddle_clas/run.sh new file mode 100755 index 000000000..d039114b5 --- /dev/null +++ b/VisualFL/examples/paddle_clas/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash + + +# export Environment Variables +# PYTHONPATH Python default search path for module files, PaddleFL, PaddleDetection +# PYTHON_EXECUTABLE python executable path, such as python | python3 | venv/bin/python + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" + +usage="Usage: run.sh " + +if [ $# -le 0 ]; then + echo "$usage" + exit 1 +fi + +python -m visualfl.client.submitter submit --config "${DIR}/config.yaml" --endpoint "$1" diff --git a/VisualFL/examples/paddle_detection/apply.sh b/VisualFL/examples/paddle_detection/apply.sh new file mode 100755 index 000000000..5cc1c9d60 --- /dev/null +++ b/VisualFL/examples/paddle_detection/apply.sh @@ -0,0 +1,18 @@ +#!/bin/bash + + + +# export Environment Variables +# PYTHONPATH Python default search path for module files, PaddleFL, PaddleDetection +# PYTHON_EXECUTABLE python executable path, such as python | python3 | venv/bin/python + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" + +usage="Usage: run.sh " + +if [ $# -le 0 ]; then + echo "$usage" + exit 1 +fi + +python -m visualfl.client.apply apply --config "${DIR}/config.yaml" --endpoint "$1" diff --git a/VisualFL/examples/paddle_detection/config.yaml b/VisualFL/examples/paddle_detection/config.yaml new file mode 100755 index 000000000..e612d8e45 --- /dev/null +++ b/VisualFL/examples/paddle_detection/config.yaml @@ -0,0 +1,28 @@ +job_id: job_detection0001 +task_id: job_detection0001 +job_type: paddle_fl +role: promoter +member_id: master1 +callback_url: https://www.xxx.com +env: + worker_num: 2 + local_worker_num: 2 + local_trainer_indexs: [0,1] + device: cpu + use_vdl: true + server_endpoint: 127.0.0.1:12000 + aggregator_endpoint: 127.0.0.1:12003 + aggregator_assignee: worker2 +data_set: + name: xxx + download_url: http://x.x +algorithm_config: + program: paddle_detection + max_iter: 1000 + inner_step: 10 + architecture: yolov3 + num_classes: 3 + base_lr: 0.01 + batch_size: 1 + need_shuffle: True + image_shape: [3, 608, 608] diff --git a/VisualFL/examples/paddle_detection/infer.sh b/VisualFL/examples/paddle_detection/infer.sh new file mode 100755 index 000000000..30d96fc8d --- /dev/null +++ b/VisualFL/examples/paddle_detection/infer.sh @@ -0,0 +1,18 @@ +#!/bin/bash + + + +# export Environment Variables +# PYTHONPATH Python default search path for module files, PaddleFL, PaddleDetection +# PYTHON_EXECUTABLE python executable path, such as python | python3 | venv/bin/python + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" + +usage="Usage: run.sh " + +if [ $# -le 0 ]; then + echo "$usage" + exit 1 +fi + +python -m visualfl.client.infer infer --config "${DIR}/config.yaml" --endpoint "$1" diff --git a/VisualFL/examples/paddle_detection/run.sh b/VisualFL/examples/paddle_detection/run.sh new file mode 100755 index 000000000..4a44b2f04 --- /dev/null +++ b/VisualFL/examples/paddle_detection/run.sh @@ -0,0 +1,18 @@ +#!/bin/bash + + + +# export Environment Variables +# PYTHONPATH Python default search path for module files,PaddleFL, PaddleDetection +# PYTHON_EXECUTABLE python executable path, such as python | python3 | venv/bin/python + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" + +usage="Usage: run.sh " + +if [ $# -le 0 ]; then + echo "$usage" + exit 1 +fi + +python -m visualfl.client.submitter submit --config "${DIR}/config.yaml" --endpoint "$1" diff --git a/VisualFL/examples/paddle_detection/yolov3_mobilenet_v1_fruit.yml b/VisualFL/examples/paddle_detection/yolov3_mobilenet_v1_fruit.yml new file mode 100755 index 000000000..7134bb46b --- /dev/null +++ b/VisualFL/examples/paddle_detection/yolov3_mobilenet_v1_fruit.yml @@ -0,0 +1,136 @@ +architecture: YOLOv3 +use_gpu: true +max_iters: 20000 +log_iter: 20 +save_dir: output +snapshot_iter: 200 +metric: VOC +map_type: 11point +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar +weights: output/yolov3_mobilenet_v1_fruit/best_model +num_classes: 3 +finetune_exclude_pretrained_params: ["yolo_output"] +use_fine_grained_loss: false + +YOLOv3: + backbone: MobileNet + yolo_head: YOLOv3Head + +MobileNet: + norm_type: sync_bn + norm_decay: 0. + conv_group_scale: 1 + with_extra_blocks: false + +YOLOv3Head: + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: + [ + [10, 13], + [16, 30], + [33, 23], + [30, 61], + [62, 45], + [59, 119], + [116, 90], + [156, 198], + [373, 326], + ] + norm_decay: 0. + yolo_loss: YOLOv3Loss + nms: + background_label: -1 + keep_top_k: 100 + nms_threshold: 0.45 + nms_top_k: 1000 + normalized: false + score_threshold: 0.01 + +YOLOv3Loss: + ignore_thresh: 0.7 + label_smooth: true + +LearningRate: + base_lr: 0.00001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 15000 + - 18000 + - !LinearWarmup + start_factor: 0. + steps: 100 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +#_READER_: 'yolov3_reader.yml' +# will merge TrainReader into yolov3_reader.yml +TrainReader: + inputs_def: + image_shape: [3, 608, 608] + fields: ["image", "gt_bbox", "gt_class", "gt_score"] + num_max_boxes: 50 + use_dataloader: false + dataset: !VOCDataSet + dataset_dir: fruit + anno_path: train.txt + with_background: false + use_default_label: false + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !ExpandImage + max_ratio: 4.0 + mean: [123.675, 116.28, 103.53] + prob: 0.5 + - !RandomInterpImage + max_size: 0 + target_size: 608 + - !RandomFlipImage + is_normalized: true + prob: 0.5 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: true + is_channel_first: false + - !PadBox + num_max_boxes: 50 + - !BboxXYXY2XYWH {} + batch_transforms: + - !RandomShape + sizes: [608] + - !Permute + channel_first: true + to_bgr: false + batch_size: 1 + shuffle: true + mixup_epoch: -1 + +EvalReader: + batch_size: 1 + inputs_def: + image_shape: [3, 608, 608] + fields: ["image", "im_size", "im_id", "gt_bbox", "gt_class", "is_difficult"] + num_max_boxes: 50 + dataset: !VOCDataSet + dataset_dir: fruit + anno_path: val.txt + use_default_label: false + with_background: false + +TestReader: + batch_size: 1 + dataset: !ImageFolder + anno_path: fruit/label_list.txt + use_default_label: false + with_background: false diff --git a/VisualFL/requirements.txt b/VisualFL/requirements.txt new file mode 100755 index 000000000..999017a78 --- /dev/null +++ b/VisualFL/requirements.txt @@ -0,0 +1,15 @@ +attr>=0.3.1 +grpcio>=1.33.2,!=1.34.0 +aiohttp>=3.7,<3.8 +aiofiles==0.8.0 +loguru>=0.5 +protobuf==3.14.0 +jsonschema==3.2.0 +PyYAML>=5.3.1 +click==7.1.2 +paddlepaddle==1.8.5 +peewee==3.13.3 +cachetools==4.1.1 +# enhance +visualdl==2.0.5 +paddle-serving-client==0.4.0 \ No newline at end of file diff --git a/VisualFL/requirements_dev.txt b/VisualFL/requirements_dev.txt new file mode 100755 index 000000000..b29515049 --- /dev/null +++ b/VisualFL/requirements_dev.txt @@ -0,0 +1,4 @@ +pre-commit +grpcio-tools>=1.33.2,!=1.34.0 +requests +typer diff --git a/VisualFL/script/cluster_manager.sh b/VisualFL/script/cluster_manager.sh new file mode 100755 index 000000000..4b2b48b9c --- /dev/null +++ b/VisualFL/script/cluster_manager.sh @@ -0,0 +1,48 @@ +#!/bin/bash + + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" +PROJECT_BASE=$(dirname "${DIR}") + +# shellcheck source=env.sh +. "${PROJECT_BASE}/script/env.sh" +# shellcheck source=service.sh +. "${PROJECT_BASE}/script/service.sh" + +usage="Usage: cluster_manager.sh (start|stop) " +if [ $# -le 1 ]; then + echo "$usage" + exit 1 +fi + +if [ -z "${PYTHON_EXECUTABLE}" ]; then + echo "python executable not set" + exit 1 +fi + +start_cluster_manager() { + local re='^[0-9]+$' + if ! [[ $1 =~ $re ]]; then + echo "error: port should be number" >&2 + echo "$usage" + exit 1 + fi + mkdir -p "$PROJECT_BASE/logs/nohup" + nohup "${PYTHON_EXECUTABLE}" -m visualfl.client.cluster_manager --port "${1}" >>"${PROJECT_BASE}/logs/nohup/manager" 2>&1 & +} + +case "$1" in +start) + start_service "$2" clustermanager start_cluster_manager "$2" + exit 0 + ;; +stop) + stop_service_by_port "$2" clustermanager + exit 0 + ;; +*) + echo bad command + echo "$usage" + exit 1 + ;; +esac diff --git a/VisualFL/script/cluster_worker.sh b/VisualFL/script/cluster_worker.sh new file mode 100755 index 000000000..63b1c89c6 --- /dev/null +++ b/VisualFL/script/cluster_worker.sh @@ -0,0 +1,96 @@ +#!/bin/bash + + + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" +PROJECT_BASE=$(dirname "${DIR}") + + +# shellcheck source=env.sh +. "${PROJECT_BASE}/script/env.sh" +# shellcheck source=service.sh +. "${PROJECT_BASE}/script/service.sh" + +usage="Usage: cluster_worker.sh (start|stop) [ ]" +if [ $# -le 1 ]; then + echo "$usage" + exit 1 +fi + +if [ -z "${PYTHON_EXECUTABLE}" ]; then + echo "python executable not set" + exit 1 +fi + +start_cluster_worker() { + local pid + pid=$( + ps aux | grep "visualfl.client.cluster_worker" | grep "name ${1}" | grep -v grep | awk '{print $2}' + ) + if [[ -z ${pid} ]]; then + mkdir -p "$PROJECT_BASE/logs/nohup" + nohup "${PYTHON_EXECUTABLE}" -m visualfl.client.cluster_worker --name "$1" --worker-ip "$2" --port-start "$3" --port-end "$4" --max-tasks "$5" --manager-address "$6" --data-base-dir "$7" >>"${PROJECT_BASE}/logs/nohup/worker" 2>&1 & + for ((i = 1; i <= 20; i++)); do + sleep 0.1 + pid=$( + ps aux | grep "visualfl.client.cluster_worker" | grep "name ${1}" | grep -v grep | awk '{print $2}' + ) + if [[ -n ${pid} ]]; then + echo "cluster worker service start successfully. pid: ${pid}" + exit 0 + fi + done + echo "cluster worker service start failed" + exit 1 + else + echo "cluster worker service named <$1> already started. pid: $pid" + exit 1 + fi +} + +stop_cluster_worker() { + local pid + pid=$( + ps aux | grep "visualfl.client.cluster_worker" | grep "name ${1}" | grep -v grep | awk '{print $2}' + ) + if [[ -n ${pid} ]]; then + echo "killing: $(ps aux | grep "${pid}" | grep -v grep)" + for ((i = 1; i <= 20; i++)); do + sleep 0.1 + kill "${pid}" + pid=$( + ps aux | grep "visualfl.client.cluster_worker" | grep "name ${1}" | grep -v grep | awk '{print $2}' + ) + if [[ -z ${pid} ]]; then + echo "killed by SIGTERM" + exit 0 + fi + done + if [[ $(kill -9 "${pid}") -eq 0 ]]; then + echo "killed by SIGKILL" + exit 0 + else + echo "Kill error" + exit 1 + fi + else + echo "cluster worker named <${1}> service not running" + exit 1 + fi +} + +case "$1" in +start) + start_cluster_worker "$2" "$3" "$4" "$5" "$6" "$7" "$8" + exit 0 + ;; +stop) + stop_cluster_worker "$2" + exit 0 + ;; +*) + echo bad command + echo "$usage" + exit 1 + ;; +esac diff --git a/VisualFL/script/env.sh b/VisualFL/script/env.sh new file mode 100755 index 000000000..349b4b63c --- /dev/null +++ b/VisualFL/script/env.sh @@ -0,0 +1,24 @@ +#!/bin/bash + + + +# export Environment Variables +# PYTHONPATH Python default search path for module files, add Fedvision, PaddleFL, PaddleDetection +# PYTHON_EXECUTABLE python executable path, such as python | python3 | venv/bin/python + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" +PROJECT_BASE=$(dirname "${DIR}") + +unset PYTHONPATH + +DEPS_DIR="${PROJECT_BASE}/depends" +if [ -d "${DEPS_DIR}" ]; then + # development layout + export PYTHONPATH="${PROJECT_BASE}:${DEPS_DIR}/PaddleDetection:${DEPS_DIR}/PaddleFL/python:${PYTHONPATH}" +else + export PYTHONPATH="${PROJECT_BASE}:${PYTHONPATH}" +fi + +if [ -z "${PYTHON_EXECUTABLE}" ]; then + export PYTHON_EXECUTABLE= +fi diff --git a/VisualFL/script/master.sh b/VisualFL/script/master.sh new file mode 100755 index 000000000..6928eb0cb --- /dev/null +++ b/VisualFL/script/master.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +# Copyright (c) 2020 The FedVision Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DIR="$(cd "$(dirname "$0")" >/dev/null 2>&1 && pwd)" +PROJECT_BASE=$(dirname "${DIR}") + +# shellcheck source=env.sh +. "${PROJECT_BASE}/script/env.sh" +# shellcheck source=service.sh +. "${PROJECT_BASE}/script/service.sh" + +usage="Usage: [PYTHON_EXECUTABLE=...] master.sh (start|stop) [ ]" +if [ $# -le 1 ]; then + echo "$usage" + exit 1 +fi + +if [ -z "${PYTHON_EXECUTABLE}" ]; then + echo "fedvision python executable not set" + exit 1 +fi + +start_master() { + local re='^[0-9]+$' + if ! [[ $1 =~ $re ]]; then + echo "error: port should be number" >&2 + echo "$usage" + exit 1 + fi + mkdir -p "$PROJECT_BASE/logs/nohup" + nohup "${PYTHON_EXECUTABLE}" -m visualfl.client.master --submitter-port "$1" --member-id "$2" --cluster-address "$3" --local "$4" >>"${PROJECT_BASE}/logs/nohup/master" 2>&1 & +} + +case "$1" in +start) + start_service "$2" master start_master "$2" "$3" "$4" "$5" + exit 0 + ;; +stop) + stop_service_by_port "$2" master + exit 0 + ;; +*) + echo bad command + echo "$usage" + exit 1 + ;; +esac diff --git a/VisualFL/script/service.sh b/VisualFL/script/service.sh new file mode 100755 index 000000000..fd91c07c8 --- /dev/null +++ b/VisualFL/script/service.sh @@ -0,0 +1,50 @@ +#!/bin/bash + + + +# start_service +start_service() { + local pid + pid=$(lsof -i:"$1" | grep 'LISTEN' | awk '{print $2}' | uniq) + if [[ -z ${pid} ]]; then + $3 "${@:4}" + for ((i = 1; i <= 20; i++)); do + sleep 0.1 + pid=$(lsof -i:"$1" | grep 'LISTEN' | awk '{print $2}' | uniq) + if [[ -n ${pid} ]]; then + echo "$2 service start successfully. pid: ${pid}" + exit 0 + fi + done + echo "$2 service start failed" + exit 1 + else + echo "$2 service already started at port $1. pid: $pid" + exit 1 + fi +} + +# stop_service +stop_service_by_port() { + local pid + pid=$(lsof -i:"$1" | grep 'LISTEN' | awk '{print $2}' | uniq) + if [[ -n ${pid} ]]; then + echo "killing: $(ps aux | grep "${pid}" | grep -v grep)" + for ((i = 1; i <= 20; i++)); do + sleep 0.1 + kill "$pid" + pid=$(lsof -i:"$1" | grep 'LISTEN' | awk '{print $2}' | uniq) + if [[ -z ${pid} ]]; then + echo "killed by SIGTERM" + exit 0 + fi + done + echo $pid + kill -9 "$pid" + echo "killed by SIGKILL" + exit 0 + else + echo "$2 service not running" + exit 1 + fi +} diff --git a/VisualFL/visualfl/__init__.py b/VisualFL/visualfl/__init__.py new file mode 100644 index 000000000..692c8e43f --- /dev/null +++ b/VisualFL/visualfl/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +__version__ = "1.0" +__basedir__ = os.path.dirname(os.path.abspath(__file__)) +__logs_dir__ = os.path.abspath(os.path.join(__basedir__, os.path.pardir, "logs")) +__config_path__ = os.path.abspath(os.path.join(__basedir__, os.path.pardir, "config.properties")) +# __fl_job_config_dir__ = os.path.abspath(os.path.join(__basedir__, os.path.pardir, "visualfl/fl_job_config")) +__data_dir__ = os.path.abspath(os.path.join(__basedir__, os.path.pardir, "data")) + +VISUALFL_DATA_BASE_ENV = "VISUALFL_DATA_BASE_ENV" + + +def get_data_dir(): + if VISUALFL_DATA_BASE_ENV in os.environ and os.path.exists( + os.environ.get(VISUALFL_DATA_BASE_ENV) + ): + return os.path.abspath(os.environ.get(VISUALFL_DATA_BASE_ENV)) + else: + return __data_dir__ + + +if __name__ == '__main__': + print(get_data_dir()) \ No newline at end of file diff --git a/VisualFL/visualfl/algorithm/__init__.py b/VisualFL/visualfl/algorithm/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/algorithm/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/algorithm/paddle_clas/README.md b/VisualFL/visualfl/algorithm/paddle_clas/README.md new file mode 100755 index 000000000..5213783b3 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/README.md @@ -0,0 +1,7 @@ +## Note + +This folder implement cnn algorithms directly by paddle_fl. + +## How to use + +See [example](../../../../examples/paddle_clas) diff --git a/VisualFL/visualfl/algorithm/paddle_clas/__init__.py b/VisualFL/visualfl/algorithm/paddle_clas/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/algorithm/paddle_clas/cv_utils.py b/VisualFL/visualfl/algorithm/paddle_clas/cv_utils.py new file mode 100644 index 000000000..f5804921e --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/cv_utils.py @@ -0,0 +1,98 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +class DecodeImage(object): + def __init__(self, to_rgb=True): + self.to_rgb = to_rgb + + def __call__(self, img): + data = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(data, 1) + if self.to_rgb: + assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( + img.shape) + img = img[:, :, ::-1] + + return img + + +class ResizeImage(object): + def __init__(self, resize_short=None): + self.resize_short = resize_short + + def __call__(self, img): + img_h, img_w = img.shape[:2] + percent = float(self.resize_short) / min(img_w, img_h) + w = int(round(img_w * percent)) + h = int(round(img_h * percent)) + return cv2.resize(img, (w, h)) + + +class CropImage(object): + def __init__(self, size): + if type(size) is int: + self.size = (size, size) + else: + self.size = size + + def __call__(self, img): + w, h = self.size + img_h, img_w = img.shape[:2] + w_start = (img_w - w) // 2 + h_start = (img_h - h) // 2 + + w_end = w_start + w + h_end = h_start + h + return img[h_start:h_end, w_start:w_end, :] + + +class NormalizeImage(object): + def __init__(self, scale=None, mean=None, std=None): + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, img): + return (img.astype('float32') * self.scale - self.mean) / self.std + + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, img): + img = img.transpose((2, 0, 1)) + return img diff --git a/VisualFL/visualfl/algorithm/paddle_clas/fl_master.py b/VisualFL/visualfl/algorithm/paddle_clas/fl_master.py new file mode 100644 index 000000000..ae99a24e0 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/fl_master.py @@ -0,0 +1,120 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import importlib + +import click +from paddle import fluid +from paddle_fl.core.master.job_generator import JobGenerator +from paddle_fl.core.strategy.fl_strategy_base import ( + FedAvgStrategy, +) + +class Model(object): + def __init__(self): + self.feeds = None + self.startup_program = None + self.loss = None + + + def build_program(self,inputs,label,num_classes,architecture='CNN'): + module = importlib.import_module('visualfl.algorithm.paddle_clas.models') + model = getattr(module, architecture)() + out = model.net(input=inputs,class_dim=num_classes) + + predict = fluid.layers.softmax(out) + cost = fluid.layers.cross_entropy(input=predict, label=label) + accuracy = fluid.layers.accuracy(input=predict, label=label) + self.loss = fluid.layers.mean(cost) + self.startup_program = fluid.default_startup_program() + + self.feeds = [inputs, label] + self.targets = [self.loss, accuracy] + + +@click.command() +@click.option("--ps-endpoint", type=str, required=True) +@click.option( + "-c", + "--config", + type=click.Path(file_okay=True, dir_okay=False, exists=True), + required=True, +) +@click.option( + "--algorithm-config", type=click.Path(exists=True, file_okay=True, dir_okay=False) +) +def fl_master(algorithm_config, ps_endpoint, config): + logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s-%(levelname)s: %(message)s" + ) + + logger = logging.getLogger(__name__) # noqa: F841 + with open(config) as f: + config_json = json.load(f) + worker_num = config_json["worker_num"] + + with open(algorithm_config) as f: + algorithm_config_dict = json.load(f) + + inner_step = algorithm_config_dict.get("inner_step") + base_lr = algorithm_config_dict.get("base_lr", 0.001) + image_shape = algorithm_config_dict.get("image_shape") + num_classes = algorithm_config_dict.get("num_classes") + architecture = algorithm_config_dict.get("architecture") + + inputs = fluid.layers.data(name="img", shape=image_shape, dtype="float64") + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + model = Model() + model.build_program(inputs,label,num_classes,architecture) + + job_generator = JobGenerator() + job_generator.set_losses([model.loss]) + job_generator.set_optimizer(fluid.optimizer.Adam(base_lr)) + job_generator.set_startup_program(model.startup_program) + job_generator.set_infer_feed_and_target_names( + [feed.name for feed in model.feeds], [target.name for target in model.targets] + ) + job_generator.set_feeds(model.feeds) + + strategy = FedAvgStrategy() + strategy.fed_avg = True + strategy._inner_step = inner_step + + endpoints = [ps_endpoint] + output = "compile" + job_generator.generate_fl_job( + strategy, server_endpoints=endpoints, worker_num=worker_num, output=output + ) + + +if __name__ == "__main__": + fl_master() + + diff --git a/VisualFL/visualfl/algorithm/paddle_clas/fl_trainer.py b/VisualFL/visualfl/algorithm/paddle_clas/fl_trainer.py new file mode 100644 index 000000000..70e2fe011 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/fl_trainer.py @@ -0,0 +1,250 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 The FedVision Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import click +import paddle +from visualfl.paddle_fl.trainer._trainer import FedAvgTrainer +from visualfl.utils import data_loader +from visualfl.utils.tools import * +from visualfl.utils.consts import TaskStatus,ComponentName,TaskResultType + +@click.command() +@click.option("--job-id", type=str, required=True) +@click.option("--task-id", type=str, required=True) +@click.option("--scheduler-ep", type=str, required=True) +@click.option("--trainer-id", type=int, required=True) +@click.option("--trainer-ep", type=str, required=True) +@click.option( + "--main-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--startup-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--send-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--recv-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--feed-names", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--target-names", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--strategy", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--feeds", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--config", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--algorithm-config", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) + +def fl_trainer( + job_id: str, + task_id: str, + trainer_id: int, + trainer_ep: str, + scheduler_ep: str, + main_program, + startup_program, + send_program, + recv_program, + feed_names, + target_names, + strategy, + feeds, + config, + algorithm_config, +): + import numpy as np + import paddle.fluid as fluid + from visualfl import get_data_dir + from ppdet.utils import checkpoint + + logging.basicConfig( + filename="trainer.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG, + ) + + try: + + with open(config) as f: + config_json = json.load(f) + device = config_json.get("device", "cpu") + use_vdl = config_json.get("use_vdl", False) + resume_checkpoint = config_json.get("resume", False) + save_model_dir = "model" + save_checkpoint_dir = "checkpoint" + + + with open(algorithm_config) as f: + algorithm_config_dict = json.load(f) + + batch_size = algorithm_config_dict.get("batch_size", 1024) + need_shuffle = algorithm_config_dict.get("need_shuffle", True) + max_iter = algorithm_config_dict.get("max_iter") + download_url = algorithm_config_dict.get("download_url") + data_name = algorithm_config_dict.get("data_name") + + logging.debug(f"training program begin") + trainer = FedAvgTrainer(scheduler_ep=scheduler_ep, trainer_ep=trainer_ep) + logging.debug(f"job program loading") + trainer.load_job( + main_program=main_program, + startup_program=startup_program, + send_program=send_program, + recv_program=recv_program, + feed_names=feed_names, + target_names=target_names, + strategy=strategy, + ) + logging.debug(f"job program loaded") + place = fluid.CPUPlace() if device != "cuda" else fluid.CUDAPlace(0) + + logging.debug(f"trainer starting with place {place}") + trainer.start(place) + logging.debug(f"trainer stared") + + logging.debug(f"loading data") + feed_list = trainer.load_feed_list(feeds) + feeder = fluid.DataFeeder(feed_list=feed_list, place=place) + logging.debug(f"data loader ready") + + data_dir = data_loader.job_download(download_url, job_id, get_data_dir()) + labelpath = os.path.join(data_dir,"label_list.txt") + TaskDao(task_id).save_task_result({"label_path":labelpath},ComponentName.CLASSIFY,TaskResultType.LABEL) + reader = data_loader.train(data_dir=data_dir) + if need_shuffle: + reader = fluid.io.shuffle( + reader=reader, + buf_size=1000, + ) + train_loader = paddle.batch(reader=reader, batch_size=batch_size) + + epoch_id = -1 + step = 0 + TaskDao(task_id).init_task_progress(max_iter) + TaskDao(task_id).start_task() + # if resume_checkpoint: + # try: + # epoch_id = TaskDao(task_id).get_task_progress() + # checkpoint.load_checkpoint(trainer.exe, trainer._main_program, f"checkpoint/{epoch_id}") + # logging.debug(f"use_checkpoint epoch_id: {epoch_id}") + # except Exception as e: + # logging.error(f"task id {task_id} train error {e}") + # raise Exception(f"train error as task id {task_id} ") + + if use_vdl: + from visualdl import LogWriter + vdl_writer = LogWriter("vdl_log") + + while epoch_id < max_iter: + epoch_id += 1 + if not trainer.scheduler_agent.join(epoch_id): + logging.debug(f"not join, waiting next round") + continue + + logging.debug(f"epoch {epoch_id} start train") + + for step_id, data in enumerate(train_loader()): + outs = trainer.run(feeder.feed(data), fetch=trainer._target_names) + if use_vdl: + stats = { + k: np.array(v).mean() for k, v in zip(trainer._target_names, outs) + } + for loss_name, loss_value in stats.items(): + vdl_writer.add_scalar(loss_name, loss_value, step) + save_data_to_db(task_id, loss_name,loss_value,step,ComponentName.CLASSIFY) + step += 1 + logging.debug(f"step: {step}, outs: {outs}") + + # save model + logging.debug(f"saving model at {epoch_id}-th epoch") + trainer.save_model(os.path.join(save_model_dir,str(epoch_id))) + + # info scheduler + trainer.scheduler_agent.finish() + checkpoint.save(trainer.exe, trainer._main_program, os.path.join(save_checkpoint_dir,str(epoch_id))) + TaskDao(task_id).add_task_progress(1) + + TaskDao(task_id).update_task_status(TaskStatus.SUCCESS) + TaskDao(task_id).finish_task_progress() + TaskDao(task_id).update_serving_model(type=TaskResultType.LOSS) + logging.debug(f"reach max iter, finish training") + except Exception as e: + logging.error(f"task id {task_id} train error {e}") + TaskDao(task_id).update_task_status(TaskStatus.ERROR, str(e)) + raise Exception(f"train error as task id {task_id} ") + + +if __name__ == "__main__": + fl_trainer() diff --git a/VisualFL/visualfl/algorithm/paddle_clas/infer.py b/VisualFL/visualfl/algorithm/paddle_clas/infer.py new file mode 100644 index 000000000..e83e43062 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/infer.py @@ -0,0 +1,201 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import json +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../../'))) +from ppdet.utils import checkpoint +import argparse +import numpy as np +import paddle +import paddle.fluid as fluid +from visualfl.algorithm.paddle_clas import models +import cv_utils as utils +from visualfl.db.task_dao import TaskDao +from visualfl.utils.consts import ComponentName,TaskResultType +import logging + +logging.basicConfig( + filename="infer.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG, +) + + +def parse_args(): + def str2bool(v): + return v.lower() in ("true", "t", "1") + + parser = argparse.ArgumentParser() + parser.add_argument("--job_id", type=str, required=False) + parser.add_argument("--task_id", type=str,required=True) + parser.add_argument("-i", "--infer_dir", type=str,required=True) + parser.add_argument("-c", "--config", type=str,required=True) + parser.add_argument("--weights", type=str,required=True) + parser.add_argument("--output_dir", type=str,required=False) + parser.add_argument("--use_gpu", type=str2bool, default=False) + + return parser.parse_args() + + +def create_predictor(args): + with open(args.config) as f: + config = json.load(f) + + def create_input(config): + image = fluid.layers.data( + name='image', shape=config.get("image_shape"), dtype='float32') + return image + + def create_model(architecture, model, input, class_dim=1000): + if architecture == "GoogLeNet": + out, _, _ = model.net(input=input, class_dim=class_dim) + else: + out = model.net(input=input, class_dim=class_dim) + out = fluid.layers.softmax(out) + return out + + architecture = config.get("architecture") + model = models.__dict__[architecture]() + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + startup_prog = fluid.Program() + infer_prog = fluid.Program() + with fluid.program_guard(infer_prog, startup_prog): + with fluid.unique_name.guard(): + image = create_input(config) + out = create_model(architecture, model, image,config.get("num_classes")) + exe.run(startup_prog) + + infer_prog = infer_prog.clone(for_test=True) + exe.run(startup_prog) + # fluid.load( + # program=infer_prog, model_path=args.weights, executor=exe) + checkpoint.load_params(exe, infer_prog, args.weights) + + return exe, infer_prog, [image.name], [out.name] + + +def create_operators(): + size = 224 + img_mean = [0.485, 0.456, 0.406] + img_std = [0.229, 0.224, 0.225] + img_scale = 1.0 / 255.0 + + decode_op = utils.DecodeImage() + resize_op = utils.ResizeImage(resize_short=256) + crop_op = utils.CropImage(size=(size, size)) + normalize_op = utils.NormalizeImage( + scale=img_scale, mean=img_mean, std=img_std) + totensor_op = utils.ToTensor() + + return [decode_op, resize_op, crop_op, normalize_op, totensor_op] + + +def preprocess(fname, ops): + data = open(fname, 'rb').read() + for op in ops: + data = op(data) + + return data + + +def postprocess(outputs, topk=5): + output = outputs[0] + prob = np.array(output).flatten() + index = prob.argsort(axis=0)[-topk:][::-1].astype('int32') + return zip(index, prob[index]) + + +def get_image_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] + if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + if single_file.split('.')[-1] in img_end: + imgs_lists.append(os.path.join(img_file, single_file)) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + return imgs_lists + + +def main(): + args = parse_args() + operators = create_operators() + exe, program, feed_names, fetch_names = create_predictor(args) + image_list = get_image_list(args.infer_dir) + model = TaskDao(args.task_id).get_task_result(TaskResultType.INFER) + infer_result={} + if model: + infer_result = json.loads(model.result) + infer_result.update({"status": "running"}) + TaskDao(args.task_id).save_task_result(infer_result, ComponentName.CLASSIFY,type=TaskResultType.INFER) + task_result = TaskDao(args.task_id).get_task_result(TaskResultType.LABEL) + if not task_result: + raise Exception(f"task result is None as task id: {args.task_id}") + label_file = json.loads(task_result.result).get("label_path") + cats = [] + with open(label_file) as f: + for line in f.readlines(): + lines = line.replace('\n','').split(' ') + cats.append(' '.join(lines[1:])) + img_probs = [] + for idx, filename in enumerate(image_list): + data = preprocess(filename, operators) + data = np.expand_dims(data, axis=0) + outputs = exe.run(program, + feed={feed_names[0]: data}, + fetch_list=fetch_names, + return_numpy=False) + probs = postprocess(outputs) + logging.debug("current image: {}".format(filename)) + infer_probs = [] + for idx, prob in probs: + logging.debug("\tclass id: {:d}, probability: {:.4f}".format(idx, prob)) + infer_probs.append({"class_id":idx,"class_name":str(cats[idx]),"prob":prob}) + infer_dict = {"image": os.path.basename(filename), "infer_probs": infer_probs} + img_probs.append(infer_dict) + infer_result["result"] = img_probs + infer_result["status"] = "finish" + logging.debug(f"infer result: {infer_result}") + TaskDao(task_id=args.task_id).save_task_result(infer_result, ComponentName.CLASSIFY, type=TaskResultType.INFER) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/__init__.py b/VisualFL/visualfl/algorithm/paddle_clas/models/__init__.py new file mode 100755 index 000000000..88794d7ba --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/__init__.py @@ -0,0 +1,46 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from .lenet import LeNet +from .custom import CNN +from .alexnet import AlexNet +from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x1_0, MobileNetV1_x0_75, MobileNetV1 +from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2_x1_0, MobileNetV2_x1_5, MobileNetV2_x2_0, MobileNetV2 +from .mobilenet_v3 import MobileNetV3_small_x0_25, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_25, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25 +from .googlenet import GoogLeNet +from .vgg import VGG11, VGG13, VGG16, VGG19 +from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 +from .resnet_vc import ResNet50_vc, ResNet101_vc, ResNet152_vc +from .resnet_vd import ResNet18_vd, ResNet34_vd, ResNet50_vd, ResNet101_vd, ResNet152_vd, ResNet200_vd +from .resnext import ResNeXt50_64x4d, ResNeXt101_64x4d, ResNeXt152_64x4d, ResNeXt50_32x4d, ResNeXt101_32x4d, ResNeXt152_32x4d +from .resnext_vd import ResNeXt50_vd_64x4d, ResNeXt101_vd_64x4d, ResNeXt152_vd_64x4d, ResNeXt50_vd_32x4d, ResNeXt101_vd_32x4d, ResNeXt152_vd_32x4d +from .inception_v4 import InceptionV4 +from .se_resnet_vd import SE_ResNet18_vd, SE_ResNet34_vd, SE_ResNet50_vd, SE_ResNet101_vd, SE_ResNet152_vd, SE_ResNet200_vd +from .se_resnext import SE_ResNeXt50_32x4d, SE_ResNeXt101_32x4d, SE_ResNeXt152_32x4d +from .se_resnext_vd import SE_ResNeXt50_vd_32x4d, SE_ResNeXt101_vd_32x4d, SENet154_vd +from .dpn import DPN68, DPN92, DPN98, DPN107, DPN131 +from .shufflenet_v2_swish import ShuffleNetV2_swish, ShuffleNetV2_x0_5_swish, ShuffleNetV2_x1_0_swish, ShuffleNetV2_x1_5_swish, ShuffleNetV2_x2_0_swish +from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2_x1_0, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2 +from .xception import Xception41, Xception65, Xception71 +from .xception_deeplab import Xception41_deeplab, Xception65_deeplab, Xception71_deeplab +from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264 +from .squeezenet import SqueezeNet1_0, SqueezeNet1_1 +from .darknet import DarkNet53 +from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl +from .efficientnet import EfficientNet, EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7 +from .res2net import Res2Net50_48w_2s, Res2Net50_26w_4s, Res2Net50_14w_8s, Res2Net50_26w_6s, Res2Net50_26w_8s, Res2Net101_26w_4s, Res2Net152_26w_4s +from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_14w_8s, Res2Net50_vd_26w_6s, Res2Net50_vd_26w_8s, Res2Net101_vd_26w_4s, Res2Net152_vd_26w_4s, Res2Net200_vd_26w_4s +from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C +from .autodl import DARTS_6M, DARTS_4M +from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/alexnet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/alexnet.py new file mode 100755 index 000000000..638aaa544 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/alexnet.py @@ -0,0 +1,95 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid + +__all__ = ['AlexNet'] + + +class AlexNet(): + def __init__(self): + pass + + def net(self, input, class_dim=1000): + + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=input, + num_filters=96, + filter_size=11, + conv_stride=4, + conv_padding=5, + pool_size=2, + pool_stride=2, + act="relu", + ) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=256, + filter_size=5, + conv_stride=1, + conv_padding=2, + pool_size=2, + pool_stride=2, + act="relu", + ) + conv_pool_3 = fluid.nets.simple_img_conv_pool( + input=conv_pool_2, + num_filters=384, + filter_size=3, + conv_stride=1, + conv_padding=1, + pool_size=1, + pool_stride=1, + act="relu", + ) + conv_pool_4 = fluid.nets.simple_img_conv_pool( + input=conv_pool_3, + num_filters=384, + filter_size=3, + conv_stride=1, + conv_padding=1, + pool_size=1, + pool_stride=1, + act="relu", + ) + conv_pool_5 = fluid.nets.simple_img_conv_pool( + input=conv_pool_4, + num_filters=256, + filter_size=3, + conv_stride=1, + conv_padding=1, + pool_size=2, + pool_stride=2, + act="relu", + ) + fc1 = fluid.layers.fc( + input=conv_pool_5, size=4096, act="relu" + ) + fc1 = fluid.layers.dropout(fc1, 0.5) + fc2 = fluid.layers.fc( + input=fc1, size=4096, act="relu" + ) + fc2 = fluid.layers.dropout(fc2, 0.5) + out = fluid.layers.fc( + input=fc2, size=class_dim + ) + + return out diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/autodl.py b/VisualFL/visualfl/algorithm/paddle_clas/models/autodl.py new file mode 100755 index 000000000..915bc1631 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/autodl.py @@ -0,0 +1,562 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# DARTS +# Copyright (c) 2018, Hanxiao Liu. +# Licensed under the Apache License, Version 2.0; +# -------------------------------------------------------- + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import sys +import numpy as np +import time +import functools +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Xavier +from paddle.fluid.initializer import Normal +from paddle.fluid.initializer import Constant + + +from collections import namedtuple +Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') + +arch_dict = { + 'DARTS_6M': Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('skip_connect', 4), ('sep_conv_3x3', 3)], normal_concat=range(2, 6), reduce=[('sep_conv_5x5', 0), ('max_pool_3x3', 1), ('dil_conv_5x5', 2), ('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('dil_conv_5x5', 3), ('dil_conv_3x3', 1), ('sep_conv_3x3', 2)], reduce_concat=range(2, 6)), + 'DARTS_4M': Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('avg_pool_3x3', 1), ('skip_connect', 3), ('skip_connect', 2), ('sep_conv_3x3', 0), ('sep_conv_5x5', 2)], reduce_concat=range(2, 6)), +} + +__all__ = list(arch_dict.keys()) + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + +OPS = { + 'none' : lambda input, C, stride, name, affine: Zero(input, stride, name), + 'avg_pool_3x3' : lambda input, C, stride, name, affine: fluid.layers.pool2d(input, 3, 'avg', pool_stride=stride, pool_padding=1, name=name), + 'max_pool_3x3' : lambda input, C, stride, name, affine: fluid.layers.pool2d(input, 3, 'max', pool_stride=stride, pool_padding=1, name=name), + 'skip_connect' : lambda input,C, stride, name, affine: Identity(input, name) if stride == 1 else FactorizedReduce(input, C, name=name, affine=affine), + 'sep_conv_3x3' : lambda input,C, stride, name, affine: SepConv(input, C, C, 3, stride, 1, name=name, affine=affine), + 'sep_conv_5x5' : lambda input,C, stride, name, affine: SepConv(input, C, C, 5, stride, 2, name=name, affine=affine), + 'sep_conv_7x7' : lambda input,C, stride, name, affine: SepConv(input, C, C, 7, stride, 3, name=name, affine=affine), + 'dil_conv_3x3' : lambda input,C, stride, name, affine: DilConv(input, C, C, 3, stride, 2, 2, name=name, affine=affine), + 'dil_conv_5x5' : lambda input,C, stride, name, affine: DilConv(input, C, C, 5, stride, 4, 2, name=name, affine=affine), + 'conv_7x1_1x7' : lambda input,C, stride, name, affine: SevenConv(input, C, name=name, affine=affine) +} + +def ReLUConvBN(input, C_out, kernel_size, stride, padding, name='', + affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_out, + kernel_size, + stride, + padding, + bias_attr=False) + if affine: + reluconvbn_out = fluid.layers.batch_norm( + conv2d_a, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.2.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.2.bias'), + moving_mean_name=name + 'op.2.running_mean', + moving_variance_name=name + 'op.2.running_var') + else: + reluconvbn_out = fluid.layers.batch_norm( + conv2d_a, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.2.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.2.bias'), + moving_mean_name=name + 'op.2.running_mean', + moving_variance_name=name + 'op.2.running_var') + return reluconvbn_out + +def DilConv(input, + C_in, + C_out, + kernel_size, + stride, + padding, + dilation, + name='', + affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_in, + kernel_size, + stride, + padding, + dilation, + groups=C_in, + bias_attr=False, + use_cudnn=False) + conv2d_b = fluid.layers.conv2d( + conv2d_a, + C_out, + 1, + bias_attr=False) + if affine: + dilconv_out = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + else: + dilconv_out = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + return dilconv_out + +def SepConv(input, + C_in, + C_out, + kernel_size, + stride, + padding, + name='', + affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_in, + kernel_size, + stride, + padding, + groups=C_in, + bias_attr=False, + use_cudnn=False) + conv2d_b = fluid.layers.conv2d( + conv2d_a, + C_in, + 1, + bias_attr=False) + if affine: + bn_a = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + else: + bn_a = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + + relu_b = fluid.layers.relu(bn_a) + conv2d_d = fluid.layers.conv2d( + relu_b, + C_in, + kernel_size, + 1, + padding, + groups=C_in, + bias_attr=False, + use_cudnn=False) + conv2d_e = fluid.layers.conv2d( + conv2d_d, + C_out, + 1, + bias_attr=False) + if affine: + sepconv_out = fluid.layers.batch_norm( + conv2d_e, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.7.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.7.bias'), + moving_mean_name=name + 'op.7.running_mean', + moving_variance_name=name + 'op.7.running_var') + else: + sepconv_out = fluid.layers.batch_norm( + conv2d_e, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.7.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.7.bias'), + moving_mean_name=name + 'op.7.running_mean', + moving_variance_name=name + 'op.7.running_var') + return sepconv_out + +def SevenConv(input, C_out, stride, name='', affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_out, (1, 7), (1, stride), (0, 3), + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.1.weight'), + bias_attr=False) + conv2d_b = fluid.layers.conv2d( + conv2d_a, + C_out, (7, 1), (stride, 1), (3, 0), + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'op.2.weight'), + bias_attr=False) + if affine: + out = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + else: + out = fluid.layers.batch_norm( + conv2d_b, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'op.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'op.3.bias'), + moving_mean_name=name + 'op.3.running_mean', + moving_variance_name=name + 'op.3.running_var') + +def Identity(input, name=''): + return input + +def Zero(input, stride, name=''): + ones = np.ones(input.shape[-2:]) + ones[::stride, ::stride] = 0 + ones = fluid.layers.assign(ones) + return input * ones + +def FactorizedReduce(input, C_out, name='', affine=True): + relu_a = fluid.layers.relu(input) + conv2d_a = fluid.layers.conv2d( + relu_a, + C_out // 2, + 1, + 2, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'conv_1.weight'), + bias_attr=False) + h_end = relu_a.shape[2] + w_end = relu_a.shape[3] + slice_a = fluid.layers.slice(relu_a, [2, 3], [1, 1], [h_end, w_end]) + conv2d_b = fluid.layers.conv2d( + slice_a, + C_out // 2, + 1, + 2, + param_attr=ParamAttr( + initializer=Xavier( + uniform=False, fan_in=0), + name=name + 'conv_2.weight'), + bias_attr=False) + out = fluid.layers.concat([conv2d_a, conv2d_b], axis=1) + if affine: + out = fluid.layers.batch_norm( + out, + param_attr=ParamAttr( + initializer=Constant(1.), name=name + 'bn.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=name + 'bn.bias'), + moving_mean_name=name + 'bn.running_mean', + moving_variance_name=name + 'bn.running_var') + else: + out = fluid.layers.batch_norm( + out, + param_attr=ParamAttr( + initializer=Constant(1.), + learning_rate=0., + name=name + 'bn.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + learning_rate=0., + name=name + 'bn.bias'), + moving_mean_name=name + 'bn.running_mean', + moving_variance_name=name + 'bn.running_var') + return out + +class Cell(): + def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, + reduction_prev): + + if reduction_prev: + self.preprocess0 = functools.partial(FactorizedReduce, C_out=C) + else: + self.preprocess0 = functools.partial( + ReLUConvBN, C_out=C, kernel_size=1, stride=1, padding=0) + self.preprocess1 = functools.partial( + ReLUConvBN, C_out=C, kernel_size=1, stride=1, padding=0) + if reduction: + op_names, indices = zip(*genotype.reduce) + concat = genotype.reduce_concat + else: + op_names, indices = zip(*genotype.normal) + concat = genotype.normal_concat + print(op_names, indices, concat, reduction) + self._compile(C, op_names, indices, concat, reduction) + + def _compile(self, C, op_names, indices, concat, reduction): + assert len(op_names) == len(indices) + self._steps = len(op_names) // 2 + self._concat = concat + self.multiplier = len(concat) + + self._ops = [] + for name, index in zip(op_names, indices): + stride = 2 if reduction and index < 2 else 1 + op = functools.partial(OPS[name], C=C, stride=stride, affine=True) + self._ops += [op] + self._indices = indices + + def forward(self, s0, s1, drop_prob, is_train, name): + self.training = is_train + preprocess0_name = name + 'preprocess0.' + preprocess1_name = name + 'preprocess1.' + s0 = self.preprocess0(s0, name=preprocess0_name) + s1 = self.preprocess1(s1, name=preprocess1_name) + out = [s0, s1] + for i in range(self._steps): + h1 = out[self._indices[2 * i]] + h2 = out[self._indices[2 * i + 1]] + op1 = self._ops[2 * i] + op2 = self._ops[2 * i + 1] + h3 = op1(h1, name=name + '_ops.' + str(2 * i) + '.') + h4 = op2(h2, name=name + '_ops.' + str(2 * i + 1) + '.') + if self.training and drop_prob > 0.: + if h3 != h1: + h3 = fluid.layers.dropout( + h3, + drop_prob, + dropout_implementation='upscale_in_train') + if h4 != h2: + h4 = fluid.layers.dropout( + h4, + drop_prob, + dropout_implementation='upscale_in_train') + s = h3 + h4 + out += [s] + return fluid.layers.concat([out[i] for i in self._concat], axis=1) + +def AuxiliaryHeadImageNet(input, num_classes, aux_name='auxiliary_head'): + relu_a = fluid.layers.relu(input) + pool_a = fluid.layers.pool2d(relu_a, 5, 'avg', 2) + conv2d_a = fluid.layers.conv2d( + pool_a, + 128, + 1, + name=aux_name + '.features.2', + bias_attr=False) + bn_a_name = aux_name + '.features.3' + bn_a = fluid.layers.batch_norm( + conv2d_a, + act='relu', + name=bn_a_name, + param_attr=ParamAttr( + initializer=Constant(1.), name=bn_a_name + '.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=bn_a_name + '.bias'), + moving_mean_name=bn_a_name + '.running_mean', + moving_variance_name=bn_a_name + '.running_var') + conv2d_b = fluid.layers.conv2d( + bn_a, + 768, + 2, + name=aux_name + '.features.5', + bias_attr=False) + bn_b_name = aux_name + '.features.6' + bn_b = fluid.layers.batch_norm( + conv2d_b, + act='relu', + name=bn_b_name, + param_attr=ParamAttr( + initializer=Constant(1.), name=bn_b_name + '.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=bn_b_name + '.bias'), + moving_mean_name=bn_b_name + '.running_mean', + moving_variance_name=bn_b_name + '.running_var') + pool_b = fluid.layers.adaptive_pool2d(bn_b, (1, 1), "avg") + fc_name = aux_name + '.classifier' + fc = fluid.layers.fc(pool_b, + num_classes, + name=fc_name, + param_attr=ParamAttr( + initializer=Normal(scale=1e-3), + name=fc_name + '.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name=fc_name + '.bias')) + return fc + + +def StemConv0(input, C_out): + conv_a = fluid.layers.conv2d( + input, + C_out // 2, + 3, + stride=2, + padding=1, + bias_attr=False) + bn_a = fluid.layers.batch_norm( + conv_a, + act='relu', + param_attr=ParamAttr( + initializer=Constant(1.), name='stem0.1.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name='stem0.1.bias'), + moving_mean_name='stem0.1.running_mean', + moving_variance_name='stem0.1.running_var') + + conv_b = fluid.layers.conv2d( + bn_a, + C_out, + 3, + stride=2, + padding=1, + bias_attr=False) + bn_b = fluid.layers.batch_norm( + conv_b, + param_attr=ParamAttr( + initializer=Constant(1.), name='stem0.3.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name='stem0.3.bias'), + moving_mean_name='stem0.3.running_mean', + moving_variance_name='stem0.3.running_var') + return bn_b + +def StemConv1(input, C_out): + relu_a = fluid.layers.relu(input) + conv_a = fluid.layers.conv2d( + relu_a, + C_out, + 3, + stride=2, + padding=1, + bias_attr=False) + bn_a = fluid.layers.batch_norm( + conv_a, + param_attr=ParamAttr( + initializer=Constant(1.), name='stem1.1.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), name='stem1.1.bias'), + moving_mean_name='stem1.1.running_mean', + moving_variance_name='stem1.1.running_var') + return bn_a + +class NetworkImageNet(object): + def __init__(self, arch='DARTS_6M'): + self.params = train_parameters + self.class_num = 1000 + self.init_channel = 48 + self._layers = 14 + self._auxiliary = False + self.drop_path_prob = 0 + genotype = arch_dict[arch] + + C = self.init_channel + layers = self._layers + C_prev_prev, C_prev, C_curr = C, C, C + self.cells = [] + reduction_prev = True + for i in range(layers): + if i in [layers // 3, 2 * layers // 3]: + C_curr *= 2 + reduction = True + else: + reduction = False + cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, + reduction_prev) + reduction_prev = reduction + self.cells += [cell] + C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr + if i == 2 * layers // 3: + C_to_auxiliary = C_prev + + def net(self, input, class_dim=1000, is_train=True): + self.logits_aux = None + num_channel = self.init_channel + s0 = StemConv0(input, num_channel) + s1 = StemConv1(s0, num_channel) + for i, cell in enumerate(self.cells): + name = 'cells.' + str(i) + '.' + s0, s1 = s1, cell.forward(s0, s1, self.drop_path_prob, is_train, + name) + if i == int(2 * self._layers // 3): + if self._auxiliary and is_train: + self.logits_aux = AuxiliaryHeadImageNet(s1, self.class_num) + out = fluid.layers.adaptive_pool2d(s1, (1, 1), "avg") + self.logits = fluid.layers.fc(out, + size=self.class_num, + param_attr=ParamAttr( + initializer=Normal(scale=1e-4), + name='classifier.weight'), + bias_attr=ParamAttr( + initializer=Constant(0.), + name='classifier.bias')) + return self.logits + +def DARTS_6M(): + return NetworkImageNet(arch = 'DARTS_6M') +def DARTS_4M(): + return NetworkImageNet(arch = 'DARTS_4M') diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/custom.py b/VisualFL/visualfl/algorithm/paddle_clas/models/custom.py new file mode 100644 index 000000000..8d52759aa --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/custom.py @@ -0,0 +1,39 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid + +__all__ = ['CNN'] + + +class CNN(): + def __init__(self): + pass + + def net(self,input, class_dim=1000): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=input, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act="relu", + ) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act="relu", + ) + + out = fluid.layers.fc( + input=conv_pool_2, size=class_dim + ) + + return out \ No newline at end of file diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/darknet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/darknet.py new file mode 100755 index 000000000..895809a6c --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/darknet.py @@ -0,0 +1,115 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +import math +__all__ = ["DarkNet53"] + + +class DarkNet53(): + def __init__(self): + + pass + + def net(self, input, class_dim=1000): + DarkNet_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)} + stages, block_func = DarkNet_cfg[53] + stages = stages[0:5] + conv1 = self.conv_bn_layer( + input, + ch_out=32, + filter_size=3, + stride=1, + padding=1, + name="yolo_input") + conv = self.downsample( + conv1, ch_out=conv1.shape[1] * 2, name="yolo_input.downsample") + + for i, stage in enumerate(stages): + conv = self.layer_warp( + block_func, conv, 32 * (2**i), stage, name="stage.{}".format(i)) + if i < len(stages) - 1: # do not downsaple in the last stage + conv = self.downsample( + conv, + ch_out=conv.shape[1] * 2, + name="stage.{}.downsample".format(i)) + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name='fc_weights'), + bias_attr=ParamAttr(name='fc_offset')) + return out + + def conv_bn_layer(self, + input, + ch_out, + filter_size, + stride, + padding, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + act=None, + param_attr=ParamAttr(name=name + ".conv.weights"), + bias_attr=False) + + bn_name = name + ".bn" + out = fluid.layers.batch_norm( + input=conv, + act='relu', + param_attr=ParamAttr(name=bn_name + '.scale'), + bias_attr=ParamAttr(name=bn_name + '.offset'), + moving_mean_name=bn_name + '.mean', + moving_variance_name=bn_name + '.var') + return out + + def downsample(self, + input, + ch_out, + filter_size=3, + stride=2, + padding=1, + name=None): + return self.conv_bn_layer( + input, + ch_out=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + name=name) + + def basicblock(self, input, ch_out, name=None): + conv1 = self.conv_bn_layer(input, ch_out, 1, 1, 0, name=name + ".0") + conv2 = self.conv_bn_layer(conv1, ch_out * 2, 3, 1, 1, name=name + ".1") + out = fluid.layers.elementwise_add(x=input, y=conv2, act=None) + return out + + def layer_warp(self, block_func, input, ch_out, count, name=None): + res_out = block_func(input, ch_out, name='{}.0'.format(name)) + for j in range(1, count): + res_out = block_func(res_out, ch_out, name='{}.{}'.format(name, j)) + return res_out diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/densenet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/densenet.py new file mode 100755 index 000000000..d3a7c1da7 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/densenet.py @@ -0,0 +1,201 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + "DenseNet", "DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201", + "DenseNet264" +] + + +class DenseNet(): + def __init__(self, layers=121): + self.layers = layers + + def net(self, input, bn_size=4, dropout=0, class_dim=1000): + layers = self.layers + supported_layers = [121, 161, 169, 201, 264] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + densenet_spec = { + 121: (64, 32, [6, 12, 24, 16]), + 161: (96, 48, [6, 12, 36, 24]), + 169: (64, 32, [6, 12, 32, 32]), + 201: (64, 32, [6, 12, 48, 32]), + 264: (64, 32, [6, 12, 64, 48]) + } + + num_init_features, growth_rate, block_config = densenet_spec[layers] + conv = fluid.layers.conv2d( + input=input, + num_filters=num_init_features, + filter_size=7, + stride=2, + padding=3, + act=None, + param_attr=ParamAttr(name="conv1_weights"), + bias_attr=False) + conv = fluid.layers.batch_norm( + input=conv, + act='relu', + param_attr=ParamAttr(name='conv1_bn_scale'), + bias_attr=ParamAttr(name='conv1_bn_offset'), + moving_mean_name='conv1_bn_mean', + moving_variance_name='conv1_bn_variance') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + num_features = num_init_features + for i, num_layers in enumerate(block_config): + conv = self.make_dense_block( + conv, + num_layers, + bn_size, + growth_rate, + dropout, + name='conv' + str(i + 2)) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + conv = self.make_transition( + conv, num_features // 2, name='conv' + str(i + 2) + '_blk') + num_features = num_features // 2 + conv = fluid.layers.batch_norm( + input=conv, + act='relu', + param_attr=ParamAttr(name='conv5_blk_bn_scale'), + bias_attr=ParamAttr(name='conv5_blk_bn_offset'), + moving_mean_name='conv5_blk_bn_mean', + moving_variance_name='conv5_blk_bn_variance') + conv = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(conv.shape[1] * 1.0) + out = fluid.layers.fc( + input=conv, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name="fc_weights"), + bias_attr=ParamAttr(name='fc_offset')) + return out + + def make_transition(self, input, num_output_features, name=None): + bn_ac = fluid.layers.batch_norm( + input, + act='relu', + param_attr=ParamAttr(name=name + '_bn_scale'), + bias_attr=ParamAttr(name + '_bn_offset'), + moving_mean_name=name + '_bn_mean', + moving_variance_name=name + '_bn_variance') + + bn_ac_conv = fluid.layers.conv2d( + input=bn_ac, + num_filters=num_output_features, + filter_size=1, + stride=1, + act=None, + bias_attr=False, + param_attr=ParamAttr(name=name + "_weights")) + pool = fluid.layers.pool2d( + input=bn_ac_conv, pool_size=2, pool_stride=2, pool_type='avg') + return pool + + def make_dense_block(self, + input, + num_layers, + bn_size, + growth_rate, + dropout, + name=None): + conv = input + for layer in range(num_layers): + conv = self.make_dense_layer( + conv, + growth_rate, + bn_size, + dropout, + name=name + '_' + str(layer + 1)) + return conv + + def make_dense_layer(self, input, growth_rate, bn_size, dropout, name=None): + bn_ac = fluid.layers.batch_norm( + input, + act='relu', + param_attr=ParamAttr(name=name + '_x1_bn_scale'), + bias_attr=ParamAttr(name + '_x1_bn_offset'), + moving_mean_name=name + '_x1_bn_mean', + moving_variance_name=name + '_x1_bn_variance') + bn_ac_conv = fluid.layers.conv2d( + input=bn_ac, + num_filters=bn_size * growth_rate, + filter_size=1, + stride=1, + act=None, + bias_attr=False, + param_attr=ParamAttr(name=name + "_x1_weights")) + bn_ac = fluid.layers.batch_norm( + bn_ac_conv, + act='relu', + param_attr=ParamAttr(name=name + '_x2_bn_scale'), + bias_attr=ParamAttr(name + '_x2_bn_offset'), + moving_mean_name=name + '_x2_bn_mean', + moving_variance_name=name + '_x2_bn_variance') + bn_ac_conv = fluid.layers.conv2d( + input=bn_ac, + num_filters=growth_rate, + filter_size=3, + stride=1, + padding=1, + act=None, + bias_attr=False, + param_attr=ParamAttr(name=name + "_x2_weights")) + if dropout: + bn_ac_conv = fluid.layers.dropout( + x=bn_ac_conv, dropout_prob=dropout) + bn_ac_conv = fluid.layers.concat([input, bn_ac_conv], axis=1) + return bn_ac_conv + + +def DenseNet121(): + model = DenseNet(layers=121) + return model + + +def DenseNet161(): + model = DenseNet(layers=161) + return model + + +def DenseNet169(): + model = DenseNet(layers=169) + return model + + +def DenseNet201(): + model = DenseNet(layers=201) + return model + + +def DenseNet264(): + model = DenseNet(layers=264) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/dpn.py b/VisualFL/visualfl/algorithm/paddle_clas/models/dpn.py new file mode 100755 index 000000000..36777b623 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/dpn.py @@ -0,0 +1,334 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np +import time +import sys +import math + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["DPN", "DPN68", "DPN92", "DPN98", "DPN107", "DPN131"] + + +class DPN(object): + def __init__(self, layers=68): + self.layers = layers + + def net(self, input, class_dim=1000): + # get network args + args = self.get_net_args(self.layers) + bws = args['bw'] + inc_sec = args['inc_sec'] + rs = args['r'] + k_r = args['k_r'] + k_sec = args['k_sec'] + G = args['G'] + init_num_filter = args['init_num_filter'] + init_filter_size = args['init_filter_size'] + init_padding = args['init_padding'] + + ## define Dual Path Network + + # conv1 + conv1_x_1 = fluid.layers.conv2d( + input=input, + num_filters=init_num_filter, + filter_size=init_filter_size, + stride=2, + padding=init_padding, + groups=1, + act=None, + bias_attr=False, + name="conv1", + param_attr=ParamAttr(name="conv1_weights"), ) + + conv1_x_1 = fluid.layers.batch_norm( + input=conv1_x_1, + act='relu', + is_test=False, + name="conv1_bn", + param_attr=ParamAttr(name='conv1_bn_scale'), + bias_attr=ParamAttr('conv1_bn_offset'), + moving_mean_name='conv1_bn_mean', + moving_variance_name='conv1_bn_variance', ) + + convX_x_x = fluid.layers.pool2d( + input=conv1_x_1, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max', + name="pool1") + + #conv2 - conv5 + match_list, num = [], 0 + for gc in range(4): + bw = bws[gc] + inc = inc_sec[gc] + R = (k_r * bw) // rs[gc] + if gc == 0: + _type1 = 'proj' + _type2 = 'normal' + match = 1 + else: + _type1 = 'down' + _type2 = 'normal' + match = match + k_sec[gc - 1] + match_list.append(match) + + convX_x_x = self.dual_path_factory( + convX_x_x, R, R, bw, inc, G, _type1, name="dpn" + str(match)) + for i_ly in range(2, k_sec[gc] + 1): + num += 1 + if num in match_list: + num += 1 + convX_x_x = self.dual_path_factory( + convX_x_x, R, R, bw, inc, G, _type2, name="dpn" + str(num)) + + conv5_x_x = fluid.layers.concat(convX_x_x, axis=1) + conv5_x_x = fluid.layers.batch_norm( + input=conv5_x_x, + act='relu', + is_test=False, + name="final_concat_bn", + param_attr=ParamAttr(name='final_concat_bn_scale'), + bias_attr=ParamAttr('final_concat_bn_offset'), + moving_mean_name='final_concat_bn_mean', + moving_variance_name='final_concat_bn_variance', ) + pool5 = fluid.layers.pool2d( + input=conv5_x_x, + pool_size=7, + pool_stride=1, + pool_padding=0, + pool_type='avg', ) + + stdv = 0.01 + fc6 = fluid.layers.fc(input=pool5, + size=class_dim, + param_attr=ParamAttr(initializer=fluid.initializer.Uniform(-stdv, stdv), name='fc_weights'), + bias_attr=ParamAttr(name='fc_offset')) + + return fc6 + + def get_net_args(self, layers): + if layers == 68: + k_r = 128 + G = 32 + k_sec = [3, 4, 12, 3] + inc_sec = [16, 32, 32, 64] + bw = [64, 128, 256, 512] + r = [64, 64, 64, 64] + init_num_filter = 10 + init_filter_size = 3 + init_padding = 1 + elif layers == 92: + k_r = 96 + G = 32 + k_sec = [3, 4, 20, 3] + inc_sec = [16, 32, 24, 128] + bw = [256, 512, 1024, 2048] + r = [256, 256, 256, 256] + init_num_filter = 64 + init_filter_size = 7 + init_padding = 3 + elif layers == 98: + k_r = 160 + G = 40 + k_sec = [3, 6, 20, 3] + inc_sec = [16, 32, 32, 128] + bw = [256, 512, 1024, 2048] + r = [256, 256, 256, 256] + init_num_filter = 96 + init_filter_size = 7 + init_padding = 3 + elif layers == 107: + k_r = 200 + G = 50 + k_sec = [4, 8, 20, 3] + inc_sec = [20, 64, 64, 128] + bw = [256, 512, 1024, 2048] + r = [256, 256, 256, 256] + init_num_filter = 128 + init_filter_size = 7 + init_padding = 3 + elif layers == 131: + k_r = 160 + G = 40 + k_sec = [4, 8, 28, 3] + inc_sec = [16, 32, 32, 128] + bw = [256, 512, 1024, 2048] + r = [256, 256, 256, 256] + init_num_filter = 128 + init_filter_size = 7 + init_padding = 3 + else: + raise NotImplementedError + net_arg = { + 'k_r': k_r, + 'G': G, + 'k_sec': k_sec, + 'inc_sec': inc_sec, + 'bw': bw, + 'r': r + } + net_arg['init_num_filter'] = init_num_filter + net_arg['init_filter_size'] = init_filter_size + net_arg['init_padding'] = init_padding + + return net_arg + + def dual_path_factory(self, + data, + num_1x1_a, + num_3x3_b, + num_1x1_c, + inc, + G, + _type='normal', + name=None): + kw = 3 + kh = 3 + pw = (kw - 1) // 2 + ph = (kh - 1) // 2 + + # type + if _type is 'proj': + key_stride = 1 + has_proj = True + if _type is 'down': + key_stride = 2 + has_proj = True + if _type is 'normal': + key_stride = 1 + has_proj = False + + # PROJ + if type(data) is list: + data_in = fluid.layers.concat([data[0], data[1]], axis=1) + else: + data_in = data + + if has_proj: + c1x1_w = self.bn_ac_conv( + data=data_in, + num_filter=(num_1x1_c + 2 * inc), + kernel=(1, 1), + pad=(0, 0), + stride=(key_stride, key_stride), + name=name + "_match") + data_o1, data_o2 = fluid.layers.split( + c1x1_w, + num_or_sections=[num_1x1_c, 2 * inc], + dim=1, + name=name + "_match_conv_Slice") + else: + data_o1 = data[0] + data_o2 = data[1] + + # MAIN + c1x1_a = self.bn_ac_conv( + data=data_in, + num_filter=num_1x1_a, + kernel=(1, 1), + pad=(0, 0), + name=name + "_conv1") + c3x3_b = self.bn_ac_conv( + data=c1x1_a, + num_filter=num_3x3_b, + kernel=(kw, kh), + pad=(pw, ph), + stride=(key_stride, key_stride), + num_group=G, + name=name + "_conv2") + c1x1_c = self.bn_ac_conv( + data=c3x3_b, + num_filter=(num_1x1_c + inc), + kernel=(1, 1), + pad=(0, 0), + name=name + "_conv3") + + c1x1_c1, c1x1_c2 = fluid.layers.split( + c1x1_c, + num_or_sections=[num_1x1_c, inc], + dim=1, + name=name + "_conv3_Slice") + + # OUTPUTS + summ = fluid.layers.elementwise_add( + x=data_o1, y=c1x1_c1, name=name + "_elewise") + dense = fluid.layers.concat( + [data_o2, c1x1_c2], axis=1, name=name + "_concat") + + return [summ, dense] + + def bn_ac_conv(self, + data, + num_filter, + kernel, + pad, + stride=(1, 1), + num_group=1, + name=None): + bn_ac = fluid.layers.batch_norm( + input=data, + act='relu', + is_test=False, + name=name + '.output.1', + param_attr=ParamAttr(name=name + '_bn_scale'), + bias_attr=ParamAttr(name + '_bn_offset'), + moving_mean_name=name + '_bn_mean', + moving_variance_name=name + '_bn_variance', ) + bn_ac_conv = fluid.layers.conv2d( + input=bn_ac, + num_filters=num_filter, + filter_size=kernel, + stride=stride, + padding=pad, + groups=num_group, + act=None, + bias_attr=False, + param_attr=ParamAttr(name=name + "_weights")) + return bn_ac_conv + + +def DPN68(): + model = DPN(layers=68) + return model + + +def DPN92(): + model = DPN(layers=92) + return model + + +def DPN98(): + model = DPN(layers=98) + return model + + +def DPN107(): + model = DPN(layers=107) + return model + + +def DPN131(): + model = DPN(layers=131) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/efficientnet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/efficientnet.py new file mode 100755 index 000000000..dcac10d8c --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/efficientnet.py @@ -0,0 +1,454 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +import collections +import re +import math +import copy +from .layers import conv2d, init_batch_norm_layer, init_fc_layer + + +__all__ = ['EfficientNet', 'EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', 'EfficientNetB4', + 'EfficientNetB5', 'EfficientNetB6', 'EfficientNetB7'] + +GlobalParams = collections.namedtuple('GlobalParams', [ + 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', + 'num_classes', 'width_coefficient', 'depth_coefficient', + 'depth_divisor', 'min_depth', 'drop_connect_rate', ]) + +BlockArgs = collections.namedtuple('BlockArgs', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) + +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + + +def efficientnet_params(model_name): + """ Map EfficientNet model name to parameter coefficients. """ + params_dict = { + # Coefficients: width,depth,resolution,dropout + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + } + return params_dict[model_name] + + +def efficientnet(width_coefficient=None, depth_coefficient=None, + dropout_rate=0.2, drop_connect_rate=0.2): + """ Get block arguments according to parameter and coefficients. """ + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + dropout_rate=dropout_rate, + drop_connect_rate=drop_connect_rate, + num_classes=1000, + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + depth_divisor=8, + min_depth=None + ) + + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """ Get the block args and global params for a given model """ + if model_name.startswith('efficientnet'): + w, d, _, p = efficientnet_params(model_name) + blocks_args, global_params = efficientnet(width_coefficient=w, depth_coefficient=d, dropout_rate=p) + else: + raise NotImplementedError('model name is not pre-defined: %s' % model_name) + if override_params: + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +def round_filters(filters, global_params): + """ Calculate and round number of filters based on depth multiplier. """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """ Round number of filters based on depth multiplier. """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +class EfficientNet(): + def __init__(self, name='b0', padding_type='SAME', override_params=None, is_test=False, use_se=True): + valid_names = ['b' + str(i) for i in range(8)] + assert name in valid_names, 'efficient name should be in b0~b7' + model_name = 'efficientnet-' + name + self._blocks_args, self._global_params = get_model_params(model_name, override_params) + self._bn_mom = self._global_params.batch_norm_momentum + self._bn_eps = self._global_params.batch_norm_epsilon + self.is_test = is_test + self.padding_type = padding_type + self.use_se = use_se + + def net(self, input, class_dim=1000, is_test=False): + + conv = self.extract_features(input, is_test=is_test) + + out_channels = round_filters(1280, self._global_params) + conv = self.conv_bn_layer(conv, + num_filters=out_channels, + filter_size=1, + bn_act='swish', + bn_mom=self._bn_mom, + bn_eps=self._bn_eps, + padding_type=self.padding_type, + name='', + conv_name='_conv_head', + bn_name='_bn1') + + pool = fluid.layers.pool2d(input=conv, pool_type='avg', global_pooling=True, use_cudnn=False) + + if self._global_params.dropout_rate: + pool = fluid.layers.dropout(pool, self._global_params.dropout_rate, dropout_implementation='upscale_in_train') + + param_attr, bias_attr = init_fc_layer(class_dim, '_fc') + out = fluid.layers.fc(pool, class_dim, name='_fc', param_attr=param_attr, bias_attr=bias_attr) + return out + + def _drop_connect(self, inputs, prob, is_test): + if is_test: + return inputs + keep_prob = 1.0 - prob + random_tensor = keep_prob + fluid.layers.uniform_random_batch_size_like(inputs, [-1, 1, 1, 1], min=0., max=1.) + binary_tensor = fluid.layers.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + def _expand_conv_norm(self, inputs, block_args, is_test, name=None): + # Expansion phase + oup = block_args.input_filters * block_args.expand_ratio # number of output channels + + if block_args.expand_ratio != 1: + conv = self.conv_bn_layer(inputs, + num_filters=oup, + filter_size=1, + bn_act=None, + bn_mom=self._bn_mom, + bn_eps=self._bn_eps, + padding_type=self.padding_type, + name=name, + conv_name=name + '_expand_conv', + bn_name='_bn0') + + return conv + + def _depthwise_conv_norm(self, inputs, block_args, is_test, name=None): + k = block_args.kernel_size + s = block_args.stride + if isinstance(s, list) or isinstance(s, tuple): + s = s[0] + oup = block_args.input_filters * block_args.expand_ratio # number of output channels + + conv = self.conv_bn_layer(inputs, + num_filters=oup, + filter_size=k, + stride=s, + num_groups=oup, + bn_act=None, + padding_type=self.padding_type, + bn_mom=self._bn_mom, + bn_eps=self._bn_eps, + name=name, + use_cudnn=False, + conv_name=name + '_depthwise_conv', + bn_name='_bn1') + + return conv + + def _project_conv_norm(self, inputs, block_args, is_test, name=None): + final_oup = block_args.output_filters + conv = self.conv_bn_layer(inputs, + num_filters=final_oup, + filter_size=1, + bn_act=None, + padding_type=self.padding_type, + bn_mom=self._bn_mom, + bn_eps=self._bn_eps, + name=name, + conv_name=name + '_project_conv', + bn_name='_bn2') + return conv + + def conv_bn_layer(self, input, filter_size, num_filters, stride=1, num_groups=1, padding_type="SAME", conv_act=None, + bn_act='swish', use_cudnn=True, use_bn=True, bn_mom=0.9, bn_eps=1e-05, use_bias=False, name=None, + conv_name=None, bn_name=None): + conv = conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + groups=num_groups, + act=conv_act, + padding_type=padding_type, + use_cudnn=use_cudnn, + name=conv_name, + use_bias=use_bias) + + if use_bn == False: + return conv + else: + bn_name = name + bn_name + param_attr, bias_attr = init_batch_norm_layer(bn_name) + return fluid.layers.batch_norm(input=conv, + act=bn_act, + momentum=bn_mom, + epsilon=bn_eps, + name=bn_name, + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + param_attr=param_attr, + bias_attr=bias_attr) + + def _conv_stem_norm(self, inputs, is_test): + out_channels = round_filters(32, self._global_params) + bn = self.conv_bn_layer(inputs, num_filters=out_channels, filter_size=3, stride=2, bn_act=None, + bn_mom=self._bn_mom, padding_type=self.padding_type, + bn_eps=self._bn_eps, name='', conv_name='_conv_stem', bn_name='_bn0') + + return bn + + def mb_conv_block(self, inputs, block_args, is_test=False, drop_connect_rate=None, name=None): + # Expansion and Depthwise Convolution + oup = block_args.input_filters * block_args.expand_ratio # number of output channels + has_se = self.use_se and (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1) + id_skip = block_args.id_skip # skip connection and drop connect + conv = inputs + if block_args.expand_ratio != 1: + conv = fluid.layers.swish(self._expand_conv_norm(conv, block_args, is_test, name)) + + conv = fluid.layers.swish(self._depthwise_conv_norm(conv, block_args, is_test, name)) + + # Squeeze and Excitation + if has_se: + num_squeezed_channels = max(1, int(block_args.input_filters * block_args.se_ratio)) + conv = self.se_block(conv, num_squeezed_channels, oup, name) + + conv = self._project_conv_norm(conv, block_args, is_test, name) + + # Skip connection and drop connect + input_filters, output_filters = block_args.input_filters, block_args.output_filters + if id_skip and block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + conv = self._drop_connect(conv, drop_connect_rate, self.is_test) + conv = fluid.layers.elementwise_add(conv, inputs) + + return conv + + def se_block(self, inputs, num_squeezed_channels, oup, name): + x_squeezed = fluid.layers.pool2d( + input=inputs, + pool_type='avg', + global_pooling=True, + use_cudnn=False) + x_squeezed = conv2d(x_squeezed, + num_filters=num_squeezed_channels, + filter_size=1, + use_bias=True, + padding_type=self.padding_type, + act='swish', + name=name + '_se_reduce') + x_squeezed = conv2d(x_squeezed, + num_filters=oup, + filter_size=1, + use_bias=True, + padding_type=self.padding_type, + name=name + '_se_expand') + se_out = inputs * fluid.layers.sigmoid(x_squeezed) + return se_out + + def extract_features(self, inputs, is_test): + """ Returns output of the final convolution layer """ + + conv = fluid.layers.swish(self._conv_stem_norm(inputs, is_test=is_test)) + + block_args_copy = copy.deepcopy(self._blocks_args) + idx = 0 + block_size = 0 + for block_arg in block_args_copy: + block_arg = block_arg._replace( + input_filters=round_filters(block_arg.input_filters, self._global_params), + output_filters=round_filters(block_arg.output_filters, self._global_params), + num_repeat=round_repeats(block_arg.num_repeat, self._global_params) + ) + block_size += 1 + for _ in range(block_arg.num_repeat - 1): + block_size += 1 + + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / block_size + conv = self.mb_conv_block(conv, block_args, is_test, drop_connect_rate, '_blocks.' + str(idx) + '.') + + idx += 1 + if block_args.num_repeat > 1: + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / block_size + conv = self.mb_conv_block(conv, block_args, is_test, drop_connect_rate, '_blocks.' + str(idx) + '.') + idx += 1 + + return conv + + def shortcut(self, input, data_residual): + return fluid.layers.elementwise_add(input, data_residual) + + +class BlockDecoder(object): + """ Block Decoder for readability, straight from the official TensorFlow repository """ + + @staticmethod + def _decode_block_string(block_string): + """ Gets a block through a string notation of arguments. """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + + return BlockArgs( + kernel_size=int(options['k']), + num_repeat=int(options['r']), + input_filters=int(options['i']), + output_filters=int(options['o']), + expand_ratio=int(options['e']), + id_skip=('noskip' not in block_string), + se_ratio=float(options['se']) if 'se' in options else None, + stride=[int(options['s'][0])]) + + @staticmethod + def _encode_block_string(block): + """Encodes a block to a string.""" + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """ + Decodes a list of string notations to specify blocks inside the network. + + :param string_list: a list of strings, each string is a notation of block + :return: a list of BlockArgs namedtuples of block args + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """ + Encodes a list of BlockArgs to a list of strings. + + :param blocks_args: a list of BlockArgs namedtuples of block args + :return: a list of strings, each string is a notation of block + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def EfficientNetB0(is_test=False, padding_type='SAME', override_params=None, use_se=True): + model = EfficientNet(name='b0', is_test=is_test, padding_type=padding_type, override_params=override_params, use_se=use_se) + return model + + +def EfficientNetB1(is_test=False, padding_type='SAME', override_params=None, use_se=True): + model = EfficientNet(name='b1', is_test=is_test, padding_type=padding_type, override_params=override_params, use_se=use_se) + return model + + +def EfficientNetB2(is_test=False, padding_type='SAME', override_params=None, use_se=True): + model = EfficientNet(name='b2', is_test=is_test, padding_type=padding_type, override_params=override_params, use_se=use_se) + return model + + +def EfficientNetB3(is_test=False, padding_type='SAME', override_params=None, use_se=True): + model = EfficientNet(name='b3', is_test=is_test, padding_type=padding_type, override_params=override_params, use_se=use_se) + return model + + +def EfficientNetB4(is_test=False, padding_type='SAME', override_params=None, use_se=True): + model = EfficientNet(name='b4', is_test=is_test, padding_type=padding_type, override_params=override_params, use_se=use_se) + return model + + +def EfficientNetB5(is_test=False, padding_type='SAME', override_params=None, use_se=True): + model = EfficientNet(name='b5', is_test=is_test, padding_type=padding_type, override_params=override_params, use_se=use_se) + return model + + +def EfficientNetB6(is_test=False, padding_type='SAME', override_params=None, use_se=True): + model = EfficientNet(name='b6', is_test=is_test, padding_type=padding_type, override_params=override_params, use_se=use_se) + return model + + +def EfficientNetB7(is_test=False, padding_type='SAME', override_params=None, use_se=True): + model = EfficientNet(name='b7', is_test=is_test, padding_type=padding_type, override_params=override_params, use_se=use_se) + return model \ No newline at end of file diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/googlenet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/googlenet.py new file mode 100755 index 000000000..5b92d5775 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/googlenet.py @@ -0,0 +1,237 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = ['GoogLeNet'] + + +class GoogLeNet(): + def __init__(self): + + pass + + def conv_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + channels = input.shape[1] + stdv = (3.0 / (filter_size**2 * channels))**0.5 + param_attr = ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + "_weights") + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=act, + param_attr=param_attr, + bias_attr=False, + name=name) + return conv + + def xavier(self, channels, filter_size, name): + stdv = (3.0 / (filter_size**2 * channels))**0.5 + param_attr = ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + "_weights") + + return param_attr + + def inception(self, + input, + channels, + filter1, + filter3R, + filter3, + filter5R, + filter5, + proj, + name=None): + conv1 = self.conv_layer( + input=input, + num_filters=filter1, + filter_size=1, + stride=1, + act=None, + name="inception_" + name + "_1x1") + conv3r = self.conv_layer( + input=input, + num_filters=filter3R, + filter_size=1, + stride=1, + act=None, + name="inception_" + name + "_3x3_reduce") + conv3 = self.conv_layer( + input=conv3r, + num_filters=filter3, + filter_size=3, + stride=1, + act=None, + name="inception_" + name + "_3x3") + conv5r = self.conv_layer( + input=input, + num_filters=filter5R, + filter_size=1, + stride=1, + act=None, + name="inception_" + name + "_5x5_reduce") + conv5 = self.conv_layer( + input=conv5r, + num_filters=filter5, + filter_size=5, + stride=1, + act=None, + name="inception_" + name + "_5x5") + pool = fluid.layers.pool2d( + input=input, + pool_size=3, + pool_stride=1, + pool_padding=1, + pool_type='max') + convprj = fluid.layers.conv2d( + input=pool, + filter_size=1, + num_filters=proj, + stride=1, + padding=0, + name="inception_" + name + "_3x3_proj", + param_attr=ParamAttr( + name="inception_" + name + "_3x3_proj_weights"), + bias_attr=False) + cat = fluid.layers.concat(input=[conv1, conv3, conv5, convprj], axis=1) + cat = fluid.layers.relu(cat) + return cat + + def net(self, input, class_dim=1000): + conv = self.conv_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act=None, + name="conv1") + pool = fluid.layers.pool2d( + input=conv, pool_size=3, pool_type='max', pool_stride=2) + + conv = self.conv_layer( + input=pool, + num_filters=64, + filter_size=1, + stride=1, + act=None, + name="conv2_1x1") + conv = self.conv_layer( + input=conv, + num_filters=192, + filter_size=3, + stride=1, + act=None, + name="conv2_3x3") + pool = fluid.layers.pool2d( + input=conv, pool_size=3, pool_type='max', pool_stride=2) + + ince3a = self.inception(pool, 192, 64, 96, 128, 16, 32, 32, "ince3a") + ince3b = self.inception(ince3a, 256, 128, 128, 192, 32, 96, 64, + "ince3b") + pool3 = fluid.layers.pool2d( + input=ince3b, pool_size=3, pool_type='max', pool_stride=2) + + ince4a = self.inception(pool3, 480, 192, 96, 208, 16, 48, 64, "ince4a") + ince4b = self.inception(ince4a, 512, 160, 112, 224, 24, 64, 64, + "ince4b") + ince4c = self.inception(ince4b, 512, 128, 128, 256, 24, 64, 64, + "ince4c") + ince4d = self.inception(ince4c, 512, 112, 144, 288, 32, 64, 64, + "ince4d") + ince4e = self.inception(ince4d, 528, 256, 160, 320, 32, 128, 128, + "ince4e") + pool4 = fluid.layers.pool2d( + input=ince4e, pool_size=3, pool_type='max', pool_stride=2) + + ince5a = self.inception(pool4, 832, 256, 160, 320, 32, 128, 128, + "ince5a") + ince5b = self.inception(ince5a, 832, 384, 192, 384, 48, 128, 128, + "ince5b") + pool5 = fluid.layers.pool2d( + input=ince5b, pool_size=7, pool_type='avg', pool_stride=7) + dropout = fluid.layers.dropout(x=pool5, dropout_prob=0.4) + out = fluid.layers.fc(input=dropout, + size=class_dim, + act='softmax', + param_attr=self.xavier(1024, 1, "out"), + name="out", + bias_attr=ParamAttr(name="out_offset")) + + pool_o1 = fluid.layers.pool2d( + input=ince4a, pool_size=5, pool_type='avg', pool_stride=3) + conv_o1 = self.conv_layer( + input=pool_o1, + num_filters=128, + filter_size=1, + stride=1, + act=None, + name="conv_o1") + fc_o1 = fluid.layers.fc(input=conv_o1, + size=1024, + act='relu', + param_attr=self.xavier(2048, 1, "fc_o1"), + name="fc_o1", + bias_attr=ParamAttr(name="fc_o1_offset")) + dropout_o1 = fluid.layers.dropout(x=fc_o1, dropout_prob=0.7) + out1 = fluid.layers.fc(input=dropout_o1, + size=class_dim, + act='softmax', + param_attr=self.xavier(1024, 1, "out1"), + name="out1", + bias_attr=ParamAttr(name="out1_offset")) + + pool_o2 = fluid.layers.pool2d( + input=ince4d, pool_size=5, pool_type='avg', pool_stride=3) + conv_o2 = self.conv_layer( + input=pool_o2, + num_filters=128, + filter_size=1, + stride=1, + act=None, + name="conv_o2") + fc_o2 = fluid.layers.fc(input=conv_o2, + size=1024, + act='relu', + param_attr=self.xavier(2048, 1, "fc_o2"), + name="fc_o2", + bias_attr=ParamAttr(name="fc_o2_offset")) + dropout_o2 = fluid.layers.dropout(x=fc_o2, dropout_prob=0.7) + out2 = fluid.layers.fc(input=dropout_o2, + size=class_dim, + act='softmax', + param_attr=self.xavier(1024, 1, "out2"), + name="out2", + bias_attr=ParamAttr(name="out2_offset")) + + # last fc layer is "out" + return out, out1, out2 diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/hrnet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/hrnet.py new file mode 100755 index 000000000..ea6b98e74 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/hrnet.py @@ -0,0 +1,322 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["HRNet", "HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C", "HRNet_W44_C", "HRNet_W48_C", "HRNet_W60_C", + "HRNet_W64_C", "SE_HRNet_W18_C", "SE_HRNet_W30_C", "SE_HRNet_W32_C", "SE_HRNet_W40_C", "SE_HRNet_W44_C", + "SE_HRNet_W48_C", "SE_HRNet_W60_C", "SE_HRNet_W64_C"] + + +class HRNet(): + def __init__(self, width=18, has_se=False): + self.width = width + self.has_se = has_se + self.channels = { + 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]], + 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]], + 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]], + 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]], + 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]], + 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]], + 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]] + } + + + def net(self, input, class_dim=1000): + width = self.width + channels_2, channels_3, channels_4 = self.channels[width] + num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3 + + x = self.conv_bn_layer(input=input, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_1') + x = self.conv_bn_layer(input=x, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_2') + + la1 = self.layer1(x, name='layer2') + tr1 = self.transition_layer([la1], [256], channels_2, name='tr1') + st2 = self.stage(tr1, num_modules_2, channels_2, name='st2') + tr2 = self.transition_layer(st2, channels_2, channels_3, name='tr2') + st3 = self.stage(tr2, num_modules_3, channels_3, name='st3') + tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3') + st4 = self.stage(tr3, num_modules_4, channels_4, name='st4') + + #classification + last_cls = self.last_cls_out(x=st4, name='cls_head') + y = last_cls[0] + last_num_filters = [256, 512, 1024] + for i in range(3): + y = fluid.layers.elementwise_add(last_cls[i+1], + self.conv_bn_layer(input=y, filter_size=3, + num_filters=last_num_filters[i], stride=2, + name='cls_head_add'+str(i+1))) + + y = self.conv_bn_layer(input=y, filter_size=1, num_filters=2048, stride=1, name='cls_head_last_conv') + pool = fluid.layers.pool2d(input=y, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, size=class_dim, + param_attr=ParamAttr(name='fc_weights', initializer=fluid.initializer.Uniform(-stdv, stdv)), + bias_attr=ParamAttr(name='fc_offset')) + return out + + + def layer1(self, input, name=None): + conv = input + for i in range(4): + conv = self.bottleneck_block(conv, num_filters=64, downsample=True if i == 0 else False, name=name+'_'+str(i+1)) + return conv + + + def transition_layer(self, x, in_channels, out_channels, name=None): + num_in = len(in_channels) + num_out = len(out_channels) + out = [] + for i in range(num_out): + if i < num_in: + if in_channels[i] != out_channels[i]: + residual = self.conv_bn_layer(x[i], filter_size=3, num_filters=out_channels[i], name=name+'_layer_'+str(i+1)) + out.append(residual) + else: + out.append(x[i]) + else: + residual = self.conv_bn_layer(x[-1], filter_size=3, num_filters=out_channels[i], stride=2, + name=name+'_layer_'+str(i+1)) + out.append(residual) + return out + + + def branches(self, x, block_num, channels, name=None): + out = [] + for i in range(len(channels)): + residual = x[i] + for j in range(block_num): + residual = self.basic_block(residual, channels[i], name=name+'_branch_layer_'+str(i+1)+'_'+str(j+1)) + out.append(residual) + return out + + + def fuse_layers(self, x, channels, multi_scale_output=True, name=None): + out = [] + for i in range(len(channels) if multi_scale_output else 1): + residual = x[i] + for j in range(len(channels)): + if j > i: + y = self.conv_bn_layer(x[j], filter_size=1, num_filters=channels[i], if_act=False, + name=name+'_layer_'+str(i+1)+'_'+str(j+1)) + y = fluid.layers.resize_nearest(input=y, scale=2 ** (j - i)) + residual = fluid.layers.elementwise_add(x=residual, y=y, act=None) + elif j < i: + y = x[j] + for k in range(i - j): + if k == i - j - 1: + y = self.conv_bn_layer(y, filter_size=3, num_filters=channels[i], stride=2, if_act=False, + name=name+'_layer_'+str(i+1)+'_'+str(j+1)+'_'+str(k+1)) + else: + y = self.conv_bn_layer(y, filter_size=3, num_filters=channels[j], stride=2, + name=name+'_layer_'+str(i+1)+'_'+str(j+1)+'_'+str(k+1)) + residual = fluid.layers.elementwise_add(x=residual, y=y, act=None) + + residual = fluid.layers.relu(residual) + out.append(residual) + return out + + + def high_resolution_module(self, x, channels, multi_scale_output=True, name=None): + residual = self.branches(x, 4, channels, name=name) + out = self.fuse_layers(residual, channels, multi_scale_output=multi_scale_output, name=name) + return out + + + def stage(self, x, num_modules, channels, multi_scale_output=True, name=None): + out = x + for i in range(num_modules): + if i == num_modules - 1 and multi_scale_output == False: + out = self.high_resolution_module(out, channels, multi_scale_output=False, name=name+'_'+str(i+1)) + else: + out = self.high_resolution_module(out, channels, name=name+'_'+str(i+1)) + + return out + + + def last_cls_out(self, x, name=None): + out = [] + num_filters_list = [32, 64, 128, 256] + for i in range(len(x)): + out.append(self.bottleneck_block(input=x[i], num_filters=num_filters_list[i], name=name+'conv_'+str(i+1), + downsample=True)) + + return out + + + def basic_block(self, input, num_filters, stride=1, downsample=False, name=None): + residual = input + conv = self.conv_bn_layer(input=input, filter_size=3, num_filters=num_filters, stride=stride, name=name+'_conv1') + conv = self.conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, if_act=False, name=name+'_conv2') + if downsample: + residual = self.conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, if_act=False, + name=name+'_downsample') + if self.has_se: + conv = self.squeeze_excitation( + input=conv, + num_channels=num_filters, + reduction_ratio=16, + name=name+'_fc') + return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') + + + def bottleneck_block(self, input, num_filters, stride=1, downsample=False, name=None): + residual = input + conv = self.conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, name=name+'_conv1') + conv = self.conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, stride=stride, name=name+'_conv2') + conv = self.conv_bn_layer(input=conv, filter_size=1, num_filters=num_filters*4, if_act=False, name=name+'_conv3') + if downsample: + residual = self.conv_bn_layer(input=input, filter_size=1, num_filters=num_filters*4, if_act=False, + name=name+'_downsample') + if self.has_se: + conv = self.squeeze_excitation( + input=conv, + num_channels=num_filters * 4, + reduction_ratio=16, + name=name+'_fc') + return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') + + + def squeeze_excitation(self, input, num_channels, reduction_ratio, name=None): + pool = fluid.layers.pool2d( + input=input, pool_size=0, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + squeeze = fluid.layers.fc(input=pool, + size=num_channels / reduction_ratio, + act='relu', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform( + -stdv, stdv),name=name+'_sqz_weights'), + bias_attr=ParamAttr(name=name+'_sqz_offset')) + stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) + excitation = fluid.layers.fc(input=squeeze, + size=num_channels, + act='sigmoid', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform( + -stdv, stdv),name=name+'_exc_weights'), + bias_attr=ParamAttr(name=name+'_exc_offset')) + scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return scale + + + def conv_bn_layer(self,input, filter_size, num_filters, stride=1, padding=1, num_groups=1, if_act=True, name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size-1)//2, + groups=num_groups, + act=None, + param_attr=ParamAttr(initializer=MSRA(), name=name+'_weights'), + bias_attr=False) + bn_name = name + '_bn' + bn = fluid.layers.batch_norm(input=conv, + param_attr = ParamAttr(name=bn_name+"_scale", initializer=fluid.initializer.Constant(1.0)), + bias_attr=ParamAttr(name=bn_name+"_offset", initializer=fluid.initializer.Constant(0.0)), + moving_mean_name=bn_name+'_mean', + moving_variance_name=bn_name+'_variance') + if if_act: + bn = fluid.layers.relu(bn) + return bn + + +def HRNet_W18_C(): + model = HRNet(width=18) + return model + + +def HRNet_W30_C(): + model = HRNet(width=30) + return model + + +def HRNet_W32_C(): + model = HRNet(width=32) + return model + + +def HRNet_W40_C(): + model = HRNet(width=40) + return model + + +def HRNet_W44_C(): + model = HRNet(width=44) + return model + + +def HRNet_W48_C(): + model = HRNet(width=48) + return model + +def HRNet_W60_C(): + model = HRNet(width=60) + return model + + +def HRNet_W64_C(): + model = HRNet(width=64) + return model + + +def SE_HRNet_W18_C(): + model = HRNet(width=18, has_se=True) + return model + + +def SE_HRNet_W30_C(): + model = HRNet(width=30, has_se=True) + return model + +def SE_HRNet_W32_C(): + model = HRNet(width=32, has_se=True) + return model + + +def SE_HRNet_W40_C(): + model = HRNet(width=40, has_se=True) + return model + + +def SE_HRNet_W44_C(): + model = HRNet(width=44, has_se=True) + return model + + +def SE_HRNet_W48_C(): + model = HRNet(width=48, has_se=True) + return model + + +def SE_HRNet_W60_C(): + model = HRNet(width=60, has_se=True) + return model + + +def SE_HRNet_W64_C(): + model = HRNet(width=64, has_se=True) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/inception_v4.py b/VisualFL/visualfl/algorithm/paddle_clas/models/inception_v4.py new file mode 100755 index 000000000..6d2ad928e --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/inception_v4.py @@ -0,0 +1,345 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = ['InceptionV4'] + + +class InceptionV4(): + def __init__(self): + + pass + + def net(self, input, class_dim=1000): + x = self.inception_stem(input) + + for i in range(4): + x = self.inceptionA(x, name=str(i + 1)) + x = self.reductionA(x) + + for i in range(7): + x = self.inceptionB(x, name=str(i + 1)) + x = self.reductionB(x) + + for i in range(3): + x = self.inceptionC(x, name=str(i + 1)) + + pool = fluid.layers.pool2d( + input=x, pool_type='avg', global_pooling=True) + + drop = fluid.layers.dropout(x=pool, dropout_prob=0.2) + + stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0) + out = fluid.layers.fc( + input=drop, + size=class_dim, + param_attr=ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name="final_fc_weights"), + bias_attr=ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name="final_fc_offset")) + return out + + def conv_bn_layer(self, + data, + num_filters, + filter_size, + stride=1, + padding=0, + groups=1, + act='relu', + name=None): + conv = fluid.layers.conv2d( + input=data, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name) + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def inception_stem(self, data, name=None): + conv = self.conv_bn_layer( + data, 32, 3, stride=2, act='relu', name="conv1_3x3_s2") + conv = self.conv_bn_layer(conv, 32, 3, act='relu', name="conv2_3x3_s1") + conv = self.conv_bn_layer( + conv, 64, 3, padding=1, act='relu', name="conv3_3x3_s1") + + pool1 = fluid.layers.pool2d( + input=conv, pool_size=3, pool_stride=2, pool_type='max') + conv2 = self.conv_bn_layer( + conv, 96, 3, stride=2, act='relu', name="inception_stem1_3x3_s2") + concat = fluid.layers.concat([pool1, conv2], axis=1) + + conv1 = self.conv_bn_layer( + concat, 64, 1, act='relu', name="inception_stem2_3x3_reduce") + conv1 = self.conv_bn_layer( + conv1, 96, 3, act='relu', name="inception_stem2_3x3") + + conv2 = self.conv_bn_layer( + concat, 64, 1, act='relu', name="inception_stem2_1x7_reduce") + conv2 = self.conv_bn_layer( + conv2, + 64, (7, 1), + padding=(3, 0), + act='relu', + name="inception_stem2_1x7") + conv2 = self.conv_bn_layer( + conv2, + 64, (1, 7), + padding=(0, 3), + act='relu', + name="inception_stem2_7x1") + conv2 = self.conv_bn_layer( + conv2, 96, 3, act='relu', name="inception_stem2_3x3_2") + + concat = fluid.layers.concat([conv1, conv2], axis=1) + + conv1 = self.conv_bn_layer( + concat, 192, 3, stride=2, act='relu', name="inception_stem3_3x3_s2") + pool1 = fluid.layers.pool2d( + input=concat, pool_size=3, pool_stride=2, pool_type='max') + + concat = fluid.layers.concat([conv1, pool1], axis=1) + + return concat + + def inceptionA(self, data, name=None): + pool1 = fluid.layers.pool2d( + input=data, pool_size=3, pool_padding=1, pool_type='avg') + conv1 = self.conv_bn_layer( + pool1, 96, 1, act='relu', name="inception_a" + name + "_1x1") + + conv2 = self.conv_bn_layer( + data, 96, 1, act='relu', name="inception_a" + name + "_1x1_2") + + conv3 = self.conv_bn_layer( + data, 64, 1, act='relu', name="inception_a" + name + "_3x3_reduce") + conv3 = self.conv_bn_layer( + conv3, + 96, + 3, + padding=1, + act='relu', + name="inception_a" + name + "_3x3") + + conv4 = self.conv_bn_layer( + data, + 64, + 1, + act='relu', + name="inception_a" + name + "_3x3_2_reduce") + conv4 = self.conv_bn_layer( + conv4, + 96, + 3, + padding=1, + act='relu', + name="inception_a" + name + "_3x3_2") + conv4 = self.conv_bn_layer( + conv4, + 96, + 3, + padding=1, + act='relu', + name="inception_a" + name + "_3x3_3") + + concat = fluid.layers.concat([conv1, conv2, conv3, conv4], axis=1) + + return concat + + def reductionA(self, data, name=None): + pool1 = fluid.layers.pool2d( + input=data, pool_size=3, pool_stride=2, pool_type='max') + + conv2 = self.conv_bn_layer( + data, 384, 3, stride=2, act='relu', name="reduction_a_3x3") + + conv3 = self.conv_bn_layer( + data, 192, 1, act='relu', name="reduction_a_3x3_2_reduce") + conv3 = self.conv_bn_layer( + conv3, 224, 3, padding=1, act='relu', name="reduction_a_3x3_2") + conv3 = self.conv_bn_layer( + conv3, 256, 3, stride=2, act='relu', name="reduction_a_3x3_3") + + concat = fluid.layers.concat([pool1, conv2, conv3], axis=1) + + return concat + + def inceptionB(self, data, name=None): + pool1 = fluid.layers.pool2d( + input=data, pool_size=3, pool_padding=1, pool_type='avg') + conv1 = self.conv_bn_layer( + pool1, 128, 1, act='relu', name="inception_b" + name + "_1x1") + + conv2 = self.conv_bn_layer( + data, 384, 1, act='relu', name="inception_b" + name + "_1x1_2") + + conv3 = self.conv_bn_layer( + data, 192, 1, act='relu', name="inception_b" + name + "_1x7_reduce") + conv3 = self.conv_bn_layer( + conv3, + 224, (1, 7), + padding=(0, 3), + act='relu', + name="inception_b" + name + "_1x7") + conv3 = self.conv_bn_layer( + conv3, + 256, (7, 1), + padding=(3, 0), + act='relu', + name="inception_b" + name + "_7x1") + + conv4 = self.conv_bn_layer( + data, + 192, + 1, + act='relu', + name="inception_b" + name + "_7x1_2_reduce") + conv4 = self.conv_bn_layer( + conv4, + 192, (1, 7), + padding=(0, 3), + act='relu', + name="inception_b" + name + "_1x7_2") + conv4 = self.conv_bn_layer( + conv4, + 224, (7, 1), + padding=(3, 0), + act='relu', + name="inception_b" + name + "_7x1_2") + conv4 = self.conv_bn_layer( + conv4, + 224, (1, 7), + padding=(0, 3), + act='relu', + name="inception_b" + name + "_1x7_3") + conv4 = self.conv_bn_layer( + conv4, + 256, (7, 1), + padding=(3, 0), + act='relu', + name="inception_b" + name + "_7x1_3") + + concat = fluid.layers.concat([conv1, conv2, conv3, conv4], axis=1) + + return concat + + def reductionB(self, data, name=None): + pool1 = fluid.layers.pool2d( + input=data, pool_size=3, pool_stride=2, pool_type='max') + + conv2 = self.conv_bn_layer( + data, 192, 1, act='relu', name="reduction_b_3x3_reduce") + conv2 = self.conv_bn_layer( + conv2, 192, 3, stride=2, act='relu', name="reduction_b_3x3") + + conv3 = self.conv_bn_layer( + data, 256, 1, act='relu', name="reduction_b_1x7_reduce") + conv3 = self.conv_bn_layer( + conv3, + 256, (1, 7), + padding=(0, 3), + act='relu', + name="reduction_b_1x7") + conv3 = self.conv_bn_layer( + conv3, + 320, (7, 1), + padding=(3, 0), + act='relu', + name="reduction_b_7x1") + conv3 = self.conv_bn_layer( + conv3, 320, 3, stride=2, act='relu', name="reduction_b_3x3_2") + + concat = fluid.layers.concat([pool1, conv2, conv3], axis=1) + + return concat + + def inceptionC(self, data, name=None): + pool1 = fluid.layers.pool2d( + input=data, pool_size=3, pool_padding=1, pool_type='avg') + conv1 = self.conv_bn_layer( + pool1, 256, 1, act='relu', name="inception_c" + name + "_1x1") + + conv2 = self.conv_bn_layer( + data, 256, 1, act='relu', name="inception_c" + name + "_1x1_2") + + conv3 = self.conv_bn_layer( + data, 384, 1, act='relu', name="inception_c" + name + "_1x1_3") + conv3_1 = self.conv_bn_layer( + conv3, + 256, (1, 3), + padding=(0, 1), + act='relu', + name="inception_c" + name + "_1x3") + conv3_2 = self.conv_bn_layer( + conv3, + 256, (3, 1), + padding=(1, 0), + act='relu', + name="inception_c" + name + "_3x1") + + conv4 = self.conv_bn_layer( + data, 384, 1, act='relu', name="inception_c" + name + "_1x1_4") + conv4 = self.conv_bn_layer( + conv4, + 448, (1, 3), + padding=(0, 1), + act='relu', + name="inception_c" + name + "_1x3_2") + conv4 = self.conv_bn_layer( + conv4, + 512, (3, 1), + padding=(1, 0), + act='relu', + name="inception_c" + name + "_3x1_2") + conv4_1 = self.conv_bn_layer( + conv4, + 256, (1, 3), + padding=(0, 1), + act='relu', + name="inception_c" + name + "_1x3_3") + conv4_2 = self.conv_bn_layer( + conv4, + 256, (3, 1), + padding=(1, 0), + act='relu', + name="inception_c" + name + "_3x1_3") + + concat = fluid.layers.concat( + [conv1, conv2, conv3_1, conv3_2, conv4_1, conv4_2], axis=1) + + return concat diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/layers.py b/VisualFL/visualfl/algorithm/paddle_clas/models/layers.py new file mode 100755 index 000000000..5900f8ccc --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/layers.py @@ -0,0 +1,222 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +import math +import warnings + +def initial_type(name, + input, + op_type, + fan_out, + init="google", + use_bias=False, + filter_size=0, + stddev=0.02): + if init == "kaiming": + if op_type == 'conv': + fan_in = input.shape[1] * filter_size * filter_size + elif op_type == 'deconv': + fan_in = fan_out * filter_size * filter_size + else: + if len(input.shape) > 2: + fan_in = input.shape[1] * input.shape[2] * input.shape[3] + else: + fan_in = input.shape[1] + bound = 1 / math.sqrt(fan_in) + param_attr = fluid.ParamAttr( + name=name + "_weights", + initializer=fluid.initializer.Uniform( + low=-bound, high=bound)) + if use_bias == True: + bias_attr = fluid.ParamAttr( + name=name + '_offset', + initializer=fluid.initializer.Uniform( + low=-bound, high=bound)) + else: + bias_attr = False + elif init == 'google': + n = filter_size * filter_size * fan_out + param_attr = fluid.ParamAttr( + name=name + "_weights", + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=math.sqrt(2.0 / n))) + if use_bias == True: + bias_attr = fluid.ParamAttr( + name=name + "_offset", initializer=fluid.initializer.Constant(0.0)) + else: + bias_attr = False + + else: + param_attr = fluid.ParamAttr( + name=name + "_weights", + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=stddev)) + if use_bias == True: + bias_attr = fluid.ParamAttr( + name=name + "_offset", initializer=fluid.initializer.Constant(0.0)) + else: + bias_attr = False + return param_attr, bias_attr + +def cal_padding(img_size, stride, filter_size, dilation=1): + """Calculate padding size.""" + if img_size % stride == 0: + out_size = max(filter_size - stride, 0) + else: + out_size = max(filter_size - (img_size % stride), 0) + return out_size // 2, out_size - out_size // 2 + +def init_batch_norm_layer(name="batch_norm"): + param_attr = fluid.ParamAttr( + name=name + '_scale', initializer=fluid.initializer.Constant(1.0)) + bias_attr = fluid.ParamAttr( + name=name + '_offset', initializer=fluid.initializer.Constant(value=0.0)) + return param_attr, bias_attr + +def init_fc_layer(fout, name='fc'): + n = fout # fan-out + init_range = 1.0 / math.sqrt(n) + + param_attr = fluid.ParamAttr( + name=name + '_weights', initializer=fluid.initializer.UniformInitializer( + low=-init_range, high=init_range)) + bias_attr = fluid.ParamAttr( + name=name + '_offset', initializer=fluid.initializer.Constant(value=0.0)) + return param_attr, bias_attr + +def norm_layer(input, norm_type='batch_norm', name=None): + if norm_type == 'batch_norm': + param_attr = fluid.ParamAttr( + name=name + '_weights', initializer=fluid.initializer.Constant(1.0)) + bias_attr = fluid.ParamAttr( + name=name + '_offset', initializer=fluid.initializer.Constant(value=0.0)) + return fluid.layers.batch_norm( + input, + param_attr=param_attr, + bias_attr=bias_attr, + moving_mean_name=name + '_mean', + moving_variance_name=name + '_variance') + + elif norm_type == 'instance_norm': + helper = fluid.layer_helper.LayerHelper("instance_norm", **locals()) + dtype = helper.input_dtype() + epsilon = 1e-5 + mean = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True) + var = fluid.layers.reduce_mean( + fluid.layers.square(input - mean), dim=[2, 3], keep_dim=True) + if name is not None: + scale_name = name + "_scale" + offset_name = name + "_offset" + scale_param = fluid.ParamAttr( + name=scale_name, + initializer=fluid.initializer.Constant(1.0), + trainable=True) + offset_param = fluid.ParamAttr( + name=offset_name, + initializer=fluid.initializer.Constant(0.0), + trainable=True) + scale = helper.create_parameter( + attr=scale_param, shape=input.shape[1:2], dtype=dtype) + offset = helper.create_parameter( + attr=offset_param, shape=input.shape[1:2], dtype=dtype) + + tmp = fluid.layers.elementwise_mul(x=(input - mean), y=scale, axis=1) + tmp = tmp / fluid.layers.sqrt(var + epsilon) + tmp = fluid.layers.elementwise_add(tmp, offset, axis=1) + return tmp + else: + raise NotImplementedError("norm tyoe: [%s] is not support" % norm_type) + + +def conv2d(input, + num_filters=64, + filter_size=7, + stride=1, + stddev=0.02, + padding=0, + groups=None, + name="conv2d", + norm=None, + act=None, + relufactor=0.0, + use_bias=False, + padding_type=None, + initial="normal", + use_cudnn=True): + + if padding != 0 and padding_type != None: + warnings.warn( + 'padding value and padding type are set in the same time, and the final padding width and padding height are computed by padding_type' + ) + + param_attr, bias_attr = initial_type( + name=name, + input=input, + op_type='conv', + fan_out=num_filters, + init=initial, + use_bias=use_bias, + filter_size=filter_size, + stddev=stddev) + + def get_padding(filter_size, stride=1, dilation=1): + padding = ((stride - 1) + dilation * (filter_size - 1)) // 2 + return padding + + need_crop = False + if padding_type == "SAME": + top_padding, bottom_padding = cal_padding(input.shape[2], stride, + filter_size) + left_padding, right_padding = cal_padding(input.shape[2], stride, + filter_size) + height_padding = bottom_padding + width_padding = right_padding + if top_padding != bottom_padding or left_padding != right_padding: + height_padding = top_padding + stride + width_padding = left_padding + stride + need_crop = True + padding = [height_padding, width_padding] + elif padding_type == "VALID": + height_padding = 0 + width_padding = 0 + padding = [height_padding, width_padding] + elif padding_type == "DYNAMIC": + padding = get_padding(filter_size, stride) + else: + padding = padding + + conv = fluid.layers.conv2d( + input, + num_filters, + filter_size, + groups=groups, + name=name, + stride=stride, + padding=padding, + use_cudnn=use_cudnn, + param_attr=param_attr, + bias_attr=bias_attr) + + if need_crop: + conv = conv[:, :, 1:, 1:] + + if norm is not None: + conv = norm_layer(input=conv, norm_type=norm, name=name + "_norm") + if act == 'relu': + conv = fluid.layers.relu(conv, name=name + '_relu') + elif act == 'leaky_relu': + conv = fluid.layers.leaky_relu( + conv, alpha=relufactor, name=name + '_leaky_relu') + elif act == 'tanh': + conv = fluid.layers.tanh(conv, name=name + '_tanh') + elif act == 'sigmoid': + conv = fluid.layers.sigmoid(conv, name=name + '_sigmoid') + elif act == 'swish': + conv = fluid.layers.swish(conv, name=name + '_swish') + elif act == None: + conv = conv + else: + raise NotImplementedError("activation: [%s] is not support" %act) + + return conv \ No newline at end of file diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/lenet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/lenet.py new file mode 100644 index 000000000..d869becc2 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/lenet.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid + +__all__ = ['LeNet'] + + +class LeNet(): + def __init__(self): + pass + + def net(self,input, class_dim=1000): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=input, + num_filters=6, + filter_size=5, + pool_size=2, + pool_stride=2, + act="relu", + ) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=16, + filter_size=5, + pool_size=2, + pool_stride=2, + act="relu", + ) + conv_pool_3 = fluid.nets.simple_img_conv_pool( + input=conv_pool_2, + num_filters=120, + filter_size=4, + pool_size=1, + pool_stride=1, + act="relu", + ) + fc1 = fluid.layers.fc( + input=conv_pool_3, size=64, act="relu" + ) + out = fluid.layers.fc( + input=fc1, size=class_dim + ) + + return out \ No newline at end of file diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/mobilenet_v1.py b/VisualFL/visualfl/algorithm/paddle_clas/models/mobilenet_v1.py new file mode 100755 index 000000000..5f14de811 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/mobilenet_v1.py @@ -0,0 +1,218 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + 'MobileNetV1', 'MobileNetV1_x0_25', 'MobileNetV1_x0_5', 'MobileNetV1_x1_0', + 'MobileNetV1_x0_75' +] + + +class MobileNetV1(): + def __init__(self, scale=1.0): + self.scale = scale + + def net(self, input, class_dim=1000): + scale = self.scale + # conv1: 112x112 + input = self.conv_bn_layer( + input, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1, + name="conv1") + + # 56x56 + input = self.depthwise_separable( + input, + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale, + name="conv2_1") + + input = self.depthwise_separable( + input, + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=2, + scale=scale, + name="conv2_2") + + # 28x28 + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale, + name="conv3_1") + + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=2, + scale=scale, + name="conv3_2") + + # 14x14 + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale, + name="conv4_1") + + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=2, + scale=scale, + name="conv4_2") + + # 14x14 + for i in range(5): + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + scale=scale, + name="conv5" + "_" + str(i + 1)) + # 7x7 + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=2, + scale=scale, + name="conv5_6") + + input = self.depthwise_separable( + input, + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=1, + scale=scale, + name="conv6") + + input = fluid.layers.pool2d( + input=input, pool_type='avg', global_pooling=True) + + output = fluid.layers.fc(input=input, + size=class_dim, + param_attr=ParamAttr( + initializer=MSRA(), name="fc7_weights"), + bias_attr=ParamAttr(name="fc7_offset")) + return output + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='relu', + use_cudnn=True, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr( + initializer=MSRA(), name=name + "_weights"), + bias_attr=False) + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def depthwise_separable(self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + name=None): + depthwise_conv = self.conv_bn_layer( + input=input, + filter_size=3, + num_filters=int(num_filters1 * scale), + stride=stride, + padding=1, + num_groups=int(num_groups * scale), + use_cudnn=False, + name=name + "_dw") + + pointwise_conv = self.conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0, + name=name + "_sep") + return pointwise_conv + + +def MobileNetV1_x0_25(): + model = MobileNetV1(scale=0.25) + return model + + +def MobileNetV1_x0_5(): + model = MobileNetV1(scale=0.5) + return model + + +def MobileNetV1_x1_0(): + model = MobileNetV1(scale=1.0) + return model + + +def MobileNetV1_x0_75(): + model = MobileNetV1(scale=0.75) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/mobilenet_v2.py b/VisualFL/visualfl/algorithm/paddle_clas/models/mobilenet_v2.py new file mode 100755 index 000000000..e9c2277a2 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/mobilenet_v2.py @@ -0,0 +1,230 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + 'MobileNetV2_x0_25', 'MobileNetV2_x0_5' + 'MobileNetV2_x0_75', 'MobileNetV2_x1_0', 'MobileNetV2_x1_5', + 'MobileNetV2_x2_0', 'MobileNetV2' +] + + +class MobileNetV2(): + def __init__(self, scale=1.0): + self.scale = scale + + def net(self, input, class_dim=1000): + scale = self.scale + bottleneck_params_list = [ + (1, 16, 1, 1), + (6, 24, 2, 2), + (6, 32, 3, 2), + (6, 64, 4, 2), + (6, 96, 3, 1), + (6, 160, 3, 2), + (6, 320, 1, 1), + ] + + #conv1 + input = self.conv_bn_layer( + input, + num_filters=int(32 * scale), + filter_size=3, + stride=2, + padding=1, + if_act=True, + name='conv1_1') + + # bottleneck sequences + i = 1 + in_c = int(32 * scale) + for layer_setting in bottleneck_params_list: + t, c, n, s = layer_setting + i += 1 + input = self.invresi_blocks( + input=input, + in_c=in_c, + t=t, + c=int(c * scale), + n=n, + s=s, + name='conv' + str(i)) + in_c = int(c * scale) + #last_conv + input = self.conv_bn_layer( + input=input, + num_filters=int(1280 * scale) if scale > 1.0 else 1280, + filter_size=1, + stride=1, + padding=0, + if_act=True, + name='conv9') + + input = fluid.layers.pool2d( + input=input, pool_type='avg', global_pooling=True) + + output = fluid.layers.fc(input=input, + size=class_dim, + param_attr=ParamAttr(name='fc10_weights'), + bias_attr=ParamAttr(name='fc10_offset')) + return output + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + if_act=True, + name=None, + use_cudnn=True): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + bn_name = name + '_bn' + bn = fluid.layers.batch_norm( + input=conv, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + if if_act: + return fluid.layers.relu6(bn) + else: + return bn + + def shortcut(self, input, data_residual): + return fluid.layers.elementwise_add(input, data_residual) + + def inverted_residual_unit(self, + input, + num_in_filter, + num_filters, + ifshortcut, + stride, + filter_size, + padding, + expansion_factor, + name=None): + num_expfilter = int(round(num_in_filter * expansion_factor)) + + channel_expand = self.conv_bn_layer( + input=input, + num_filters=num_expfilter, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name=name + '_expand') + + bottleneck_conv = self.conv_bn_layer( + input=channel_expand, + num_filters=num_expfilter, + filter_size=filter_size, + stride=stride, + padding=padding, + num_groups=num_expfilter, + if_act=True, + name=name + '_dwise', + use_cudnn=False) + + linear_out = self.conv_bn_layer( + input=bottleneck_conv, + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=False, + name=name + '_linear') + if ifshortcut: + out = self.shortcut(input=input, data_residual=linear_out) + return out + else: + return linear_out + + def invresi_blocks(self, input, in_c, t, c, n, s, name=None): + first_block = self.inverted_residual_unit( + input=input, + num_in_filter=in_c, + num_filters=c, + ifshortcut=False, + stride=s, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + '_1') + + last_residual_block = first_block + last_c = c + + for i in range(1, n): + last_residual_block = self.inverted_residual_unit( + input=last_residual_block, + num_in_filter=last_c, + num_filters=c, + ifshortcut=True, + stride=1, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + '_' + str(i + 1)) + return last_residual_block + + +def MobileNetV2_x0_25(): + model = MobileNetV2(scale=0.25) + return model + + +def MobileNetV2_x0_5(): + model = MobileNetV2(scale=0.5) + return model + + +def MobileNetV2_x0_75(): + model = MobileNetV2(scale=0.75) + return model + + +def MobileNetV2_x1_0(): + model = MobileNetV2(scale=1.0) + return model + + +def MobileNetV2_x1_5(): + model = MobileNetV2(scale=1.5) + return model + + +def MobileNetV2_x2_0(): + model = MobileNetV2(scale=2.0) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/mobilenet_v3.py b/VisualFL/visualfl/algorithm/paddle_clas/models/mobilenet_v3.py new file mode 100755 index 000000000..d2b5b1587 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/mobilenet_v3.py @@ -0,0 +1,300 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + 'MobileNetV3', 'MobileNetV3_small_x0_25', 'MobileNetV3_small_x0_5', + 'MobileNetV3_small_x0_75', 'MobileNetV3_small_x1_0', + 'MobileNetV3_small_x1_25', 'MobileNetV3_large_x0_25', + 'MobileNetV3_large_x0_5', 'MobileNetV3_large_x0_75', + 'MobileNetV3_large_x1_0', 'MobileNetV3_large_x1_25' +] + + +class MobileNetV3(): + def __init__(self, scale=1.0, model_name='small'): + self.scale = scale + self.inplanes = 16 + if model_name == "large": + self.cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, 'relu', 1], + [3, 64, 24, False, 'relu', 2], + [3, 72, 24, False, 'relu', 1], + [5, 72, 40, True, 'relu', 2], + [5, 120, 40, True, 'relu', 1], + [5, 120, 40, True, 'relu', 1], + [3, 240, 80, False, 'hard_swish', 2], + [3, 200, 80, False, 'hard_swish', 1], + [3, 184, 80, False, 'hard_swish', 1], + [3, 184, 80, False, 'hard_swish', 1], + [3, 480, 112, True, 'hard_swish', 1], + [3, 672, 112, True, 'hard_swish', 1], + [5, 672, 160, True, 'hard_swish', 2], + [5, 960, 160, True, 'hard_swish', 1], + [5, 960, 160, True, 'hard_swish', 1], + ] + self.cls_ch_squeeze = 960 + self.cls_ch_expand = 1280 + elif model_name == "small": + self.cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, 'relu', 2], + [3, 72, 24, False, 'relu', 2], + [3, 88, 24, False, 'relu', 1], + [5, 96, 40, True, 'hard_swish', 2], + [5, 240, 40, True, 'hard_swish', 1], + [5, 240, 40, True, 'hard_swish', 1], + [5, 120, 48, True, 'hard_swish', 1], + [5, 144, 48, True, 'hard_swish', 1], + [5, 288, 96, True, 'hard_swish', 2], + [5, 576, 96, True, 'hard_swish', 1], + [5, 576, 96, True, 'hard_swish', 1], + ] + self.cls_ch_squeeze = 576 + self.cls_ch_expand = 1280 + else: + raise NotImplementedError("mode[" + model_name + + "_model] is not implemented!") + + def net(self, input, class_dim=1000): + scale = self.scale + inplanes = self.inplanes + cfg = self.cfg + cls_ch_squeeze = self.cls_ch_squeeze + cls_ch_expand = self.cls_ch_expand + + #conv1 + conv = self.conv_bn_layer( + input, + filter_size=3, + num_filters=int(scale * inplanes), + stride=2, + padding=1, + num_groups=1, + if_act=True, + act='hard_swish', + name='conv1') + i = 0 + for layer_cfg in cfg: + conv = self.residual_unit( + input=conv, + num_in_filter=inplanes, + num_mid_filter=int(scale * layer_cfg[1]), + num_out_filter=int(scale * layer_cfg[2]), + act=layer_cfg[4], + stride=layer_cfg[5], + filter_size=layer_cfg[0], + use_se=layer_cfg[3], + name='conv' + str(i + 2)) + inplanes = int(scale * layer_cfg[2]) + i += 1 + + conv = self.conv_bn_layer( + input=conv, + filter_size=1, + num_filters=int(scale * cls_ch_squeeze), + stride=1, + padding=0, + num_groups=1, + if_act=True, + act='hard_swish', + name='conv_last') + conv = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True, use_cudnn=False) + conv = fluid.layers.conv2d( + input=conv, + num_filters=cls_ch_expand, + filter_size=1, + stride=1, + padding=0, + act=None, + param_attr=ParamAttr(name='last_1x1_conv_weights'), + bias_attr=False) + conv = self.hard_swish(conv) + drop = fluid.layers.dropout(x=conv, dropout_prob=0.2) + out = fluid.layers.fc(input=drop, + size=class_dim, + param_attr=ParamAttr(name='fc_weights'), + bias_attr=ParamAttr(name='fc_offset')) + return out + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + if_act=True, + act=None, + name=None, + use_cudnn=True): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + bn_name = name + '_bn' + bn = fluid.layers.batch_norm( + input=conv, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + if if_act: + if act == 'relu': + bn = fluid.layers.relu(bn) + elif act == 'hard_swish': + bn = self.hard_swish(bn) + return bn + + def hard_swish(self, x): + return x * fluid.layers.relu6(x + 3) / 6. + + def se_block(self, input, num_out_filter, ratio=4, name=None): + num_mid_filter = int(num_out_filter // ratio) + pool = fluid.layers.pool2d( + input=input, pool_type='avg', global_pooling=True, use_cudnn=False) + conv1 = fluid.layers.conv2d( + input=pool, + filter_size=1, + num_filters=num_mid_filter, + act='relu', + param_attr=ParamAttr(name=name + '_1_weights'), + bias_attr=ParamAttr(name=name + '_1_offset')) + conv2 = fluid.layers.conv2d( + input=conv1, + filter_size=1, + num_filters=num_out_filter, + act='hard_sigmoid', + param_attr=ParamAttr(name=name + '_2_weights'), + bias_attr=ParamAttr(name=name + '_2_offset')) + scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0) + return scale + + def residual_unit(self, + input, + num_in_filter, + num_mid_filter, + num_out_filter, + stride, + filter_size, + act=None, + use_se=False, + name=None): + + first_conv = (num_out_filter != num_mid_filter) + input_data = input + if first_conv: + input = self.conv_bn_layer( + input=input, + filter_size=1, + num_filters=num_mid_filter, + stride=1, + padding=0, + if_act=True, + act=act, + name=name + '_expand') + + conv1 = self.conv_bn_layer( + input=input, + filter_size=filter_size, + num_filters=num_mid_filter, + stride=stride, + padding=int((filter_size - 1) // 2), + if_act=True, + act=act, + num_groups=num_mid_filter, + use_cudnn=True, + name=name + '_depthwise') + if use_se: + conv1 = self.se_block( + input=conv1, num_out_filter=num_mid_filter, name=name + '_se') + + conv2 = self.conv_bn_layer( + input=conv1, + filter_size=1, + num_filters=num_out_filter, + stride=1, + padding=0, + if_act=False, + name=name + '_linear') + if num_in_filter != num_out_filter or stride != 1: + return conv2 + else: + return fluid.layers.elementwise_add(x=input_data, y=conv2, act=None) + + +def MobileNetV3_small_x0_25(): + model = MobileNetV3(model_name='small', scale=0.25) + return model + + +def MobileNetV3_small_x0_5(): + model = MobileNetV3(model_name='small', scale=0.5) + return model + + +def MobileNetV3_small_x0_75(): + model = MobileNetV3(model_name='small', scale=0.75) + return model + + +def MobileNetV3_small_x1_0(): + model = MobileNetV3(model_name='small', scale=1.0) + return model + + +def MobileNetV3_small_x1_25(): + model = MobileNetV3(model_name='small', scale=1.25) + return model + + +def MobileNetV3_large_x0_25(): + model = MobileNetV3(model_name='large', scale=0.25) + return model + + +def MobileNetV3_large_x0_5(): + model = MobileNetV3(model_name='large', scale=0.5) + return model + + +def MobileNetV3_large_x0_75(): + model = MobileNetV3(model_name='large', scale=0.75) + return model + + +def MobileNetV3_large_x1_0(): + model = MobileNetV3(model_name='large', scale=1.0) + return model + + +def MobileNetV3_large_x1_25(): + model = MobileNetV3(model_name='large', scale=1.25) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/model_libs.py b/VisualFL/visualfl/algorithm/paddle_clas/models/model_libs.py new file mode 100755 index 000000000..a0b97c504 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/model_libs.py @@ -0,0 +1,128 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import contextlib + +bn_regularizer = fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0) +name_scope = "" + +@contextlib.contextmanager +def scope(name): + global name_scope + bk = name_scope + name_scope = name_scope + name + '/' + yield + name_scope = bk + +def max_pool(input, kernel, stride, padding): + data = fluid.layers.pool2d(input, pool_size=kernel, pool_type='max', + pool_stride=stride, pool_padding=padding) + return data + +def group_norm(input, G, eps=1e-5, param_attr=None, bias_attr=None): + N, C, H, W = input.shape + if C % G != 0: + # print "group can not divide channle:", C, G + for d in range(10): + for t in [d, -d]: + if G + t <= 0: continue + if C % (G + t) == 0: + G = G + t + break + if C % G == 0: + # print "use group size:", G + break + assert C % G == 0 + x = fluid.layers.group_norm( + input, + groups=G, + param_attr=param_attr, + bias_attr=bias_attr, + name=name_scope + 'group_norm') + return x + +def bn(*args, **kargs): + with scope('BatchNorm'): + return fluid.layers.batch_norm( + *args, + epsilon=1e-3, + momentum=0.99, + param_attr=fluid.ParamAttr( + name=name_scope + 'gamma', regularizer=bn_regularizer), + bias_attr=fluid.ParamAttr( + name=name_scope + 'beta', regularizer=bn_regularizer), + moving_mean_name=name_scope + 'moving_mean', + moving_variance_name=name_scope + 'moving_variance', + **kargs) + +def bn_relu(data): + return fluid.layers.relu(bn(data)) + +def relu(data): + return fluid.layers.relu(data) + +def conv(*args, **kargs): + kargs['param_attr'] = name_scope + 'weights' + if 'bias_attr' in kargs and kargs['bias_attr']: + kargs['bias_attr'] = fluid.ParamAttr( + name=name_scope + 'biases', + regularizer=None, + initializer=fluid.initializer.ConstantInitializer(value=0.0)) + else: + kargs['bias_attr'] = False + return fluid.layers.conv2d(*args, **kargs) + +def deconv(*args, **kargs): + kargs['param_attr'] = name_scope + 'weights' + if 'bias_attr' in kargs and kargs['bias_attr']: + kargs['bias_attr'] = name_scope + 'biases' + else: + kargs['bias_attr'] = False + return fluid.layers.conv2d_transpose(*args, **kargs) + +def seperate_conv(input, channel, stride, filter, dilation=1, act=None): + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=0.0), + initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.33)) + with scope('depthwise'): + input = conv( + input, + input.shape[1], + filter, + stride, + groups=input.shape[1], + padding=(filter // 2) * dilation, + dilation=dilation, + use_cudnn=False, + param_attr=param_attr) + input = bn(input) + if act: input = act(input) + + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=None, + initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.06)) + with scope('pointwise'): + input = conv(input, channel, 1, 1, groups=1, padding=0, + param_attr=param_attr) + input = bn(input) + if act: input = act(input) + return input diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/paddle_clas_model_list.txt b/VisualFL/visualfl/algorithm/paddle_clas/models/paddle_clas_model_list.txt new file mode 100644 index 000000000..7124636e3 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/paddle_clas_model_list.txt @@ -0,0 +1,95 @@ +LeNet +AlexNet +VGG11 +VGG13 +VGG16 +VGG19 +ShuffleNetV2_x0_25 +ShuffleNetV2_x0_33 +ShuffleNetV2_x0_5 +ShuffleNetV2_x1_0 +ShuffleNetV2_x1_5 +ShuffleNetV2_x2_0 +SqueezeNet1_0 +SqueezeNet1_1 +InceptionV4 +Xception41 +Xception65 +Xception71 +ResNet18 +ResNet34 +ResNet50 +ResNet101 +ResNet152 +ResNet50_vc +ResNet101_vc +ResNet152_vc +ResNet18_vd +ResNet34_vd +ResNet50_vd +ResNet101_vd +ResNet152_vd +ResNet200_vd +SE_ResNet18_vd +SE_ResNet34_vd +SE_ResNet50_vd +SE_ResNet101_vd +SE_ResNet152_vd +SE_ResNet200_vd +SE_ResNeXt50_32x4d +SE_ResNeXt101_32x4d +SE_ResNeXt152_32x4d +SE_ResNeXt50_vd_32x4d +SE_ResNeXt101_vd_32x4d +SENet154_vd +DenseNet121 +DenseNet161 +DenseNet169 +DenseNet201 +DenseNet264 +DarkNet53 +ResNeXt50_64x4d +ResNeXt101_64x4d +ResNeXt152_64x4d +ResNeXt50_32x4d +ResNeXt101_32x4d +ResNeXt152_32x4d +ResNeXt50_vd_64x4d +ResNeXt101_vd_64x4d +ResNeXt152_vd_64x4d +ResNeXt50_vd_32x4d +ResNeXt101_vd_32x4d +ResNeXt152_vd_32x4d +Res2Net50_48w_2s +Res2Net50_26w_4s +Res2Net50_14w_8s +Res2Net50_26w_6s +Res2Net50_26w_8s +Res2Net101_26w_4s +Res2Net152_26w_4s +Res2Net50_vd_48w_2s +Res2Net50_vd_26w_4s +Res2Net50_vd_14w_8s +Res2Net50_vd_26w_6s +Res2Net50_vd_26w_8s +Res2Net101_vd_26w_4s +Res2Net152_vd_26w_4s +Res2Net200_vd_26w_4s +DPN68 +DPN92 +DPN98 +DPN107 +DPN131 +MobileNetV1_x0_25 +MobileNetV1_x0_5 +MobileNetV1_x1_0 +MobileNetV1_x0_75 +MobileNetV3_small_x0_25 +MobileNetV3_small_x0_5 +MobileNetV3_small_x0_75 +MobileNetV3_small_x1_0 +MobileNetV3_small_x1_25 +HRNet_W18_C +HRNet_W32_C +HRNet_W48_C +HRNet_W64_C \ No newline at end of file diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/res2net.py b/VisualFL/visualfl/algorithm/paddle_clas/models/res2net.py new file mode 100755 index 000000000..6237244ec --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/res2net.py @@ -0,0 +1,200 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["Res2Net", "Res2Net50_48w_2s", "Res2Net50_26w_4s", "Res2Net50_14w_8s", "Res2Net50_26w_6s", "Res2Net50_26w_8s", + "Res2Net101_26w_4s", "Res2Net152_26w_4s"] + + +class Res2Net(): + + def __init__(self, layers=50, scales=4, width=26): + self.layers = layers + self.scales = scales + self.width = width + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + basic_width = self.width * self.scales + num_filters1 = [basic_width * t for t in [1, 2, 4, 8]] + num_filters2 = [256 * t for t in [1, 2, 4, 8]] + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + conv = self.conv_bn_layer( + input=input, num_filters=64, filter_size=7, stride=2, act='relu', name="conv1") + + + conv = fluid.layers.pool2d( + input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block+2) + "a" + else: + conv_name = "res" + str(block+2) + "b" + str(i) + else: + conv_name = "res" + str(block+2) + chr(97+i) + conv = self.bottleneck_block( + input=conv, + num_filters1=num_filters1[block], + num_filters2=num_filters2[block], + stride=2 if i==0 and block !=0 else 1, name=conv_name) + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_stride=1, pool_type='avg', global_pooling=True) + + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc_weights'), + bias_attr=fluid.param_attr.ParamAttr(name='fc_offset')) + return out + + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1)//2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + + return fluid.layers.batch_norm(input=conv, + act=act, + param_attr=ParamAttr(name=bn_name+'_scale'), + bias_attr=ParamAttr(bn_name+'_offset'), + moving_mean_name=bn_name+'_mean', + moving_variance_name=bn_name+'_variance') + + + def shortcut(self, input, ch_out, stride, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + + def bottleneck_block(self, input, num_filters1, num_filters2, stride, name): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters1, + filter_size=1, + stride=1, + act='relu', + name=name+'_branch2a') + xs = fluid.layers.split(conv0, self.scales, 1) + ys = [] + for s in range(self.scales - 1): + if s == 0 or stride == 2: + ys.append(self.conv_bn_layer(input=xs[s], + num_filters=num_filters1//self.scales, + stride=stride, + filter_size=3, + act='relu', + name=name+'_branch2b_'+str(s+1))) + else: + ys.append(self.conv_bn_layer(input=xs[s]+ys[-1], + num_filters=num_filters1//self.scales, + stride=stride, + filter_size=3, + act='relu', + name=name+'_branch2b_'+str(s+1))) + if stride == 1: + ys.append(xs[-1]) + else: + ys.append(fluid.layers.pool2d(input=xs[-1], + pool_size=3, + pool_stride=stride, + pool_padding=1, + pool_type='avg')) + + conv1 = fluid.layers.concat(ys, axis=1) + conv2 = self.conv_bn_layer( + input=conv1, num_filters=num_filters2, filter_size=1, act=None, name=name+"_branch2c") + + short = self.shortcut(input, num_filters2, stride, name=name+"_branch1") + + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') + + + +def Res2Net50_48w_2s(): + model = Res2Net(layers=50, scales=2, width=48) + return model + + +def Res2Net50_26w_4s(): + model = Res2Net(layers=50, scales=4, width=26) + return model + + +def Res2Net50_14w_8s(): + model = Res2Net(layers=50, scales=8, width=14) + return model + + +def Res2Net50_26w_6s(): + model = Res2Net(layers=50, scales=6, width=26) + return model + + +def Res2Net50_26w_8s(): + model = Res2Net(layers=50, scales=8, width=26) + return model + + +def Res2Net101_26w_4s(): + model = Res2Net(layers=101, scales=4, width=26) + return model + + +def Res2Net152_26w_4s(): + model = Res2Net(layers=152, scales=4, width=26) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/res2net_vd.py b/VisualFL/visualfl/algorithm/paddle_clas/models/res2net_vd.py new file mode 100755 index 000000000..596ff6063 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/res2net_vd.py @@ -0,0 +1,250 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr +__all__ = ["Res2Net_vd", "Res2Net50_vd_48w_2s", "Res2Net50_vd_26w_4s", "Res2Net50_vd_14w_8s", "Res2Net50_vd_26w_6s", + "Res2Net50_vd_26w_8s", "Res2Net101_vd_26w_4s", "Res2Net152_vd_26w_4s", "Res2Net200_vd_26w_4s"] + + +class Res2Net_vd(): + + def __init__(self, layers=50, scales=4, width=26): + self.layers = layers + self.scales = scales + self.width = width + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + basic_width = self.width * self.scales + num_filters1 = [basic_width * t for t in [1, 2, 4, 8]] + num_filters2 = [256 * t for t in [1, 2, 4, 8]] + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + conv = self.conv_bn_layer( + input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1') + conv = self.conv_bn_layer( + input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2') + conv = self.conv_bn_layer( + input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3') + + conv = fluid.layers.pool2d( + input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block+2 )+ "a" + else: + conv_name = "res" + str(block+2) + "b" + str(i) + else: + conv_name = "res" + str(block+2) + chr(97+i) + conv = self.bottleneck_block( + input=conv, + num_filters1=num_filters1[block], + num_filters2=num_filters2[block], + stride=2 if i==0 and block!=0 else 1, + if_first=block==i==0, + name=conv_name) + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_stride=1, pool_type='avg', global_pooling=True) + + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), name='fc_weights'), + bias_attr=fluid.param_attr.ParamAttr(name='fc_offset')) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name+"_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm(input=conv, + act=act, + param_attr=ParamAttr(name=bn_name+'_scale'), + bias_attr=ParamAttr(bn_name+'_offset'), + moving_mean_name=bn_name+'_mean', + moving_variance_name=bn_name+'_variance') + + def conv_bn_layer_new(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + pool = fluid.layers.pool2d(input=input, + pool_size=2, + pool_stride=2, + pool_padding=0, + pool_type='avg', + ceil_mode=True) + + conv = fluid.layers.conv2d( + input=pool, + num_filters=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1)//2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name+"_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm(input=conv, + act=act, + param_attr=ParamAttr(name=bn_name+'_scale'), + bias_attr=ParamAttr(bn_name+'_offset'), + moving_mean_name=bn_name+'_mean', + moving_variance_name=bn_name+'_variance') + + + def shortcut(self, input, ch_out, stride, name, if_first=False): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + if if_first: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name) + elif if_first: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + + def bottleneck_block(self, input, num_filters1, num_filters2, stride, name, if_first): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters1, + filter_size=1, + stride=1, + act='relu', + name=name+'_branch2a') + + xs = fluid.layers.split(conv0, self.scales, 1) + ys = [] + for s in range(self.scales - 1): + if s == 0 or stride == 2: + ys.append(self.conv_bn_layer(input=xs[s], + num_filters=num_filters1//self.scales, + stride=stride, + filter_size=3, + act='relu', + name=name+'_branch2b_'+str(s+1))) + else: + ys.append(self.conv_bn_layer(input=xs[s]+ys[-1], + num_filters=num_filters1//self.scales, + stride=stride, + filter_size=3, + act='relu', + name=name+'_branch2b_'+str(s+1))) + + if stride == 1: + ys.append(xs[-1]) + else: + ys.append(fluid.layers.pool2d(input=xs[-1], + pool_size=3, + pool_stride=stride, + pool_padding=1, + pool_type='avg')) + + conv1 = fluid.layers.concat(ys, axis=1) + conv2 = self.conv_bn_layer( + input=conv1, num_filters=num_filters2, filter_size=1, act=None, name=name+"_branch2c") + + short = self.shortcut(input, num_filters2, stride, if_first=if_first, name=name+"_branch1") + + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') + + + + +def Res2Net50_vd_48w_2s(): + model = Res2Net_vd(layers=50, scales=2, width=48) + return model + + +def Res2Net50_vd_26w_4s(): + model = Res2Net_vd(layers=50, scales=4, width=26) + return model + + +def Res2Net50_vd_14w_8s(): + model = Res2Net_vd(layers=50, scales=8, width=14) + return model + + +def Res2Net50_vd_26w_6s(): + model = Res2Net_vd(layers=50, scales=6, width=26) + return model + + +def Res2Net50_vd_26w_8s(): + model = Res2Net_vd(layers=50, scales=8, width=26) + return model + + +def Res2Net101_vd_26w_4s(): + model = Res2Net_vd(layers=101, scales=4, width=26) + return model + + +def Res2Net152_vd_26w_4s(): + model = Res2Net_vd(layers=152, scales=4, width=26) + return model + + +def Res2Net200_vd_26w_4s(): + model = Res2Net_vd(layers=200, scales=4, width=26) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/resnet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/resnet.py new file mode 100755 index 000000000..fcf453588 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/resnet.py @@ -0,0 +1,236 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + "ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152" +] + + +class ResNet(): + def __init__(self, layers=50): + self.layers = layers + + def net(self, input, class_dim=1000, data_format="NCHW"): + layers = self.layers + supported_layers = [18, 34, 50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [64, 128, 256, 512] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="conv1", + data_format=data_format) + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max', + data_format=data_format) + if layers >= 50: + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + name=conv_name, + data_format=data_format) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True, data_format=data_format) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv))) + else: + for block in range(len(depth)): + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.basic_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + is_first=block == i == 0, + name=conv_name, + data_format=data_format) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True, data_format=data_format) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv))) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None, + data_format='NCHW'): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + '.conv2d.output.1', + data_format=data_format) + + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + data_layout=data_format) + + def shortcut(self, input, ch_out, stride, is_first, name, data_format): + if data_format == 'NCHW': + ch_in = input.shape[1] + else: + ch_in = input.shape[-1] + if ch_in != ch_out or stride != 1 or is_first == True: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name, data_format=data_format) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, name, data_format): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a", + data_format=data_format) + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + "_branch2b", + data_format=data_format) + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + "_branch2c", + data_format=data_format) + + short = self.shortcut( + input, + num_filters * 4, + stride, + is_first=False, + name=name + "_branch1", + data_format=data_format) + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + ".add.output.5") + + def basic_block(self, input, num_filters, stride, is_first, name, data_format): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=3, + act='relu', + stride=stride, + name=name + "_branch2a", + data_format=data_format) + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + act=None, + name=name + "_branch2b", + data_format=data_format) + short = self.shortcut( + input, num_filters, stride, is_first, name=name + "_branch1", data_format=data_format) + return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') + + +def ResNet18(): + model = ResNet(layers=18) + return model + + +def ResNet34(): + model = ResNet(layers=34) + return model + + +def ResNet50(): + model = ResNet(layers=50) + return model + + +def ResNet101(): + model = ResNet(layers=101) + return model + + +def ResNet152(): + model = ResNet(layers=152) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/resnet_acnet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/resnet_acnet.py new file mode 100755 index 000000000..6cc2f969f --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/resnet_acnet.py @@ -0,0 +1,332 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + "ResNetACNet", "ResNet18_ACNet", "ResNet34_ACNet", "ResNet50_ACNet", + "ResNet101_ACNet", "ResNet152_ACNet" +] + + +class ResNetACNet(object): + """ ACNet """ + + def __init__(self, layers=50, deploy=False): + """init""" + self.layers = layers + self.deploy = deploy + + def net(self, input, class_dim=1000): + """model""" + layers = self.layers + supported_layers = [18, 34, 50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [64, 128, 256, 512] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="conv1") + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + if layers >= 50: + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + name=conv_name) + else: + for block in range(len(depth)): + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.basic_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + is_first=block == i == 0, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_type='avg', global_pooling=True) + + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv))) + return out + + def conv_bn_layer(self, **kwargs): + """ + conv_bn_layer + """ + if kwargs['filter_size'] == 1: + return self.conv_bn_layer_ori(**kwargs) + else: + return self.conv_bn_layer_ac(**kwargs) + + # conv bn+relu + def conv_bn_layer_ori(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + """ + standard convbn + used for 1x1 convbn in acnet + """ + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + '.conv2d.output.1') + + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) + + # conv bn+relu + def conv_bn_layer_ac(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + """ ACNet conv bn """ + padding = (filter_size - 1) // 2 + + square_conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + act=act if self.deploy else None, + param_attr=ParamAttr(name=name + "_acsquare_weights"), + bias_attr=ParamAttr(name=name + "_acsquare_bias") + if self.deploy else False, + name=name + '.acsquare.conv2d.output.1') + + if self.deploy: + return square_conv + else: + ver_conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=(filter_size, 1), + stride=stride, + padding=(padding, 0), + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_acver_weights"), + bias_attr=False, + name=name + '.acver.conv2d.output.1') + + hor_conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=(1, filter_size), + stride=stride, + padding=(0, padding), + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_achor_weights"), + bias_attr=False, + name=name + '.achor.conv2d.output.1') + + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + + square_bn = fluid.layers.batch_norm( + input=square_conv, + act=None, + name=bn_name + '.acsquare.output.1', + param_attr=ParamAttr(name=bn_name + '_acsquare_scale'), + bias_attr=ParamAttr(bn_name + '_acsquare_offset'), + moving_mean_name=bn_name + '_acsquare_mean', + moving_variance_name=bn_name + '_acsquare_variance', ) + + ver_bn = fluid.layers.batch_norm( + input=ver_conv, + act=None, + name=bn_name + '.acver.output.1', + param_attr=ParamAttr(name=bn_name + '_acver_scale'), + bias_attr=ParamAttr(bn_name + '_acver_offset'), + moving_mean_name=bn_name + '_acver_mean', + moving_variance_name=bn_name + '_acver_variance', ) + + hor_bn = fluid.layers.batch_norm( + input=hor_conv, + act=None, + name=bn_name + '.achor.output.1', + param_attr=ParamAttr(name=bn_name + '_achor_scale'), + bias_attr=ParamAttr(bn_name + '_achor_offset'), + moving_mean_name=bn_name + '_achor_mean', + moving_variance_name=bn_name + '_achor_variance', ) + + return fluid.layers.elementwise_add( + x=square_bn, y=ver_bn + hor_bn, act=act) + + def shortcut(self, input, ch_out, stride, is_first, name): + """ shortcut """ + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1 or is_first == True: + return self.conv_bn_layer( + input=input, + num_filters=ch_out, + filter_size=1, + stride=stride, + name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, name): + """" bottleneck_block """ + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + "_branch2c") + + short = self.shortcut( + input, + num_filters * 4, + stride, + is_first=False, + name=name + "_branch1") + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + ".add.output.5") + + def basic_block(self, input, num_filters, stride, is_first, name): + """ basic_block """ + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=3, + act='relu', + stride=stride, + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + act=None, + name=name + "_branch2b") + short = self.shortcut( + input, num_filters, stride, is_first, name=name + "_branch1") + return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') + + +def ResNet18_ACNet(deploy=False): + """ResNet18 + ACNet""" + model = ResNetACNet(layers=18, deploy=deploy) + return model + + +def ResNet34_ACNet(deploy=False): + """ResNet34 + ACNet""" + model = ResNetACNet(layers=34, deploy=deploy) + return model + + +def ResNet50_ACNet(deploy=False): + """ResNet50 + ACNet""" + model = ResNetACNet(layers=50, deploy=deploy) + return model + + +def ResNet101_ACNet(deploy=False): + """ResNet101 + ACNet""" + model = ResNetACNet(layers=101, deploy=deploy) + return model + + +def ResNet152_ACNet(deploy=False): + """ResNet152 + ACNet""" + model = ResNetACNet(layers=152, deploy=deploy) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/resnet_vc.py b/VisualFL/visualfl/algorithm/paddle_clas/models/resnet_vc.py new file mode 100755 index 000000000..0088030c4 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/resnet_vc.py @@ -0,0 +1,192 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["ResNet", "ResNet50_vc", "ResNet101_vc", "ResNet152_vc"] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class ResNet(): + def __init__(self, layers=50): + self.params = train_parameters + self.layers = layers + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [64, 128, 256, 512] + + conv = self.conv_bn_layer( + input=input, + num_filters=32, + filter_size=3, + stride=2, + act='relu', + name='conv1_1') + conv = self.conv_bn_layer( + input=conv, + num_filters=32, + filter_size=3, + stride=1, + act='relu', + name='conv1_2') + conv = self.conv_bn_layer( + input=conv, + num_filters=64, + filter_size=3, + stride=1, + act='relu', + name='conv1_3') + + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, + stdv))) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + '.conv2d.output.1') + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) + + def shortcut(self, input, ch_out, stride, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, name): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + "_branch2c") + + short = self.shortcut( + input, num_filters * 4, stride, name=name + "_branch1") + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + ".add.output.5") + + +def ResNet50_vc(): + model = ResNet(layers=50) + return model + + +def ResNet101_vc(): + model = ResNet(layers=101) + return model + + +def ResNet152_vc(): + model = ResNet(layers=152) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/resnet_vd.py b/VisualFL/visualfl/algorithm/paddle_clas/models/resnet_vd.py new file mode 100755 index 000000000..bb04e2f6e --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/resnet_vd.py @@ -0,0 +1,290 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + "ResNet", "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd", "ResNet200_vd" +] + + +class ResNet(): + def __init__(self, layers=50, is_3x3=False): + self.layers = layers + self.is_3x3 = is_3x3 + + def net(self, input, class_dim=1000): + is_3x3 = self.is_3x3 + layers = self.layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_filters = [64, 128, 256, 512] + if is_3x3 == False: + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu') + else: + conv = self.conv_bn_layer( + input=input, + num_filters=32, + filter_size=3, + stride=2, + act='relu', + name='conv1_1') + conv = self.conv_bn_layer( + input=conv, + num_filters=32, + filter_size=3, + stride=1, + act='relu', + name='conv1_2') + conv = self.conv_bn_layer( + input=conv, + num_filters=64, + filter_size=3, + stride=1, + act='relu', + name='conv1_3') + + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + if layers >= 50: + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152, 200] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + if_first=block==i==0, + name=conv_name) + else: + for block in range(len(depth)): + for i in range(depth[block]): + conv_name="res"+str(block+2)+chr(97+i) + conv = self.basic_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + if_first=block==i==0, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv))) + + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def conv_bn_layer_new(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + pool = fluid.layers.pool2d( + input=input, + pool_size=2, + pool_stride=2, + pool_padding=0, + pool_type='avg', + ceil_mode=True) + + conv = fluid.layers.conv2d( + input=pool, + num_filters=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def shortcut(self, input, ch_out, stride, name, if_first=False): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + if if_first: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name) + elif if_first: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + + def bottleneck_block(self, input, num_filters, stride, name, if_first): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + "_branch2c") + + short = self.shortcut( + input, + num_filters * 4, + stride, + if_first=if_first, + name=name + "_branch1") + + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') + + + def basic_block(self, input, num_filters, stride, name, if_first): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=3, + act='relu', + stride=stride, + name=name+"_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + act=None, + name=name+"_branch2b") + short = self.shortcut( + input, + num_filters, + stride, + if_first=if_first, + name=name + "_branch1") + return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') + +def ResNet18_vd(): + model=ResNet(layers=18, is_3x3=True) + return model + + +def ResNet34_vd(): + model=ResNet(layers=34, is_3x3=True) + return model + + +def ResNet50_vd(): + model = ResNet(layers=50, is_3x3=True) + return model + + +def ResNet101_vd(): + model = ResNet(layers=101, is_3x3=True) + return model + + +def ResNet152_vd(): + model = ResNet(layers=152, is_3x3=True) + return model + + +def ResNet200_vd(): + model = ResNet(layers=200, is_3x3=True) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/resnext.py b/VisualFL/visualfl/algorithm/paddle_clas/models/resnext.py new file mode 100755 index 000000000..ad973387f --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/resnext.py @@ -0,0 +1,195 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + "ResNeXt", "ResNeXt50_64x4d", "ResNeXt101_64x4d", "ResNeXt152_64x4d", + "ResNeXt50_32x4d", "ResNeXt101_32x4d", "ResNeXt152_32x4d" +] + + +class ResNeXt(): + def __init__(self, layers=50, cardinality=64): + self.layers = layers + self.cardinality = cardinality + + def net(self, input, class_dim=1000): + layers = self.layers + cardinality = self.cardinality + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + + num_filters1 = [256, 512, 1024, 2048] + num_filters2 = [128, 256, 512, 1024] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="res_conv1") #debug + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters1[block] + if cardinality == 64 else num_filters2[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name='fc_weights'), + bias_attr=fluid.param_attr.ParamAttr(name='fc_offset')) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + '.conv2d.output.1') + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) + + def shortcut(self, input, ch_out, stride, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, cardinality, name): + cardinality = self.cardinality + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + groups=cardinality, + act='relu', + name=name + "_branch2b") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters if cardinality == 64 else num_filters * 2, + filter_size=1, + act=None, + name=name + "_branch2c") + + short = self.shortcut( + input, + num_filters if cardinality == 64 else num_filters * 2, + stride, + name=name + "_branch1") + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + ".add.output.5") + + +def ResNeXt50_64x4d(): + model = ResNeXt(layers=50, cardinality=64) + return model + + +def ResNeXt50_32x4d(): + model = ResNeXt(layers=50, cardinality=32) + return model + + +def ResNeXt101_64x4d(): + model = ResNeXt(layers=101, cardinality=64) + return model + + +def ResNeXt101_32x4d(): + model = ResNeXt(layers=101, cardinality=32) + return model + + +def ResNeXt152_64x4d(): + model = ResNeXt(layers=152, cardinality=64) + return model + + +def ResNeXt152_32x4d(): + model = ResNeXt(layers=152, cardinality=32) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/resnext101_wsl.py b/VisualFL/visualfl/algorithm/paddle_clas/models/resnext101_wsl.py new file mode 100755 index 000000000..aa67f208b --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/resnext101_wsl.py @@ -0,0 +1,179 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + "ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl", "ResNeXt101_32x32d_wsl", + "ResNeXt101_32x48d_wsl", "Fix_ResNeXt101_32x48d_wsl" +] + + +class ResNeXt101_wsl(): + def __init__(self, layers=101, cardinality=32, width=48): + self.layers = layers + self.cardinality = cardinality + self.width = width + + def net(self, input, class_dim=1000): + layers = self.layers + cardinality = self.cardinality + width = self.width + + depth = [3, 4, 23, 3] + base_width = cardinality * width + num_filters = [base_width * i for i in [1, 2, 4, 8]] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="conv1") #debug + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + conv_name = 'layer' + str(block + 1) + "." + str(i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name='fc.weight'), + bias_attr=fluid.param_attr.ParamAttr(name='fc.bias')) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + if "downsample" in name: + conv_name = name + '.0' + else: + conv_name = name + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=conv_name + ".weight"), + bias_attr=False) + if "downsample" in name: + bn_name = name[:9] + 'downsample' + '.1' + else: + if "conv1" == name: + bn_name = 'bn' + name[-1] + else: + bn_name = (name[:10] if name[7:9].isdigit() else name[:9] + ) + 'bn' + name[-1] + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '.weight'), + bias_attr=ParamAttr(bn_name + '.bias'), + moving_mean_name=bn_name + '.running_mean', + moving_variance_name=bn_name + '.running_var', ) + + def shortcut(self, input, ch_out, stride, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, cardinality, name): + cardinality = self.cardinality + width = self.width + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + ".conv1") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + groups=cardinality, + act='relu', + name=name + ".conv2") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters // (width // 8), + filter_size=1, + act=None, + name=name + ".conv3") + + short = self.shortcut( + input, + num_filters // (width // 8), + stride, + name=name + ".downsample") + + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') + + +def ResNeXt101_32x8d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=8) + return model + + +def ResNeXt101_32x16d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=16) + return model + + +def ResNeXt101_32x32d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=32) + return model + + +def ResNeXt101_32x48d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=48) + return model + + +def Fix_ResNeXt101_32x48d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=48) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/resnext_vd.py b/VisualFL/visualfl/algorithm/paddle_clas/models/resnext_vd.py new file mode 100755 index 000000000..fd60da377 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/resnext_vd.py @@ -0,0 +1,256 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +import math + +__all__ = [ + "ResNeXt", "ResNeXt50_vd_64x4d", "ResNeXt101_vd_64x4d", + "ResNeXt152_vd_64x4d", "ResNeXt50_vd_32x4d", "ResNeXt101_vd_32x4d", + "ResNeXt152_vd_32x4d" +] + + +class ResNeXt(): + def __init__(self, layers=50, is_3x3=False, cardinality=64): + self.layers = layers + self.is_3x3 = is_3x3 + self.cardinality = cardinality + + def net(self, input, class_dim=1000): + is_3x3 = self.is_3x3 + layers = self.layers + cardinality = self.cardinality + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters1 = [256, 512, 1024, 2048] + num_filters2 = [128, 256, 512, 1024] + + if is_3x3 == False: + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu') + else: + conv = self.conv_bn_layer( + input=input, + num_filters=32, + filter_size=3, + stride=2, + act='relu', + name='conv1_1') + conv = self.conv_bn_layer( + input=conv, + num_filters=32, + filter_size=3, + stride=1, + act='relu', + name='conv1_2') + conv = self.conv_bn_layer( + input=conv, + num_filters=64, + filter_size=3, + stride=1, + act='relu', + name='conv1_3') + + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152, 200] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters1[block] + if cardinality == 64 else num_filters2[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + if_first=block == 0, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name='fc_weights'), + bias_attr=fluid.param_attr.ParamAttr(name='fc_offset')) + + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def conv_bn_layer_new(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + pool = fluid.layers.pool2d( + input=input, + pool_size=2, + pool_stride=2, + pool_padding=0, + pool_type='avg', + ceil_mode=True) + + conv = fluid.layers.conv2d( + input=pool, + num_filters=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def shortcut(self, input, ch_out, stride, name, if_first=False): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + if if_first: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return self.conv_bn_layer_new( + input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, cardinality, name, + if_first): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + groups=cardinality, + name=name + "_branch2b") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters if cardinality == 64 else num_filters * 2, + filter_size=1, + act=None, + name=name + "_branch2c") + + short = self.shortcut( + input, + num_filters if cardinality == 64 else num_filters * 2, + stride, + if_first=if_first, + name=name + "_branch1") + + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') + + +def ResNeXt50_vd_64x4d(): + model = ResNeXt(layers=50, is_3x3=True) + return model + + +def ResNeXt50_vd_32x4d(): + model = ResNeXt(layers=50, cardinality=32, is_3x3=True) + return model + + +def ResNeXt101_vd_64x4d(): + model = ResNeXt(layers=101, is_3x3=True) + return model + + +def ResNeXt101_vd_32x4d(): + model = ResNeXt(layers=101, cardinality=32, is_3x3=True) + return model + + +def ResNeXt152_vd_64x4d(): + model = ResNeXt(layers=152, is_3x3=True) + return model + + +def ResNeXt152_vd_32x4d(): + model = ResNeXt(layers=152, cardinality=32, is_3x3=True) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/se_resnet_vd.py b/VisualFL/visualfl/algorithm/paddle_clas/models/se_resnet_vd.py new file mode 100755 index 000000000..aa8b910de --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/se_resnet_vd.py @@ -0,0 +1,295 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +import math + +__all__ = ["SE_ResNet_vd", "SE_ResNet18_vd","SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNet101_vd", "SE_ResNet152_vd", + "SE_ResNet200_vd"] + + +class SE_ResNet_vd(): + def __init__(self, layers=50, is_3x3=False): + self.layers = layers + self.is_3x3 = is_3x3 + + def net(self, input, class_dim=1000): + is_3x3 = self.is_3x3 + layers = self.layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_filters = [64, 128, 256, 512] + reduction_ratio = 16 + if is_3x3 == False: + conv = self.conv_bn_layer( + input=input, num_filters=64, filter_size=7, stride=2, act='relu') + else: + conv = self.conv_bn_layer( + input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1') + conv = self.conv_bn_layer( + input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2') + conv = self.conv_bn_layer( + input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3') + + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + if layers >= 50: + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152, 200] and block == 2: + if i == 0: + conv_name="res"+str(block+2)+"a" + else: + conv_name="res"+str(block+2)+"b"+str(i) + else: + conv_name="res"+str(block+2)+chr(97+i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + if_first=block==i==0, + reduction_ratio=reduction_ratio, + name=conv_name) + + else: + for block in range(len(depth)): + for i in range(depth[block]): + conv_name="res"+str(block+2)+chr(97+i) + conv = self.basic_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + if_first=block==i==0, + reduction_ratio=reduction_ratio, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_type='avg', global_pooling=True) + + + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), name='fc6_weights'), + bias_attr=ParamAttr(name='fc6_offset')) + + return out + + + + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm(input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + + def conv_bn_layer_new(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + pool = fluid.layers.pool2d(input=input, + pool_size=2, + pool_stride=2, + pool_padding=0, + pool_type='avg', + ceil_mode=True) + + conv = fluid.layers.conv2d( + input=pool, + num_filters=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm(input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + + + def shortcut(self, input, ch_out, stride, name, if_first=False): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + if if_first: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name) + elif if_first: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, name, if_first, reduction_ratio): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name+"_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name+"_branch2b") + conv2 =self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name+"_branch2c") + scale = self.squeeze_excitation( + input=conv2, + num_channels=num_filters * 4, + reduction_ratio=reduction_ratio, + name='fc_'+name) + + short = self.shortcut(input, num_filters * 4, stride, if_first=if_first, name=name + "_branch1") + + return fluid.layers.elementwise_add(x=short, y=scale, act='relu') + + def basic_block(self, input, num_filters, stride, name, if_first, reduction_ratio): + conv0 = self.conv_bn_layer(input=input, + num_filters=num_filters, + filter_size=3, + act='relu', + stride=stride, + name=name+"_branch2a") + conv1 = self.conv_bn_layer(input=conv0, + num_filters=num_filters, + filter_size=3, + act=None, + name=name+"_branch2b") + scale = self.squeeze_excitation( + input=conv1, + num_channels=num_filters, + reduction_ratio=reduction_ratio, + name='fc_'+name) + short = self.shortcut(input, + num_filters, + stride, + if_first=if_first, + name=name + "_branch1") + return fluid.layers.elementwise_add(x=short, y=scale, act='relu') + + + def squeeze_excitation(self, input, num_channels, reduction_ratio, name=None): + pool = fluid.layers.pool2d( + input=input, pool_size=0, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + squeeze = fluid.layers.fc(input=pool, + size=num_channels // reduction_ratio, + act='relu', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform( + -stdv, stdv),name=name+'_sqz_weights'), + bias_attr=ParamAttr(name=name+'_sqz_offset')) + stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) + excitation = fluid.layers.fc(input=squeeze, + size=num_channels, + act='sigmoid', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name+'_exc_weights'), + bias_attr=ParamAttr(name=name+'_exc_offset')) + scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return scale + +def SE_ResNet18_vd(): + model = SE_ResNet_vd(layers=18, is_3x3 = True) + return model + +def SE_ResNet34_vd(): + model = SE_ResNet_vd(layers=34, is_3x3 = True) + return model + +def SE_ResNet50_vd(): + model = SE_ResNet_vd(layers=50, is_3x3 = True) + return model + +def SE_ResNet101_vd(): + model = SE_ResNet_vd(layers=101, is_3x3 = True) + return model + +def SE_ResNet152_vd(): + model = SE_ResNet_vd(layers=152, is_3x3 = True) + return model + +def SE_ResNet200_vd(): + model = SE_ResNet_vd(layers=200, is_3x3 = True) + return model + diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/se_resnext.py b/VisualFL/visualfl/algorithm/paddle_clas/models/se_resnext.py new file mode 100755 index 000000000..598998f64 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/se_resnext.py @@ -0,0 +1,249 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + "SE_ResNeXt", "SE_ResNeXt50_32x4d", "SE_ResNeXt101_32x4d", + "SE_ResNeXt152_32x4d" +] + + +class SE_ResNeXt(): + def __init__(self, layers=50): + self.layers = layers + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + if layers == 50: + cardinality = 32 + reduction_ratio = 16 + depth = [3, 4, 6, 3] + num_filters = [128, 256, 512, 1024] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name='conv1', ) + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max', + use_cudnn=False) + elif layers == 101: + cardinality = 32 + reduction_ratio = 16 + depth = [3, 4, 23, 3] + num_filters = [128, 256, 512, 1024] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="conv1", ) + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max', + use_cudnn=False) + elif layers == 152: + cardinality = 64 + reduction_ratio = 16 + depth = [3, 8, 36, 3] + num_filters = [128, 256, 512, 1024] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=3, + stride=2, + act='relu', + name='conv1') + conv = self.conv_bn_layer( + input=conv, + num_filters=64, + filter_size=3, + stride=1, + act='relu', + name='conv2') + conv = self.conv_bn_layer( + input=conv, + num_filters=128, + filter_size=3, + stride=1, + act='relu', + name='conv3') + conv = fluid.layers.pool2d( + input=conv, pool_size=3, pool_stride=2, pool_padding=1, \ + pool_type='max', use_cudnn=False) + n = 1 if layers == 50 or layers == 101 else 3 + for block in range(len(depth)): + n += 1 + for i in range(depth[block]): + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + reduction_ratio=reduction_ratio, + name=str(n) + '_' + str(i + 1)) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True, use_cudnn=False) + drop = fluid.layers.dropout(x=pool, dropout_prob=0.5) + stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0) + out = fluid.layers.fc( + input=drop, + size=class_dim, + param_attr=ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name='fc6_weights'), + bias_attr=ParamAttr(name='fc6_offset')) + return out + + def shortcut(self, input, ch_out, stride, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + filter_size = 1 + return self.conv_bn_layer( + input, ch_out, filter_size, stride, name='conv' + name + '_prj') + else: + return input + + def bottleneck_block(self, + input, + num_filters, + stride, + cardinality, + reduction_ratio, + name=None): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name='conv' + name + '_x1') + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + groups=cardinality, + act='relu', + name='conv' + name + '_x2') + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 2, + filter_size=1, + act=None, + name='conv' + name + '_x3') + scale = self.squeeze_excitation( + input=conv2, + num_channels=num_filters * 2, + reduction_ratio=reduction_ratio, + name='fc' + name) + + short = self.shortcut(input, num_filters * 2, stride, name=name) + + return fluid.layers.elementwise_add(x=short, y=scale, act='relu') + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False, + param_attr=ParamAttr(name=name + '_weights'), ) + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def squeeze_excitation(self, + input, + num_channels, + reduction_ratio, + name=None): + pool = fluid.layers.pool2d( + input=input, pool_type='avg', global_pooling=True, use_cudnn=False) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + squeeze = fluid.layers.fc( + input=pool, + size=num_channels // reduction_ratio, + act='relu', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_sqz_weights'), + bias_attr=ParamAttr(name=name + '_sqz_offset')) + stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) + excitation = fluid.layers.fc( + input=squeeze, + size=num_channels, + act='sigmoid', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_exc_weights'), + bias_attr=ParamAttr(name=name + '_exc_offset')) + scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return scale + + +def SE_ResNeXt50_32x4d(): + model = SE_ResNeXt(layers=50) + return model + + +def SE_ResNeXt101_32x4d(): + model = SE_ResNeXt(layers=101) + return model + + +def SE_ResNeXt152_32x4d(): + model = SE_ResNeXt(layers=152) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/se_resnext_vd.py b/VisualFL/visualfl/algorithm/paddle_clas/models/se_resnext_vd.py new file mode 100755 index 000000000..a278aebe3 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/se_resnext_vd.py @@ -0,0 +1,328 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + "SE_ResNeXt_vd", "SE_ResNeXt50_vd_32x4d", "SE_ResNeXt101_vd_32x4d", "SENet154_vd" +] + + +class SE_ResNeXt_vd(): + def __init__(self, layers=50): + self.layers = layers + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + if layers == 50: + cardinality = 32 + reduction_ratio = 16 + depth = [3, 4, 6, 3] + num_filters = [128, 256, 512, 1024] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=3, + stride=2, + act='relu', + name='conv1_1') + conv = self.conv_bn_layer( + input=conv, + num_filters=64, + filter_size=3, + stride=1, + act='relu', + name='conv1_2') + conv = self.conv_bn_layer( + input=conv, + num_filters=128, + filter_size=3, + stride=1, + act='relu', + name='conv1_3') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + elif layers == 101: + cardinality = 32 + reduction_ratio = 16 + depth = [3, 4, 23, 3] + num_filters = [128, 256, 512, 1024] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=3, + stride=2, + act='relu', + name='conv1_1') + conv = self.conv_bn_layer( + input=conv, + num_filters=64, + filter_size=3, + stride=1, + act='relu', + name='conv1_2') + conv = self.conv_bn_layer( + input=conv, + num_filters=128, + filter_size=3, + stride=1, + act='relu', + name='conv1_3') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + elif layers == 152: + cardinality = 64 + reduction_ratio = 16 + depth = [3, 8, 36, 3] + num_filters = [256, 512, 1024, 2048] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=3, + stride=2, + act='relu', + name='conv1_1') + conv = self.conv_bn_layer( + input=conv, + num_filters=64, + filter_size=3, + stride=1, + act='relu', + name='conv1_2') + conv = self.conv_bn_layer( + input=conv, + num_filters=128, + filter_size=3, + stride=1, + act='relu', + name='conv1_3') + conv = fluid.layers.pool2d( + input=conv, pool_size=3, pool_stride=2, pool_padding=1, \ + pool_type='max') + n = 1 if layers == 50 or layers == 101 else 3 + for block in range(len(depth)): + n += 1 + for i in range(depth[block]): + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + reduction_ratio=reduction_ratio, + if_first=block == 0, + name=str(n) + '_' + str(i + 1)) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + if layers == 152: + pool = fluid.layers.dropout(x=pool, dropout_prob=0.2) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name='fc6_weights'), + bias_attr=ParamAttr(name='fc6_offset')) + + return out + + def shortcut(self, input, ch_out, stride, name, if_first=False): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + filter_size = 1 + if if_first: + return self.conv_bn_layer( + input, + ch_out, + filter_size, + stride, + name='conv' + name + '_prj') + else: + return self.conv_bn_layer_new( + input, + ch_out, + filter_size, + stride, + name='conv' + name + '_prj') + else: + return input + + def bottleneck_block(self, + input, + num_filters, + stride, + cardinality, + reduction_ratio, + if_first, + name=None): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name='conv' + name + '_x1') + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + groups=cardinality, + act='relu', + name='conv' + name + '_x2') + if cardinality == 64: + num_filters = num_filters // 2 + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 2, + filter_size=1, + act=None, + name='conv' + name + '_x3') + scale = self.squeeze_excitation( + input=conv2, + num_channels=num_filters * 2, + reduction_ratio=reduction_ratio, + name='fc' + name) + + short = self.shortcut( + input, num_filters * 2, stride, if_first=if_first, name=name) + + return fluid.layers.elementwise_add(x=short, y=scale, act='relu') + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False, + param_attr=ParamAttr(name=name + '_weights'), ) + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def conv_bn_layer_new(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + pool = fluid.layers.pool2d( + input=input, + pool_size=2, + pool_stride=2, + pool_padding=0, + pool_type='avg', + ceil_mode=True) + + conv = fluid.layers.conv2d( + input=pool, + num_filters=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def squeeze_excitation(self, + input, + num_channels, + reduction_ratio, + name=None): + pool = fluid.layers.pool2d( + input=input, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + squeeze = fluid.layers.fc( + input=pool, + size=num_channels // reduction_ratio, + act='relu', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_sqz_weights'), + bias_attr=ParamAttr(name=name + '_sqz_offset')) + stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) + excitation = fluid.layers.fc( + input=squeeze, + size=num_channels, + act='sigmoid', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_exc_weights'), + bias_attr=ParamAttr(name=name + '_exc_offset')) + scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return scale + + +def SE_ResNeXt50_vd_32x4d(): + model = SE_ResNeXt_vd(layers=50) + return model + + +def SE_ResNeXt101_vd_32x4d(): + model = SE_ResNeXt_vd(layers=101) + return model + + +def SENet154_vd(): + model = SE_ResNeXt_vd(layers=152) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/shufflenet_v2.py b/VisualFL/visualfl/algorithm/paddle_clas/models/shufflenet_v2.py new file mode 100755 index 000000000..89f1eb888 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/shufflenet_v2.py @@ -0,0 +1,307 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + 'ShuffleNetV2_x0_25', 'ShuffleNetV2_x0_33', 'ShuffleNetV2_x0_5', + 'ShuffleNetV2_x1_0', 'ShuffleNetV2_x1_5', 'ShuffleNetV2_x2_0', + 'ShuffleNetV2' +] + + +class ShuffleNetV2(): + def __init__(self, scale=1.0): + self.scale = scale + + def net(self, input, class_dim=1000): + scale = self.scale + stage_repeats = [4, 8, 4] + + if scale == 0.25: + stage_out_channels = [-1, 24, 24, 48, 96, 512] + elif scale == 0.33: + stage_out_channels = [-1, 24, 32, 64, 128, 512] + elif scale == 0.5: + stage_out_channels = [-1, 24, 48, 96, 192, 1024] + elif scale == 1.0: + stage_out_channels = [-1, 24, 116, 232, 464, 1024] + elif scale == 1.5: + stage_out_channels = [-1, 24, 176, 352, 704, 1024] + elif scale == 2.0: + stage_out_channels = [-1, 24, 224, 488, 976, 2048] + else: + raise NotImplementedError("This scale size:[" + str(scale) + + "] is not implemented!") + #conv1 + + input_channel = stage_out_channels[1] + conv1 = self.conv_bn_layer( + input=input, + filter_size=3, + num_filters=input_channel, + padding=1, + stride=2, + name='stage1_conv') + pool1 = fluid.layers.pool2d( + input=conv1, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + conv = pool1 + # bottleneck sequences + for idxstage in range(len(stage_repeats)): + numrepeat = stage_repeats[idxstage] + output_channel = stage_out_channels[idxstage + 2] + for i in range(numrepeat): + if i == 0: + conv = self.inverted_residual_unit( + input=conv, + num_filters=output_channel, + stride=2, + benchmodel=2, + name=str(idxstage + 2) + '_' + str(i + 1)) + else: + conv = self.inverted_residual_unit( + input=conv, + num_filters=output_channel, + stride=1, + benchmodel=1, + name=str(idxstage + 2) + '_' + str(i + 1)) + + conv_last = self.conv_bn_layer( + input=conv, + filter_size=1, + num_filters=stage_out_channels[-1], + padding=0, + stride=1, + name='conv5') + pool_last = fluid.layers.pool2d( + input=conv_last, + pool_size=7, + pool_stride=1, + pool_padding=0, + pool_type='avg') + + output = fluid.layers.fc(input=pool_last, + size=class_dim, + param_attr=ParamAttr( + initializer=MSRA(), name='fc6_weights'), + bias_attr=ParamAttr(name='fc6_offset')) + return output + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + use_cudnn=True, + if_act=True, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr( + initializer=MSRA(), name=name + '_weights'), + bias_attr=False) + out = int((input.shape[2] - 1) / float(stride) + 1) + bn_name = name + '_bn' + if if_act: + return fluid.layers.batch_norm( + input=conv, + act='relu', + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + else: + return fluid.layers.batch_norm( + input=conv, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def channel_shuffle(self, x, groups): + batchsize, num_channels, height, width = x.shape[0], x.shape[ + 1], x.shape[2], x.shape[3] + channels_per_group = num_channels // groups + + # reshape + x = fluid.layers.reshape( + x=x, shape=[batchsize, groups, channels_per_group, height, width]) + + x = fluid.layers.transpose(x=x, perm=[0, 2, 1, 3, 4]) + + # flatten + x = fluid.layers.reshape( + x=x, shape=[batchsize, num_channels, height, width]) + + return x + + def inverted_residual_unit(self, + input, + num_filters, + stride, + benchmodel, + name=None): + assert stride in [1, 2], \ + "supported stride are {} but your stride is {}".format([1,2], stride) + + oup_inc = num_filters // 2 + inp = input.shape[1] + + if benchmodel == 1: + x1, x2 = fluid.layers.split( + input, + num_or_sections=[input.shape[1] // 2, input.shape[1] // 2], + dim=1) + + conv_pw = self.conv_bn_layer( + input=x2, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv1') + + conv_dw = self.conv_bn_layer( + input=conv_pw, + num_filters=oup_inc, + filter_size=3, + stride=stride, + padding=1, + num_groups=oup_inc, + if_act=False, + use_cudnn=False, + name='stage_' + name + '_conv2') + + conv_linear = self.conv_bn_layer( + input=conv_dw, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv3') + + out = fluid.layers.concat([x1, conv_linear], axis=1) + + else: + #branch1 + conv_dw_1 = self.conv_bn_layer( + input=input, + num_filters=inp, + filter_size=3, + stride=stride, + padding=1, + num_groups=inp, + if_act=False, + use_cudnn=False, + name='stage_' + name + '_conv4') + + conv_linear_1 = self.conv_bn_layer( + input=conv_dw_1, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv5') + + #branch2 + conv_pw_2 = self.conv_bn_layer( + input=input, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv1') + + conv_dw_2 = self.conv_bn_layer( + input=conv_pw_2, + num_filters=oup_inc, + filter_size=3, + stride=stride, + padding=1, + num_groups=oup_inc, + if_act=False, + use_cudnn=False, + name='stage_' + name + '_conv2') + + conv_linear_2 = self.conv_bn_layer( + input=conv_dw_2, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv3') + out = fluid.layers.concat([conv_linear_1, conv_linear_2], axis=1) + + return self.channel_shuffle(out, 2) + + +def ShuffleNetV2_x0_25(): + model = ShuffleNetV2(scale=0.25) + return model + + +def ShuffleNetV2_x0_33(): + model = ShuffleNetV2(scale=0.33) + return model + + +def ShuffleNetV2_x0_5(): + model = ShuffleNetV2(scale=0.5) + return model + + +def ShuffleNetV2_x1_0(): + model = ShuffleNetV2(scale=1.0) + return model + + +def ShuffleNetV2_x1_5(): + model = ShuffleNetV2(scale=1.5) + return model + + +def ShuffleNetV2_x2_0(): + model = ShuffleNetV2(scale=2.0) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/shufflenet_v2_swish.py b/VisualFL/visualfl/algorithm/paddle_clas/models/shufflenet_v2_swish.py new file mode 100755 index 000000000..bbd0330a5 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/shufflenet_v2_swish.py @@ -0,0 +1,293 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + 'ShuffleNetV2_x0_5_swish', 'ShuffleNetV2_x1_0_swish', + 'ShuffleNetV2_x1_5_swish', 'ShuffleNetV2_x2_0_swish', 'ShuffleNetV2_swish' +] + + +class ShuffleNetV2_swish(): + def __init__(self, scale=1.0): + self.scale = scale + + def net(self, input, class_dim=1000): + scale = self.scale + stage_repeats = [4, 8, 4] + + if scale == 0.5: + stage_out_channels = [-1, 24, 48, 96, 192, 1024] + elif scale == 1.0: + stage_out_channels = [-1, 24, 116, 232, 464, 1024] + elif scale == 1.5: + stage_out_channels = [-1, 24, 176, 352, 704, 1024] + elif scale == 2.0: + stage_out_channels = [-1, 24, 224, 488, 976, 2048] + else: + raise ValueError("""{} groups is not supported for + 1x1 Grouped Convolutions""".format(num_groups)) + + #conv1 + + input_channel = stage_out_channels[1] + conv1 = self.conv_bn_layer( + input=input, + filter_size=3, + num_filters=input_channel, + padding=1, + stride=2, + name='stage1_conv') + pool1 = fluid.layers.pool2d( + input=conv1, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + conv = pool1 + # bottleneck sequences + for idxstage in range(len(stage_repeats)): + numrepeat = stage_repeats[idxstage] + output_channel = stage_out_channels[idxstage + 2] + for i in range(numrepeat): + if i == 0: + conv = self.inverted_residual_unit( + input=conv, + num_filters=output_channel, + stride=2, + benchmodel=2, + name=str(idxstage + 2) + '_' + str(i + 1)) + else: + conv = self.inverted_residual_unit( + input=conv, + num_filters=output_channel, + stride=1, + benchmodel=1, + name=str(idxstage + 2) + '_' + str(i + 1)) + + conv_last = self.conv_bn_layer( + input=conv, + filter_size=1, + num_filters=stage_out_channels[-1], + padding=0, + stride=1, + name='conv5') + pool_last = fluid.layers.pool2d( + input=conv_last, + pool_size=7, + pool_stride=1, + pool_padding=0, + pool_type='avg') + + output = fluid.layers.fc(input=pool_last, + size=class_dim, + param_attr=ParamAttr( + initializer=MSRA(), name='fc6_weights'), + bias_attr=ParamAttr(name='fc6_offset')) + return output + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + use_cudnn=True, + if_act=True, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr( + initializer=MSRA(), name=name + '_weights'), + bias_attr=False) + out = int((input.shape[2] - 1) / float(stride) + 1) + bn_name = name + '_bn' + if if_act: + return fluid.layers.batch_norm( + input=conv, + act='swish', + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + else: + return fluid.layers.batch_norm( + input=conv, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def channel_shuffle(self, x, groups): + batchsize, num_channels, height, width = x.shape[0], x.shape[ + 1], x.shape[2], x.shape[3] + channels_per_group = num_channels // groups + + # reshape + x = fluid.layers.reshape( + x=x, shape=[batchsize, groups, channels_per_group, height, width]) + + x = fluid.layers.transpose(x=x, perm=[0, 2, 1, 3, 4]) + + # flatten + x = fluid.layers.reshape( + x=x, shape=[batchsize, num_channels, height, width]) + + return x + + def inverted_residual_unit(self, + input, + num_filters, + stride, + benchmodel, + name=None): + assert stride in [1, 2], \ + "supported stride are {} but your stride is {}".format([1,2], stride) + + oup_inc = num_filters // 2 + inp = input.shape[1] + + if benchmodel == 1: + x1, x2 = fluid.layers.split( + input, + num_or_sections=[input.shape[1] // 2, input.shape[1] // 2], + dim=1) + + conv_pw = self.conv_bn_layer( + input=x2, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv1') + + conv_dw = self.conv_bn_layer( + input=conv_pw, + num_filters=oup_inc, + filter_size=3, + stride=stride, + padding=1, + num_groups=oup_inc, + if_act=False, + use_cudnn=False, + name='stage_' + name + '_conv2') + + conv_linear = self.conv_bn_layer( + input=conv_dw, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv3') + + out = fluid.layers.concat([x1, conv_linear], axis=1) + + else: + #branch1 + conv_dw_1 = self.conv_bn_layer( + input=input, + num_filters=inp, + filter_size=3, + stride=stride, + padding=1, + num_groups=inp, + if_act=False, + use_cudnn=False, + name='stage_' + name + '_conv4') + + conv_linear_1 = self.conv_bn_layer( + input=conv_dw_1, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv5') + + #branch2 + conv_pw_2 = self.conv_bn_layer( + input=input, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv1') + + conv_dw_2 = self.conv_bn_layer( + input=conv_pw_2, + num_filters=oup_inc, + filter_size=3, + stride=stride, + padding=1, + num_groups=oup_inc, + if_act=False, + use_cudnn=False, + name='stage_' + name + '_conv2') + + conv_linear_2 = self.conv_bn_layer( + input=conv_dw_2, + num_filters=oup_inc, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name='stage_' + name + '_conv3') + out = fluid.layers.concat([conv_linear_1, conv_linear_2], axis=1) + + return self.channel_shuffle(out, 2) + + +def ShuffleNetV2_x0_5_swish(): + model = ShuffleNetV2_swish(scale=0.5) + return model + + +def ShuffleNetV2_x1_0_swish(): + model = ShuffleNetV2_swish(scale=1.0) + return model + + +def ShuffleNetV2_x1_5_swish(): + model = ShuffleNetV2_swish(scale=1.5) + return model + + +def ShuffleNetV2_x2_0_swish(): + model = ShuffleNetV2_swish(scale=2.0) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/squeezenet.py b/VisualFL/visualfl/algorithm/paddle_clas/models/squeezenet.py new file mode 100755 index 000000000..a6dc5b3e4 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/squeezenet.py @@ -0,0 +1,132 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["SqueezeNet", "SqueezeNet1_0", "SqueezeNet1_1"] + + +class SqueezeNet(): + def __init__(self, version='1.0'): + self.version = version + + def net(self, input, class_dim=1000): + version = self.version + assert version in ['1.0', '1.1'], \ + "supported version are {} but input version is {}".format(['1.0', '1.1'], version) + if version == '1.0': + conv = fluid.layers.conv2d( + input, + num_filters=96, + filter_size=7, + stride=2, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"), + bias_attr=ParamAttr(name='conv1_offset')) + conv = fluid.layers.pool2d( + conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 16, 64, 64, name='fire2') + conv = self.make_fire(conv, 16, 64, 64, name='fire3') + conv = self.make_fire(conv, 32, 128, 128, name='fire4') + conv = fluid.layers.pool2d( + conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 32, 128, 128, name='fire5') + conv = self.make_fire(conv, 48, 192, 192, name='fire6') + conv = self.make_fire(conv, 48, 192, 192, name='fire7') + conv = self.make_fire(conv, 64, 256, 256, name='fire8') + conv = fluid.layers.pool2d( + conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 64, 256, 256, name='fire9') + else: + conv = fluid.layers.conv2d( + input, + num_filters=64, + filter_size=3, + stride=2, + padding=1, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"), + bias_attr=ParamAttr(name='conv1_offset')) + conv = fluid.layers.pool2d( + conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 16, 64, 64, name='fire2') + conv = self.make_fire(conv, 16, 64, 64, name='fire3') + conv = fluid.layers.pool2d( + conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 32, 128, 128, name='fire4') + conv = self.make_fire(conv, 32, 128, 128, name='fire5') + conv = fluid.layers.pool2d( + conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 48, 192, 192, name='fire6') + conv = self.make_fire(conv, 48, 192, 192, name='fire7') + conv = self.make_fire(conv, 64, 256, 256, name='fire8') + conv = self.make_fire(conv, 64, 256, 256, name='fire9') + conv = fluid.layers.dropout(conv, dropout_prob=0.5) + conv = fluid.layers.conv2d( + conv, + num_filters=class_dim, + filter_size=1, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name="conv10_weights"), + bias_attr=ParamAttr(name='conv10_offset')) + conv = fluid.layers.pool2d(conv, pool_type='avg', global_pooling=True) + out = fluid.layers.flatten(conv) + return out + + def make_fire_conv(self, + input, + num_filters, + filter_size, + padding=0, + name=None): + conv = fluid.layers.conv2d( + input, + num_filters=num_filters, + filter_size=filter_size, + padding=padding, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name=name + "_weights"), + bias_attr=ParamAttr(name=name + '_offset')) + return conv + + def make_fire(self, + input, + squeeze_channels, + expand1x1_channels, + expand3x3_channels, + name=None): + conv = self.make_fire_conv( + input, squeeze_channels, 1, name=name + '_squeeze1x1') + conv_path1 = self.make_fire_conv( + conv, expand1x1_channels, 1, name=name + '_expand1x1') + conv_path2 = self.make_fire_conv( + conv, expand3x3_channels, 3, 1, name=name + '_expand3x3') + out = fluid.layers.concat([conv_path1, conv_path2], axis=1) + return out + + +def SqueezeNet1_0(): + model = SqueezeNet(version='1.0') + return model + + +def SqueezeNet1_1(): + model = SqueezeNet(version='1.1') + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/vgg.py b/VisualFL/visualfl/algorithm/paddle_clas/models/vgg.py new file mode 100755 index 000000000..d58efd7f4 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/vgg.py @@ -0,0 +1,105 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.fluid as fluid + +__all__ = ["VGGNet", "VGG11", "VGG13", "VGG16", "VGG19"] + + +class VGGNet(): + def __init__(self, layers=16): + self.layers = layers + + def net(self, input, class_dim=1000): + layers = self.layers + vgg_spec = { + 11: ([1, 1, 2, 2, 2]), + 13: ([2, 2, 2, 2, 2]), + 16: ([2, 2, 3, 3, 3]), + 19: ([2, 2, 4, 4, 4]) + } + assert layers in vgg_spec.keys(), \ + "supported layers are {} but input layer is {}".format(vgg_spec.keys(), layers) + + nums = vgg_spec[layers] + conv1 = self.conv_block(input, 64, nums[0], name="conv1_") + conv2 = self.conv_block(conv1, 128, nums[1], name="conv2_") + conv3 = self.conv_block(conv2, 256, nums[2], name="conv3_") + conv4 = self.conv_block(conv3, 512, nums[3], name="conv4_") + conv5 = self.conv_block(conv4, 512, nums[4], name="conv5_") + + fc_dim = 4096 + fc_name = ["fc6", "fc7", "fc8"] + fc1 = fluid.layers.fc( + input=conv5, + size=fc_dim, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_weights"), + bias_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_offset")) + fc1 = fluid.layers.dropout(x=fc1, dropout_prob=0.5) + fc2 = fluid.layers.fc( + input=fc1, + size=fc_dim, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_weights"), + bias_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_offset")) + fc2 = fluid.layers.dropout(x=fc2, dropout_prob=0.5) + out = fluid.layers.fc( + input=fc2, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_weights"), + bias_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_offset")) + + return out + + def conv_block(self, input, num_filter, groups, name=None): + conv = input + for i in range(groups): + conv = fluid.layers.conv2d( + input=conv, + num_filters=num_filter, + filter_size=3, + stride=1, + padding=1, + act='relu', + param_attr=fluid.param_attr.ParamAttr( + name=name + str(i + 1) + "_weights"), + bias_attr=False) + return fluid.layers.pool2d( + input=conv, pool_size=2, pool_type='max', pool_stride=2) + + +def VGG11(): + model = VGGNet(layers=11) + return model + + +def VGG13(): + model = VGGNet(layers=13) + return model + + +def VGG16(): + model = VGGNet(layers=16) + return model + + +def VGG19(): + model = VGGNet(layers=19) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/xception.py b/VisualFL/visualfl/algorithm/paddle_clas/models/xception.py new file mode 100755 index 000000000..de2b7646f --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/xception.py @@ -0,0 +1,280 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import math +import sys +from paddle.fluid.param_attr import ParamAttr + +__all__ = ['Xception', 'Xception41', 'Xception65', 'Xception71'] + + +class Xception(object): + """Xception""" + + def __init__(self, entry_flow_block_num=3, middle_flow_block_num=8): + self.entry_flow_block_num = entry_flow_block_num + self.middle_flow_block_num = middle_flow_block_num + return + + def net(self, input, class_dim=1000): + conv = self.entry_flow(input, self.entry_flow_block_num) + conv = self.middle_flow(conv, self.middle_flow_block_num) + conv = self.exit_flow(conv, class_dim) + + return conv + + def entry_flow(self, input, block_num=3): + '''xception entry_flow''' + name = "entry_flow" + conv = self.conv_bn_layer( + input=input, + num_filters=32, + filter_size=3, + stride=2, + act='relu', + name=name + "_conv1") + conv = self.conv_bn_layer( + input=conv, + num_filters=64, + filter_size=3, + stride=1, + act='relu', + name=name + "_conv2") + + if block_num == 3: + relu_first = [False, True, True] + num_filters = [128, 256, 728] + stride = [2, 2, 2] + elif block_num == 5: + relu_first = [False, True, True, True, True] + num_filters = [128, 256, 256, 728, 728] + stride = [2, 1, 2, 1, 2] + else: + sys.exit(-1) + + for block in range(block_num): + curr_name = "{}_{}".format(name, block) + conv = self.entry_flow_bottleneck_block( + conv, + num_filters=num_filters[block], + name=curr_name, + stride=stride[block], + relu_first=relu_first[block]) + + return conv + + def entry_flow_bottleneck_block(self, + input, + num_filters, + name, + stride=2, + relu_first=False): + '''entry_flow_bottleneck_block''' + short = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=1, + stride=stride, + padding=0, + act=None, + param_attr=ParamAttr(name + "_branch1_weights"), + bias_attr=False) + + conv0 = input + if relu_first: + conv0 = fluid.layers.relu(conv0) + + conv1 = self.separable_conv( + conv0, num_filters, stride=1, name=name + "_branch2a_weights") + + conv2 = fluid.layers.relu(conv1) + conv2 = self.separable_conv( + conv2, num_filters, stride=1, name=name + "_branch2b_weights") + + pool = fluid.layers.pool2d( + input=conv2, + pool_size=3, + pool_stride=stride, + pool_padding=1, + pool_type='max') + + return fluid.layers.elementwise_add(x=short, y=pool) + + def middle_flow(self, input, block_num=8): + '''xception middle_flow''' + num_filters = 728 + conv = input + for block in range(block_num): + name = "middle_flow_{}".format(block) + conv = self.middle_flow_bottleneck_block(conv, num_filters, name) + + return conv + + def middle_flow_bottleneck_block(self, input, num_filters, name): + '''middle_flow_bottleneck_block''' + conv0 = fluid.layers.relu(input) + conv0 = self.separable_conv( + conv0, + num_filters=num_filters, + stride=1, + name=name + "_branch2a_weights") + + conv1 = fluid.layers.relu(conv0) + conv1 = self.separable_conv( + conv1, + num_filters=num_filters, + stride=1, + name=name + "_branch2b_weights") + + conv2 = fluid.layers.relu(conv1) + conv2 = self.separable_conv( + conv2, + num_filters=num_filters, + stride=1, + name=name + "_branch2c_weights") + + return fluid.layers.elementwise_add(x=input, y=conv2) + + def exit_flow(self, input, class_dim): + '''xception exit flow''' + name = "exit_flow" + num_filters1 = 728 + num_filters2 = 1024 + conv0 = self.exit_flow_bottleneck_block( + input, num_filters1, num_filters2, name=name + "_1") + + conv1 = self.separable_conv( + conv0, num_filters=1536, stride=1, name=name + "_2") + conv1 = fluid.layers.relu(conv1) + + conv2 = self.separable_conv( + conv1, num_filters=2048, stride=1, name=name + "_3") + conv2 = fluid.layers.relu(conv2) + + pool = fluid.layers.pool2d( + input=conv2, pool_type='avg', global_pooling=True) + + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc( + input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + name='fc_weights', + initializer=fluid.initializer.Uniform(-stdv, stdv)), + bias_attr=fluid.param_attr.ParamAttr(name='fc_offset')) + + return out + + def exit_flow_bottleneck_block(self, input, num_filters1, num_filters2, + name): + '''entry_flow_bottleneck_block''' + short = fluid.layers.conv2d( + input=input, + num_filters=num_filters2, + filter_size=1, + stride=2, + padding=0, + act=None, + param_attr=ParamAttr(name + "_branch1_weights"), + bias_attr=False) + + conv0 = fluid.layers.relu(input) + conv1 = self.separable_conv( + conv0, num_filters1, stride=1, name=name + "_branch2a_weights") + + conv2 = fluid.layers.relu(conv1) + conv2 = self.separable_conv( + conv2, num_filters2, stride=1, name=name + "_branch2b_weights") + + pool = fluid.layers.pool2d( + input=conv2, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + return fluid.layers.elementwise_add(x=short, y=pool) + + def separable_conv(self, input, num_filters, stride=1, name=None): + """separable_conv""" + pointwise_conv = self.conv_bn_layer( + input=input, + filter_size=1, + num_filters=num_filters, + stride=1, + name=name + "_sep") + + depthwise_conv = self.conv_bn_layer( + input=pointwise_conv, + filter_size=3, + num_filters=num_filters, + stride=stride, + groups=num_filters, + use_cudnn=False, + name=name + "_dw") + + return depthwise_conv + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + use_cudnn=True, + name=None): + """conv_bn_layer""" + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + use_cudnn=use_cudnn) + + bn_name = "bn_" + name + + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + +def Xception41(): + model = Xception(entry_flow_block_num=3, middle_flow_block_num=8) + return model + + +def Xception65(): + model = Xception(entry_flow_block_num=3, middle_flow_block_num=16) + return model + + +def Xception71(): + model = Xception(entry_flow_block_num=5, middle_flow_block_num=16) + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/models/xception_deeplab.py b/VisualFL/visualfl/algorithm/paddle_clas/models/xception_deeplab.py new file mode 100755 index 000000000..a8138ed9f --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/models/xception_deeplab.py @@ -0,0 +1,272 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import contextlib +import paddle +import math +import paddle.fluid as fluid +from .model_libs import scope, name_scope +from .model_libs import bn, bn_relu, relu +from .model_libs import conv +from .model_libs import seperate_conv + +__all__ = ['Xception41_deeplab', 'Xception65_deeplab', 'Xception71_deeplab'] + +def check_data(data, number): + if type(data) == int: + return [data] * number + assert len(data) == number + return data + +def check_stride(s, os): + if s <= os: + return True + else: + return False + +def check_points(count, points): + if points is None: + return False + else: + if isinstance(points, list): + return (True if count in points else False) + else: + return (True if count == points else False) + +class Xception(): + def __init__(self, backbone="xception_65"): + self.bottleneck_params = self.gen_bottleneck_params(backbone) + self.backbone = backbone + + def gen_bottleneck_params(self, backbone='xception_65'): + if backbone == 'xception_65': + bottleneck_params = { + "entry_flow": (3, [2, 2, 2], [128, 256, 728]), + "middle_flow": (16, 1, 728), + "exit_flow": (2, [2, 1],[[728, 1024, 1024], [1536, 1536, 2048]]) + } + elif backbone == 'xception_41': + bottleneck_params = { + "entry_flow": (3, [2, 2, 2], [128, 256, 728]), + "middle_flow": (8, 1, 728), + "exit_flow": (2, [2, 1],[[728, 1024, 1024], [1536, 1536, 2048]]) + } + elif backbone == 'xception_71': + bottleneck_params = { + "entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]), + "middle_flow": (16, 1, 728), + "exit_flow": (2, [2, 1],[[728, 1024, 1024], [1536, 1536, 2048]]) + } + else: + raise Exception("xception backbont only support xception_41/xception_65/xception_71") + return bottleneck_params + + def net(self, + input, + output_stride=32, + class_dim=1000, + end_points=None, + decode_points=None): + self.stride = 2 + self.block_point = 0 + self.output_stride = output_stride + self.decode_points = decode_points + self.short_cuts = dict() + with scope(self.backbone): + # Entry flow + data = self.entry_flow(input) + if check_points(self.block_point, end_points): + return data, self.short_cuts + + # Middle flow + data = self.middle_flow(data) + if check_points(self.block_point, end_points): + return data, self.short_cuts + + # Exit flow + data = self.exit_flow(data) + if check_points(self.block_point, end_points): + return data, self.short_cuts + + data = fluid.layers.reduce_mean(data, [2, 3], keep_dim=True) + data = fluid.layers.dropout(data, 0.5) + stdv = 1.0 / math.sqrt(data.shape[1] * 1.0) + with scope("logit"): + out = fluid.layers.fc(input=data, size=class_dim, + param_attr=fluid.param_attr.ParamAttr(name='fc_weights', + initializer=fluid.initializer.Uniform(-stdv, stdv)), + bias_attr=fluid.param_attr.ParamAttr(name='fc_bias')) + + return out + + def entry_flow(self, data): + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=None, + initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.09)) + with scope("entry_flow"): + with scope("conv1"): + data = bn_relu(conv(data, 32, 3, stride=2, padding=1, param_attr=param_attr)) + with scope("conv2"): + data = bn_relu(conv(data, 64, 3, stride=1, padding=1, param_attr=param_attr)) + + # get entry flow params + block_num = self.bottleneck_params["entry_flow"][0] + strides = self.bottleneck_params["entry_flow"][1] + chns = self.bottleneck_params["entry_flow"][2] + strides = check_data(strides, block_num) + chns = check_data(chns, block_num) + + # params to control your flow + s = self.stride + block_point = self.block_point + output_stride = self.output_stride + with scope("entry_flow"): + for i in range(block_num): + block_point = block_point + 1 + with scope("block" + str(i + 1)): + stride = strides[i] if check_stride(s*strides[i], output_stride) else 1 + data, short_cuts = self.xception_block(data, chns[i], [1, 1, stride]) + s = s * stride + if check_points(block_point, self.decode_points): + self.short_cuts[block_point] = short_cuts[1] + + self.stride = s + self.block_point = block_point + return data + + def middle_flow(self, data): + block_num = self.bottleneck_params["middle_flow"][0] + strides = self.bottleneck_params["middle_flow"][1] + chns = self.bottleneck_params["middle_flow"][2] + strides = check_data(strides, block_num) + chns = check_data(chns, block_num) + + # params to control your flow + s = self.stride + block_point = self.block_point + output_stride = self.output_stride + with scope("middle_flow"): + for i in range(block_num): + block_point = block_point + 1 + with scope("block" + str(i + 1)): + stride = strides[i] if check_stride(s*strides[i], output_stride) else 1 + data, short_cuts = self.xception_block(data, chns[i], [1, 1, strides[i]], skip_conv=False) + s = s * stride + if check_points(block_point, self.decode_points): + self.short_cuts[block_point] = short_cuts[1] + + self.stride = s + self.block_point = block_point + return data + + def exit_flow(self, data): + block_num = self.bottleneck_params["exit_flow"][0] + strides = self.bottleneck_params["exit_flow"][1] + chns = self.bottleneck_params["exit_flow"][2] + strides = check_data(strides, block_num) + chns = check_data(chns, block_num) + + assert(block_num==2) + # params to control your flow + s = self.stride + block_point = self.block_point + output_stride = self.output_stride + with scope("exit_flow"): + with scope('block1'): + block_point += 1 + stride = strides[0] if check_stride(s*strides[0], output_stride) else 1 + data, short_cuts = self.xception_block(data, chns[0], [1, 1, stride]) + s = s * stride + if check_points(block_point, self.decode_points): + self.short_cuts[block_point] = short_cuts[1] + with scope('block2'): + block_point += 1 + stride = strides[1] if check_stride(s*strides[1], output_stride) else 1 + data, short_cuts = self.xception_block( + data, chns[1], [1, 1, stride], + dilation=2, + has_skip=False, + activation_fn_in_separable_conv=True) + s = s * stride + if check_points(block_point, self.decode_points): + self.short_cuts[block_point] = short_cuts[1] + + self.stride = s + self.block_point = block_point + return data + + def xception_block(self, + input, + channels, + strides=1, + filters=3, + dilation=1, + skip_conv=True, + has_skip=True, + activation_fn_in_separable_conv=False): + repeat_number = 3 + channels = check_data(channels, repeat_number) + filters = check_data(filters, repeat_number) + strides = check_data(strides, repeat_number) + data = input + results = [] + for i in range(repeat_number): + with scope('separable_conv' + str(i + 1)): + if not activation_fn_in_separable_conv: + data = relu(data) + data = seperate_conv( + data, + channels[i], + strides[i], + filters[i], + dilation=dilation) + else: + data = seperate_conv( + data, + channels[i], + strides[i], + filters[i], + dilation=dilation, + act=relu) + results.append(data) + if not has_skip: + return data, results + if skip_conv: + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=None, + initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.09)) + with scope('shortcut'): + skip = bn( + conv(input, channels[-1], 1, strides[-1], groups=1, + padding=0, param_attr=param_attr)) + else: + skip = input + return data + skip, results + +def Xception41_deeplab(): + model = Xception("xception_41") + return model + +def Xception65_deeplab(): + model = Xception("xception_65") + return model + +def Xception71_deeplab(): + model = Xception("xception_71") + return model diff --git a/VisualFL/visualfl/algorithm/paddle_clas/test/__init__.py b/VisualFL/visualfl/algorithm/paddle_clas/test/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/test/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/algorithm/paddle_clas/test/data_loader_test.py b/VisualFL/visualfl/algorithm/paddle_clas/test/data_loader_test.py new file mode 100644 index 000000000..d1e553d80 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_clas/test/data_loader_test.py @@ -0,0 +1,55 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +from __future__ import print_function + +from visualfl.utils import data_loader +import unittest + + +class TestDataLoader(unittest.TestCase): + def check_reader(self, reader): + sum = 0 + label = 0 + size = 224 * 224 * 3 + for l in reader(): + self.assertEqual(l[0].size, size) + if l[1] > label: + label = l[1] + sum += 1 + return sum, label + + def test_train(self): + instances, max_label_value = self.check_reader( + data_loader.train()) + self.assertEqual(instances, 7169) + self.assertEqual(max_label_value, 101) + + def test_test(self): + instances, max_label_value = self.check_reader( + data_loader.test()) + self.assertEqual(instances, 1020) + self.assertEqual(max_label_value, 101) + + def test_valid(self): + instances, max_label_value = self.check_reader( + data_loader.valid()) + self.assertEqual(instances, 1020) + self.assertEqual(max_label_value, 101) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/VisualFL/visualfl/algorithm/paddle_detection/README.md b/VisualFL/visualfl/algorithm/paddle_detection/README.md new file mode 100755 index 000000000..3404763bb --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/README.md @@ -0,0 +1,12 @@ +## Note + +This folder implement detection algorithms based on [PaddleDetection project](https://github.com/PaddlePaddle/PaddleDetection). + +The original detection programs are compiled into single server program and multiple trainer program by paddle_fl. +Thus, all algorithms supported by [PaddleDetection project](https://github.com/PaddlePaddle/PaddleDetection) has their `federated counterpart`. + +While, there are plenty test should be task before we could claim it. + +## How to use + +See [example](../../../../examples/paddle_detection) diff --git a/VisualFL/visualfl/algorithm/paddle_detection/__init__.py b/VisualFL/visualfl/algorithm/paddle_detection/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/algorithm/paddle_detection/_empty_optimizer.py b/VisualFL/visualfl/algorithm/paddle_detection/_empty_optimizer.py new file mode 100644 index 000000000..28952d234 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/_empty_optimizer.py @@ -0,0 +1,17 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class EmptyOptimizer(object): + def minimize(self, loss): + ... diff --git a/VisualFL/visualfl/algorithm/paddle_detection/_merge_config.py b/VisualFL/visualfl/algorithm/paddle_detection/_merge_config.py new file mode 100644 index 000000000..4586263d8 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/_merge_config.py @@ -0,0 +1,40 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from visualfl import __basedir__ +from ppdet.core.workspace import create,load_config + +def merger_algorithm_config(algorithm_config,data_name=None): + + + program_full_path = os.path.join(__basedir__, 'algorithm', 'paddle_detection') + architecture = algorithm_config["architecture"] + config_name = f'{architecture}.yml' + default_algorithm_config = os.path.join(program_full_path, "configs", architecture.split('_')[0], config_name) + cfg = load_config(default_algorithm_config) + + cfg["max_iters"] = algorithm_config["max_iter"] + cfg["inner_step"] = algorithm_config["inner_step"] + num_class = algorithm_config["num_classes"] + cfg["num_classes"] = num_class if architecture.startswith("yolo") or architecture.startswith("ppyolo") else num_class+1 + cfg["LearningRate"].base_lr = algorithm_config["base_lr"] + # cfg.TrainReader["inputs_def"]["image_shape"] = algorithm_config["image_shape"] + cfg.TrainReader["batch_size"] = algorithm_config["batch_size"] + cfg.TrainReader["dataset"].dataset_dir = data_name + cfg.EvalReader["dataset"].dataset_dir = data_name + + return cfg + + diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/ppyolo/ppyolo_r18vd.yml b/VisualFL/visualfl/algorithm/paddle_detection/configs/ppyolo/ppyolo_r18vd.yml new file mode 100755 index 000000000..c3ab5a8ff --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/ppyolo/ppyolo_r18vd.yml @@ -0,0 +1,127 @@ +architecture: YOLOv3 +use_gpu: false +max_iters: 250000 +log_iter: 20 +save_dir: output +snapshot_iter: 10000 +metric: VOC +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar +weights: output/ppyolo_tiny/model_final +num_classes: 4 +use_fine_grained_loss: false +use_ema: true + + +YOLOv3: + backbone: ResNet + yolo_head: YOLOv3Head + use_fine_grained_loss: false + +ResNet: + norm_type: sync_bn + freeze_at: 0 + freeze_norm: false + norm_decay: 0. + depth: 18 + feature_maps: [4, 5] + variant: d + +YOLOv3Head: + anchor_masks: [[3, 4, 5], [0, 1, 2]] + anchors: [[10, 14], [23, 27], [37, 58], + [81, 82], [135, 169], [344, 319]] + norm_decay: 0. + conv_block_num: 0 + scale_x_y: 1.05 + yolo_loss: YOLOv3Loss + nms: MatrixNMS + +YOLOv3Loss: + ignore_thresh: 0.7 + scale_x_y: 1.05 + label_smooth: false + use_fine_grained_loss: false + iou_loss: IouLoss + +IouLoss: + loss_weight: 2.5 + max_height: 608 + max_width: 608 + +MatrixNMS: + background_label: -1 + keep_top_k: 100 + normalized: false + score_threshold: 0.01 + post_threshold: 0.01 + +LearningRate: + base_lr: 0.004 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 150000 + - 200000 + - !LinearWarmup + start_factor: 0. + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +_READER_: 'ppyolo_reader.yml' +TrainReader: + inputs_def: + fields: ['image', 'gt_bbox', 'gt_class', 'gt_score'] + num_max_boxes: 50 + dataset: + !VOCDataSet + anno_path: train.txt + dataset_dir: dataset/fruit + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + with_mixup: True + - !MixupImage + alpha: 1.5 + beta: 1.5 + - !ColorDistort {} + - !RandomExpand + fill_value: [123.675, 116.28, 103.53] + - !RandomCrop {} + - !RandomFlipImage + is_normalized: false + - !NormalizeBox {} + - !PadBox + num_max_boxes: 50 + - !BboxXYXY2XYWH {} + batch_transforms: + - !RandomShape + sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + random_inter: True + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + # Gt2YoloTarget is only used when use_fine_grained_loss set as true, + # this operator will be deleted automatically if use_fine_grained_loss + # is set as false + - !Gt2YoloTarget + anchor_masks: [[3, 4, 5], [0, 1, 2]] + anchors: [[10, 14], [23, 27], [37, 58], + [81, 82], [135, 169], [344, 319]] + downsample_ratios: [32, 16] + batch_size: 32 + + diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/ppyolo/ppyolo_reader.yml b/VisualFL/visualfl/algorithm/paddle_detection/configs/ppyolo/ppyolo_reader.yml new file mode 100644 index 000000000..d36f8aa80 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/ppyolo/ppyolo_reader.yml @@ -0,0 +1,102 @@ +TrainReader: + inputs_def: + fields: ['image', 'gt_bbox', 'gt_class', 'gt_score'] + num_max_boxes: 50 + dataset: + !VOCDataSet + anno_path: train.txt + dataset_dir: dataset/fruit + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + with_mixup: True + - !MixupImage + alpha: 1.5 + beta: 1.5 + - !ColorDistort {} + - !RandomExpand + ratio: 2.0 + fill_value: [123.675, 116.28, 103.53] + - !RandomCrop {} + - !RandomFlipImage + is_normalized: false + - !NormalizeBox {} + - !PadBox + num_max_boxes: 50 + - !BboxXYXY2XYWH {} + batch_transforms: + - !RandomShape + sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + random_inter: True + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + # Gt2YoloTarget is only used when use_fine_grained_loss set as true, + # this operator will be deleted automatically if use_fine_grained_loss + # is set as false + - !Gt2YoloTarget + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + downsample_ratios: [32, 16, 8] + batch_size: 24 + + +EvalReader: + inputs_def: + fields: ['image', 'im_size', 'im_id'] + num_max_boxes: 50 + dataset: + !VOCDataSet + anno_path: val.txt + dataset_dir: dataset/fruit + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 608 + interp: 2 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !PadBox + num_max_boxes: 50 + - !Permute + to_bgr: false + channel_first: True + batch_size: 8 + + +TestReader: + inputs_def: + image_shape: [3, 608, 608] + fields: ['image', 'im_size', 'im_id'] + dataset: + !ImageFolder + anno_path: test.txt + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 608 + interp: 2 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/ssd/ssd_vgg16_300_voc.yml b/VisualFL/visualfl/algorithm/paddle_detection/configs/ssd/ssd_vgg16_300_voc.yml new file mode 100644 index 000000000..93fd89e40 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/ssd/ssd_vgg16_300_voc.yml @@ -0,0 +1,145 @@ +architecture: SSD +use_gpu: false +max_iters: 120001 +snapshot_iter: 10000 +log_iter: 20 +metric: VOC +map_type: 11point +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_caffe_pretrained.tar +save_dir: output +weights: output/ssd_vgg16_300_voc/model_final +# 20(label_class) + 1(background) +num_classes: 4 + +SSD: + backbone: VGG + multi_box_head: MultiBoxHead + output_decoder: + background_label: 0 + keep_top_k: 200 + nms_eta: 1.0 + nms_threshold: 0.45 + nms_top_k: 400 + score_threshold: 0.01 + +VGG: + depth: 16 + with_extra_blocks: true + normalizations: [20., -1, -1, -1, -1, -1] + +MultiBoxHead: + base_size: 300 + aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]] + min_ratio: 20 + max_ratio: 90 + min_sizes: [30.0, 60.0, 111.0, 162.0, 213.0, 264.0] + max_sizes: [60.0, 111.0, 162.0, 213.0, 264.0, 315.0] + steps: [8, 16, 32, 64, 100, 300] + offset: 0.5 + flip: true + min_max_aspect_ratios_order: true + kernel_size: 3 + pad: 1 + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [80000, 100000] + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +TrainReader: + inputs_def: + image_shape: [3, 300, 300] + fields: ['image', 'gt_bbox', 'gt_class'] + dataset: + !VOCDataSet + dataset_dir: dataset/fruit + anno_path: train.txt + use_default_label: false + sample_transforms: + - !DecodeImage + to_rgb: true + - !RandomDistort + brightness_lower: 0.875 + brightness_upper: 1.125 + is_order: true + - !RandomExpand + fill_value: [104, 117, 123] + - !RandomCrop + allow_no_crop: true + - !NormalizeBox {} + - !ResizeImage + interp: 1 + target_size: 300 + use_cv2: false + - !RandomFlipImage + is_normalized: true + - !Permute + to_bgr: false + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [1, 1, 1] + batch_size: 8 + shuffle: true + +EvalReader: + inputs_def: + image_shape: [3, 300, 300] + fields: ['image', 'gt_bbox', 'gt_class', 'im_shape', 'im_id', 'is_difficult'] + dataset: + !VOCDataSet + anno_path: val.txt + dataset_dir: dataset/fruit + use_default_label: false + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !ResizeImage + interp: 1 + target_size: 300 + use_cv2: false + - !Permute + to_bgr: false + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [1, 1, 1] + batch_size: 32 + + +TestReader: + inputs_def: + image_shape: [3,300,300] + fields: ['image', 'im_id', 'im_shape'] + dataset: + !ImageFolder + anno_path: test.txt + use_default_label: false + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !ResizeImage + interp: 1 + max_size: 0 + target_size: 300 + use_cv2: true + - !Permute + to_bgr: false + - !NormalizeImage + is_scale: false + mean: [104, 117, 123] + std: [1, 1, 1] + batch_size: 1 diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/ssd/ssd_vgg16_512_voc.yml b/VisualFL/visualfl/algorithm/paddle_detection/configs/ssd/ssd_vgg16_512_voc.yml new file mode 100644 index 000000000..a5735d908 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/ssd/ssd_vgg16_512_voc.yml @@ -0,0 +1,148 @@ +architecture: SSD +use_gpu: false +max_iters: 120000 +snapshot_iter: 10000 +log_iter: 20 +metric: VOC +map_type: 11point +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_caffe_pretrained.tar +save_dir: output +weights: output/ssd_vgg16_512_voc/model_final +# 20(label_class) + 1(background) +num_classes: 4 + +SSD: + backbone: VGG + multi_box_head: MultiBoxHead + output_decoder: + background_label: 0 + keep_top_k: 200 + nms_eta: 1.0 + nms_threshold: 0.45 + nms_top_k: 400 + score_threshold: 0.01 + +VGG: + depth: 16 + with_extra_blocks: true + normalizations: [20., -1, -1, -1, -1, -1, -1] + extra_block_filters: [[256, 512, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 1, 1, 4]] + + +MultiBoxHead: + base_size: 512 + aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]] + min_ratio: 20 + max_ratio: 90 + min_sizes: [20.0, 51.0, 133.0, 215.0, 296.0, 378.0, 460.0] + max_sizes: [51.0, 133.0, 215.0, 296.0, 378.0, 460.0, 542.0] + steps: [8, 16, 32, 64, 128, 256, 512] + offset: 0.5 + flip: true + kernel_size: 3 + pad: 1 + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [80000, 100000] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +TrainReader: + inputs_def: + image_shape: [3, 512, 512] + fields: ['image', 'gt_bbox', 'gt_class'] + dataset: + !VOCDataSet + dataset_dir: dataset/fruit + anno_path: train.txt + use_default_label: false + sample_transforms: + - !DecodeImage + to_rgb: true + - !RandomDistort + brightness_lower: 0.875 + brightness_upper: 1.125 + is_order: true + - !RandomExpand + fill_value: [123, 117, 104] + - !RandomCrop + allow_no_crop: true + - !NormalizeBox {} + - !ResizeImage + interp: 1 + target_size: 512 + use_cv2: false + - !RandomFlipImage + is_normalized: true + - !Permute + to_bgr: false + - !NormalizeImage + is_scale: false + mean: [123, 117, 104] + std: [1, 1, 1] + batch_size: 8 + shuffle: true + +EvalReader: + inputs_def: + image_shape: [3, 512, 512] + fields: ['image', 'gt_bbox', 'gt_class', 'im_shape', 'im_id', 'is_difficult'] + dataset: + !VOCDataSet + anno_path: val.txt + dataset_dir: dataset/fruit + use_default_label: false + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !ResizeImage + interp: 1 + target_size: 512 + use_cv2: false + - !Permute + to_bgr: false + - !NormalizeImage + is_scale: false + mean: [123, 117, 104] + std: [1, 1, 1] + batch_size: 32 + +TestReader: + inputs_def: + image_shape: [3,512,512] + fields: ['image', 'im_id', 'im_shape'] + dataset: + !ImageFolder + anno_path: test.txt + use_default_label: false + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !ResizeImage + interp: 1 + max_size: 0 + target_size: 512 + use_cv2: true + - !Permute + to_bgr: false + - !NormalizeImage + is_scale: false + mean: [123, 117, 104] + std: [1, 1, 1] + batch_size: 1 diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_darknet_voc.yml b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_darknet_voc.yml new file mode 100644 index 000000000..b87089430 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_darknet_voc.yml @@ -0,0 +1,89 @@ +architecture: YOLOv3 +use_gpu: false +max_iters: 70000 +log_iter: 20 +save_dir: output +snapshot_iter: 2000 +metric: VOC +map_type: 11point +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar +weights: output/yolov3_darknet_voc/model_final +num_classes: 4 +use_fine_grained_loss: false + +YOLOv3: + backbone: DarkNet + yolo_head: YOLOv3Head + +DarkNet: + norm_type: sync_bn + norm_decay: 0. + depth: 53 + +YOLOv3Head: + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + norm_decay: 0. + yolo_loss: YOLOv3Loss + nms: + background_label: -1 + keep_top_k: 100 + nms_threshold: 0.45 + nms_top_k: 1000 + normalized: false + score_threshold: 0.01 + +YOLOv3Loss: + ignore_thresh: 0.7 + label_smooth: false + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 55000 + - 62000 + - !LinearWarmup + start_factor: 0. + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +_READER_: 'yolov3_reader.yml' +TrainReader: + inputs_def: + fields: ['image', 'gt_bbox', 'gt_class', 'gt_score'] + num_max_boxes: 50 + dataset: + !VOCDataSet + dataset_dir: dataset/fruit + anno_path: train.txt + use_default_label: false + with_background: false + +EvalReader: + inputs_def: + fields: ['image', 'im_size', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult'] + num_max_boxes: 50 + dataset: + !VOCDataSet + dataset_dir: dataset/fruit + anno_path: val.txt + use_default_label: false + with_background: false + +TestReader: + dataset: + !ImageFolder + use_default_label: false + with_background: false diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_mobilenet_v1_voc.yml b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_mobilenet_v1_voc.yml new file mode 100644 index 000000000..84c552cf0 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_mobilenet_v1_voc.yml @@ -0,0 +1,87 @@ +architecture: YOLOv3 +use_gpu: false +max_iters: 70000 +log_iter: 20 +save_dir: output +snapshot_iter: 2000 +metric: VOC +map_type: 11point +pretrain_weights: http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar +weights: output/yolov3_mobilenet_v1_voc/model_final +num_classes: 4 +use_fine_grained_loss: false + +YOLOv3: + backbone: MobileNet + yolo_head: YOLOv3Head + +MobileNet: + norm_type: sync_bn + norm_decay: 0. + conv_group_scale: 1 + with_extra_blocks: false + +YOLOv3Head: + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + norm_decay: 0. + yolo_loss: YOLOv3Loss + nms: + background_label: -1 + keep_top_k: 100 + nms_threshold: 0.45 + nms_top_k: 1000 + normalized: false + score_threshold: 0.01 + +YOLOv3Loss: + ignore_thresh: 0.7 + label_smooth: false + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 55000 + - 62000 + - !LinearWarmup + start_factor: 0. + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +_READER_: 'yolov3_reader.yml' +TrainReader: + dataset: + !VOCDataSet + dataset_dir: dataset/fruit + anno_path: train.txt + use_default_label: false + with_background: false + +EvalReader: + inputs_def: + fields: ['image', 'im_size', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult'] + num_max_boxes: 50 + dataset: + !VOCDataSet + dataset_dir: dataset/fruit + anno_path: val.txt + use_default_label: false + with_background: false + +TestReader: + dataset: + !ImageFolder + use_default_label: false + with_background: false diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_r34_voc.yml b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_r34_voc.yml new file mode 100644 index 000000000..aa0b3d166 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_r34_voc.yml @@ -0,0 +1,89 @@ +architecture: YOLOv3 +use_gpu: false +max_iters: 70000 +log_iter: 20 +save_dir: output +snapshot_iter: 2000 +metric: VOC +map_type: 11point +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar +weights: output/yolov3_r34_voc/model_final +num_classes: 4 +use_fine_grained_loss: false + +YOLOv3: + backbone: ResNet + yolo_head: YOLOv3Head + +ResNet: + norm_type: sync_bn + freeze_at: 0 + freeze_norm: false + norm_decay: 0. + depth: 34 + feature_maps: [3, 4, 5] + +YOLOv3Head: + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + norm_decay: 0. + yolo_loss: YOLOv3Loss + nms: + background_label: -1 + keep_top_k: 100 + nms_threshold: 0.45 + nms_top_k: 1000 + normalized: false + score_threshold: 0.01 + +YOLOv3Loss: + ignore_thresh: 0.7 + label_smooth: false + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 55000 + - 62000 + - !LinearWarmup + start_factor: 0. + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +_READER_: 'yolov3_reader.yml' +TrainReader: + dataset: + !VOCDataSet + dataset_dir: dataset/fruit + anno_path: train.txt + use_default_label: false + with_background: false + +EvalReader: + inputs_def: + fields: ['image', 'im_size', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult'] + num_max_boxes: 50 + dataset: + !VOCDataSet + dataset_dir: dataset/fruit + anno_path: val.txt + use_default_label: false + with_background: false + +TestReader: + dataset: + !ImageFolder + use_default_label: false + with_background: false diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_reader.yml b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_reader.yml new file mode 100755 index 000000000..131660932 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov3/yolov3_reader.yml @@ -0,0 +1,80 @@ +TrainReader: + inputs_def: + image_shape: [3, 608, 608] + fields: ['image', 'gt_bbox', 'gt_class', 'gt_score'] + num_max_boxes: 50 + dataset: + !VOCDataSet + dataset_dir: fruit + anno_path: train.txt + with_background: false + use_default_label: false + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeBox {} + - !ExpandImage + max_ratio: 4.0 + mean: [123.675, 116.28, 103.53] + prob: 0.5 + - !RandomInterpImage + max_size: 0 + target_size: 608 + - !RandomFlipImage + is_normalized: true + prob: 0.5 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: true + is_channel_first: false + - !PadBox + num_max_boxes: 50 + - !BboxXYXY2XYWH {} + batch_transforms: + - !RandomShape + sizes: [608] + - !Permute + channel_first: true + to_bgr: false + batch_size: 1 + shuffle: true + mixup_epoch: -1 + +EvalReader: + batch_size: 1 + inputs_def: + image_shape: [3, 608, 608] + fields: ['image', 'im_size', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult'] + num_max_boxes: 50 + dataset: + !VOCDataSet + dataset_dir: fruit + anno_path: val.txt + use_default_label: false + with_background: false + +TestReader: + inputs_def: + image_shape: [3, 608, 608] + fields: ['image', 'im_size', 'im_id'] + dataset: + !ImageFolder + anno_path: test.txt + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 608 + interp: 2 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov4/README.md b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov4/README.md new file mode 100755 index 000000000..9394975af --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov4/README.md @@ -0,0 +1,62 @@ +# YOLO v4 模型 + +## 内容 +- [简介](#简介) +- [模型库与基线](#模型库与基线) +- [未来工作](#未来工作) +- [如何贡献代码](#如何贡献代码) + +## 简介 + +[YOLO v4](https://arxiv.org/abs/2004.10934)的Paddle实现版本,要求使用PaddlePaddle2.0.0及以上版本或适当的develop版本 + +目前转换了[darknet](https://github.com/AlexeyAB/darknet)中YOLO v4的权重,可以直接对图片进行预测,在[test-dev2019](http://cocodataset.org/#detection-2019)中精度为43.5%。另外,支持VOC数据集上finetune,精度达到85.5% + +目前支持YOLO v4的多个模块: + +- mish激活函数 +- PAN模块 +- SPP模块 +- ciou loss +- label_smooth +- grid_sensitive + +目前支持YOLO系列的Anchor聚类算法 +``` bash +python tools/anchor_cluster.py -c ${config} -m ${method} -s ${size} +``` +主要参数配置参考下表 +| 参数 | 用途 | 默认值 | 备注 | +|:------:|:------:|:------:|:------:| +| -c/--config | 模型的配置文件 | 无默认值 | 必须指定 | +| -n/--n | 聚类的簇数 | 9 | Anchor的数目 | +| -s/--size | 图片的输入尺寸 | None | 若指定,则使用指定的尺寸,如果不指定, 则尝试从配置文件中读取图片尺寸 | +| -m/--method | 使用的Anchor聚类方法 | v2 | 目前只支持yolov2/v5的聚类算法 | +| -i/--iters | kmeans聚类算法的迭代次数 | 1000 | kmeans算法收敛或者达到迭代次数后终止 | +| -gi/--gen_iters | 遗传算法的迭代次数 | 1000 | 该参数只用于yolov5的Anchor聚类算法 | +| -t/--thresh| Anchor尺度的阈值 | 0.25 | 该参数只用于yolov5的Anchor聚类算法 | + +## 模型库 +下表中展示了当前支持的网络结构。 + +| | GPU个数 | 测试集 | 骨干网络 | 精度 | 模型下载 | 配置文件 | +|:------------------------:|:-------:|:------:|:--------------------------:|:------------------------:| :---------:| :-----: | +| YOLO v4 | - |test-dev2019 | CSPDarkNet53 | 43.5 |[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolov4/yolov4_cspdarknet.yml) | +| YOLO v4 VOC | 2 | VOC2007 | CSPDarkNet53 | 85.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolov4/yolov4_cspdarknet_voc.yml) | + +**注意:** + +- 由于原版YOLO v4使用coco trainval2014进行训练,训练样本中包含部分评估样本,若使用val集会导致精度虚高,因此使用coco test集对模型进行评估。 +- YOLO v4模型仅支持coco test集评估和图片预测,由于test集不包含目标框的真实标注,评估时会将预测结果保存在json文件中,请将结果提交至[cocodataset](http://cocodataset.org/#detection-2019)上查看最终精度指标。 +- coco测试集使用test2017,下载请参考[coco2017](http://cocodataset.org/#download) + + +## 未来工作 + +1. mish激活函数优化 +2. mosaic数据预处理实现 + + + +## 如何贡献代码 +我们非常欢迎您可以为PaddleDetection提供代码,您可以提交PR供我们review;也十分感谢您的反馈,可以提交相应issue,我们会及时解答。 diff --git a/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov4/yolov4_cspdarknet_voc.yml b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov4/yolov4_cspdarknet_voc.yml new file mode 100644 index 000000000..c88008a70 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/configs/yolov4/yolov4_cspdarknet_voc.yml @@ -0,0 +1,165 @@ +architecture: YOLOv4 +use_gpu: false +max_iters: 140000 +log_iter: 20 +save_dir: output +snapshot_iter: 1000 +metric: VOC +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams +weights: output/yolov4_cspdarknet_voc/model_final +num_classes: 3 +use_fine_grained_loss: false + +YOLOv4: + backbone: CSPDarkNet + yolo_head: YOLOv4Head + +CSPDarkNet: + norm_type: sync_bn + norm_decay: 0. + depth: 53 + +YOLOv4Head: + anchors: [[12, 16], [19, 36], [40, 28], [36, 75], [76, 55], + [72, 146], [142, 110], [192, 243], [459, 401]] + anchor_masks: [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + nms: + background_label: -1 + keep_top_k: -1 + nms_threshold: 0.45 + nms_top_k: -1 + normalized: true + score_threshold: 0.001 + downsample: [8,16,32] + scale_x_y: [1.2, 1.1, 1.05] + +YOLOv3Loss: + ignore_thresh: 0.7 + label_smooth: true + downsample: [8,16,32] + scale_x_y: [1.2, 1.1, 1.05] + iou_loss: IouLoss + match_score: true + +IouLoss: + loss_weight: 0.07 + max_height: 608 + max_width: 608 + ciou_term: true + loss_square: true + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 110000 + - 130000 + - !LinearWarmup + start_factor: 0. + steps: 1000 + +OptimizerBuilder: + clip_grad_by_norm: 10. + optimizer: + momentum: 0.949 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +_READER_: '../yolov3/yolov3_reader.yml' +TrainReader: + inputs_def: + fields: ['image', 'gt_bbox', 'gt_class', 'gt_score'] + num_max_boxes: 50 + dataset: + !VOCDataSet + anno_path: train.txt + dataset_dir: dataset/fruit + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ColorDistort {} + - !RandomExpand + fill_value: [123.675, 116.28, 103.53] + - !RandomCrop {} + - !RandomFlipImage + is_normalized: false + - !NormalizeBox {} + - !PadBox + num_max_boxes: 50 + - !BboxXYXY2XYWH {} + batch_transforms: + - !RandomShape + sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + random_inter: True + - !NormalizeImage + mean: [0.,0.,0.] + std: [1.,1.,1.] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + # Gt2YoloTarget is only used when use_fine_grained_loss set as true, + # this operator will be deleted automatically if use_fine_grained_loss + # is set as false + - !Gt2YoloTarget + anchor_masks: [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + anchors: [[12, 16], [19, 36], [40, 28], + [36, 75], [76, 55], [72, 146], + [142, 110], [192, 243], [459, 401]] + downsample_ratios: [8, 16, 32] + batch_size: 4 + shuffle: true + +EvalReader: + inputs_def: + fields: ['image', 'im_size', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult'] + num_max_boxes: 90 + dataset: + !VOCDataSet + anno_path: val.txt + dataset_dir: dataset/fruit + use_default_label: false + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 608 + interp: 1 + - !NormalizeImage + mean: [0., 0., 0.] + std: [1., 1., 1.] + is_scale: True + is_channel_first: false + - !PadBox + num_max_boxes: 90 + - !Permute + to_bgr: false + channel_first: True + batch_size: 4 + +TestReader: + dataset: + !ImageFolder + use_default_label: false + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 608 + interp: 1 + - !NormalizeImage + mean: [0., 0., 0.] + std: [1., 1., 1.] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True diff --git a/VisualFL/visualfl/algorithm/paddle_detection/export_serving_model.py b/VisualFL/visualfl/algorithm/paddle_detection/export_serving_model.py new file mode 100644 index 000000000..ca9de42b9 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/export_serving_model.py @@ -0,0 +1,133 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os, sys +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +import paddle +from paddle import fluid + +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_config, check_version, enable_static_mode +import ppdet.utils.checkpoint as checkpoint +import yaml,json +from visualfl.db.task_dao import TaskDao +from visualfl.utils.consts import TaskResultType + +import logging +from ppdet.utils.export_utils import dump_infer_config, prune_feed_vars +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def save_serving_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): + cfg_name = os.path.basename(FLAGS.config).split('.')[0] + save_dir = os.path.join(FLAGS.output_dir, cfg_name) + feed_var_names = [var.name for var in feed_vars.values()] + fetch_list = sorted(test_fetches.items(), key=lambda i: i[0]) + target_vars = [var[1] for var in fetch_list] + feed_var_names = prune_feed_vars(feed_var_names, target_vars, infer_prog) + serving_client = os.path.join(FLAGS.output_dir, 'serving_client') + serving_server = os.path.join(FLAGS.output_dir, 'serving_server') + logger.info( + "Export serving model to {}, client side: {}, server side: {}. input: {}, output: " + "{}...".format(FLAGS.output_dir, serving_client, serving_server, + feed_var_names, [str(var.name) for var in target_vars])) + feed_dict = {x: infer_prog.global_block().var(x) for x in feed_var_names} + fetch_dict = {x.name: x for x in target_vars} + import paddle_serving_client.io as serving_io + serving_client = os.path.join(save_dir, 'serving_client') + serving_server = os.path.join(save_dir, 'serving_server') + serving_io.save_model( + client_config_folder=serving_client, + server_model_folder=serving_server, + feed_var_dict=feed_dict, + fetch_var_dict=fetch_dict, + main_program=infer_prog) + + +def main(): + cfg = load_config(FLAGS.config) + merge_config(FLAGS.opt) + check_config(cfg) + check_version() + + main_arch = cfg.architecture + + dataset = cfg.TestReader['dataset'] + task_result = TaskDao(FLAGS.task_id).get_task_result(TaskResultType.LABEL) + if task_result: + dataset.anno_path = json.loads(task_result.result).get("label_path") + + # Use CPU for exporting inference model instead of GPU + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + model = create(main_arch) + + startup_prog = fluid.Program() + infer_prog = fluid.Program() + with fluid.program_guard(infer_prog, startup_prog): + with fluid.unique_name.guard(): + inputs_def = cfg['TestReader']['inputs_def'] + inputs_def['use_dataloader'] = False + feed_vars, _ = model.build_inputs(**inputs_def) + test_fetches = model.test(feed_vars) + infer_prog = infer_prog.clone(True) + + exe.run(startup_prog) + checkpoint.load_params(exe, infer_prog, cfg.weights) + + save_serving_model(FLAGS, exe, feed_vars, test_fetches, infer_prog) + dump_infer_config(FLAGS, cfg) + + +if __name__ == '__main__': + enable_static_mode() + parser = ArgsParser() + parser.add_argument( + "--task_id", + type=str, + default=None) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory for storing the output model files.") + + FLAGS = parser.parse_args() + main() diff --git a/VisualFL/visualfl/algorithm/paddle_detection/fl_master.py b/VisualFL/visualfl/algorithm/paddle_detection/fl_master.py new file mode 100644 index 000000000..f405387fd --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/fl_master.py @@ -0,0 +1,133 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import json +import logging +import click +from paddle import fluid +from visualfl.algorithm.paddle_detection._empty_optimizer import ( + EmptyOptimizer, +) +from paddle_fl.core.master.job_generator import JobGenerator +from paddle_fl.core.strategy.fl_strategy_base import ( + FedAvgStrategy, +) +from ppdet.core.workspace import load_config, create +from ppdet.utils.check import check_version, check_config +from visualfl.algorithm.paddle_detection._merge_config import merger_algorithm_config + + +class Model(object): + def __init__(self): + self.feeds = None + self.startup_program = None + self.loss = None + + def build_program(self, config): + + with open(config) as f: + algorithm_config_dict = json.load(f) + + cfg = merger_algorithm_config(algorithm_config_dict) + check_config(cfg) + check_version() + + lr_builder = create("LearningRate") + optimizer_builder = create("OptimizerBuilder") + + # build program + self.startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, self.startup_program): + with fluid.unique_name.guard(): + model = create(cfg.architecture) + + inputs_def = cfg["TrainReader"]["inputs_def"] + # can't compile with dataloader now. + inputs_def["use_dataloader"] = False + feed_vars, _ = model.build_inputs(**inputs_def) + + train_fetches = model.train(feed_vars) + loss = train_fetches["loss"] + lr = lr_builder() + optimizer = optimizer_builder(lr) + optimizer.minimize(loss) + + self.loss = loss + self.feeds = feed_vars + + +@click.command() +@click.option("--ps-endpoint", type=str, required=True) +@click.option( + "-c", + "--config", + type=click.Path(file_okay=True, dir_okay=False, exists=True), + required=True, +) +@click.option( + "--algorithm-config", type=click.Path(exists=True, file_okay=True, dir_okay=False) +) +def fl_master(algorithm_config, ps_endpoint, config): + logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s-%(levelname)s: %(message)s" + ) + logger = logging.getLogger(__name__) # noqa: F841 + with open(config) as f: + config_json = json.load(f) + worker_num = config_json["worker_num"] + + model = Model() + model.build_program(algorithm_config) + + job_generator = JobGenerator() + job_generator.set_losses([model.loss]) + job_generator.set_optimizer(EmptyOptimizer()) # optimizer defined in Model + job_generator.set_startup_program(model.startup_program) + job_generator.set_infer_feed_and_target_names( + [name for name in model.feeds], [model.loss.name] + ) + job_generator.set_feeds(model.feeds.values()) + + strategy = FedAvgStrategy() + strategy.fed_avg = True + strategy.inner_step = 1 + + endpoints = [ps_endpoint] + output = "compile" + job_generator.generate_fl_job( + strategy, server_endpoints=endpoints, worker_num=worker_num, output=output + ) + + +if __name__ == "__main__": + fl_master() diff --git a/VisualFL/visualfl/algorithm/paddle_detection/fl_trainer.py b/VisualFL/visualfl/algorithm/paddle_detection/fl_trainer.py new file mode 100644 index 000000000..0c602baf3 --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/fl_trainer.py @@ -0,0 +1,260 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 The FedVision Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import traceback +import click +from visualfl import __logs_dir__ +from visualfl.paddle_fl.trainer._trainer import FedAvgTrainer +from visualfl import get_data_dir +from visualfl.db.task_dao import TaskDao +from visualdl import LogWriter,LogReader +from visualfl.utils.consts import TaskStatus,ComponentName,TaskResultType +from visualfl.utils.tools import * +from visualfl.algorithm.paddle_detection._merge_config import merger_algorithm_config + +@click.command() +@click.option("--job-id", type=str, required=True) +@click.option("--task-id", type=str, required=True) +@click.option("--scheduler-ep", type=str, required=True) +@click.option("--trainer-id", type=int, required=True) +@click.option("--trainer-ep", type=str, required=True) +@click.option( + "--main-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--startup-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--send-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--recv-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--feed-names", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--target-names", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--strategy", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--feeds", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--config", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--algorithm-config", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +def fl_trainer( + job_id: str, + task_id: str, + trainer_id: int, + trainer_ep: str, + scheduler_ep: str, + main_program, + startup_program, + send_program, + recv_program, + feed_names, + target_names, + strategy, + feeds, + config, + algorithm_config, +): + import numpy as np + import paddle.fluid as fluid + from visualfl.utils import data_loader + + from ppdet.data import create_reader + from ppdet.utils import checkpoint + from ppdet.utils.check import check_config, check_version + + logging.basicConfig( + filename="trainer.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG, + ) + + try: + with open(config) as f: + config_json = json.load(f) + max_iter = config_json["max_iter"] + device = config_json.get("device", "cpu") + use_vdl = config_json.get("use_vdl", False) + resume_checkpoint = config_json.get("resume", False) + save_model_dir = "model" + save_checkpoint_dir = "checkpoint" + + + logging.debug(f"training program begin") + trainer = FedAvgTrainer(scheduler_ep=scheduler_ep, trainer_ep=trainer_ep) + logging.debug(f"job program loading") + trainer.load_job( + main_program=main_program, + startup_program=startup_program, + send_program=send_program, + recv_program=recv_program, + feed_names=feed_names, + target_names=target_names, + strategy=strategy, + ) + logging.debug(f"job program loaded") + place = fluid.CPUPlace() if device != "cuda" else fluid.CUDAPlace(0) + + logging.debug(f"trainer starting with place {place}") + trainer.start(place) + logging.debug(f"trainer stared") + + with open(algorithm_config) as f: + algorithm_config_json = json.load(f) + + download_url = algorithm_config_json.get("download_url") + data_name = algorithm_config_json.get("data_name") + + data_dir = data_loader.job_download(download_url, job_id, get_data_dir()) + labelpath = os.path.join(data_dir, "label_list.txt") + TaskDao(task_id).save_task_result({"label_path":labelpath}, ComponentName.DETECTION, TaskResultType.LABEL) + cfg = merger_algorithm_config(algorithm_config_json,os.path.basename(data_dir)) + check_config(cfg) + check_version() + + logging.debug(f"loading data") + feed_list = trainer.load_feed_list(feeds) + feeder = fluid.DataFeeder(feed_list=feed_list, place=place) + logging.debug(f"data loader ready") + + epoch_id = -1 + vdl_loss_step = 0 + # vdl_mAP_step = 0 + TaskDao(task_id).init_task_progress(max_iter) + TaskDao(task_id).start_task() + if resume_checkpoint: + try: + epoch_id = TaskDao(task_id).get_task_progress() + # vdl_loss_step = checkpoint.global_step() + # epoch_id = round(vdl_loss_step / max_iter) + checkpoint.load_checkpoint(trainer.exe, trainer._main_program, f"checkpoint/{epoch_id}") + logging.debug(f"use_checkpoint epoch_id: {epoch_id}") + except Exception as e: + logging.error(f"task id {task_id} train error {e}") + + # redirect dataset path to VisualFL/data + cfg.TrainReader["dataset"].dataset_dir = os.path.join( + get_data_dir(), cfg.TrainReader["dataset"].dataset_dir + ) + + data_loader = create_reader( + cfg.TrainReader, max_iter, cfg, devices_num=1, num_trainers=1 + ) + logging.error(f"{cfg.TrainReader['dataset']}") + + if use_vdl: + vdl_writer = LogWriter("vdl_log") + + while epoch_id < max_iter: + epoch_id += 1 + if not trainer.scheduler_agent.join(epoch_id): + logging.debug(f"not join, waiting next round") + continue + + logging.debug(f"epoch {epoch_id} start train") + + for step_id, data in enumerate(data_loader()): + outs = trainer.run(feeder.feed(data), fetch=trainer._target_names) + if use_vdl: + stats = { + k: np.array(v).mean() for k, v in zip(trainer._target_names, outs) + } + for loss_name, loss_value in stats.items(): + vdl_writer.add_scalar(loss_name, loss_value, vdl_loss_step) + save_data_to_db(task_id, loss_name, loss_value,vdl_loss_step,ComponentName.DETECTION) + vdl_loss_step += 1 + logging.debug(f"step: {vdl_loss_step}, outs: {outs}") + + # save model + logging.debug(f"saving model at {epoch_id}-th epoch") + trainer.save_model(os.path.join(save_model_dir,str(epoch_id))) + + # info scheduler + trainer.scheduler_agent.finish() + checkpoint.save(trainer.exe, trainer._main_program, os.path.join(save_checkpoint_dir,str(epoch_id))) + TaskDao(task_id).add_task_progress(1) + + TaskDao(task_id).update_task_status(TaskStatus.SUCCESS) + TaskDao(task_id).finish_task_progress() + TaskDao(task_id).update_serving_model(type=TaskResultType.LOSS) + logging.debug(f"reach max iter, finish training") + except Exception as e: + logging.error(f"task id {task_id} train error: {e}") + TaskDao(task_id).update_task_status(TaskStatus.ERROR,str(e)) + raise Exception(f"train error as task id {task_id} ") + + +if __name__ == "__main__": + fl_trainer() diff --git a/VisualFL/visualfl/algorithm/paddle_detection/infer.py b/VisualFL/visualfl/algorithm/paddle_detection/infer.py new file mode 100644 index 000000000..26a9109ec --- /dev/null +++ b/VisualFL/visualfl/algorithm/paddle_detection/infer.py @@ -0,0 +1,322 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os, sys +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +import json +import glob +import numpy as np +import six +from PIL import Image, ImageOps + +from paddle import fluid + +from ppdet.core.workspace import load_config, merge_config, create + +from ppdet.utils.eval_utils import parse_fetches +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_gpu, check_version, check_config, enable_static_mode +from ppdet.utils.visualizer import visualize_results +import ppdet.utils.checkpoint as checkpoint +from ppdet.data.reader import create_reader +from visualfl.db.task_dao import TaskDao +from visualfl.utils.consts import TaskResultType,ComponentName + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def get_save_image_name(output_dir, image_path): + """ + Get save image name from source image path. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + image_name = os.path.split(image_path)[-1] + name, ext = os.path.splitext(image_name) + return os.path.join(output_dir, "{}".format(name)) + ext + + +def get_test_images(infer_dir, infer_img): + """ + Get image path list in TEST mode + """ + assert infer_img is not None or infer_dir is not None, \ + "--infer_img or --infer_dir should be set" + assert infer_img is None or os.path.isfile(infer_img), \ + "{} is not a file".format(infer_img) + assert infer_dir is None or os.path.isdir(infer_dir), \ + "{} is not a directory".format(infer_dir) + + # infer_img has a higher priority + if infer_img and os.path.isfile(infer_img): + return [infer_img] + + images = set() + infer_dir = os.path.abspath(infer_dir) + assert os.path.isdir(infer_dir), \ + "infer_dir {} is not a directory".format(infer_dir) + exts = ['jpg', 'jpeg', 'png', 'bmp'] + exts += [ext.upper() for ext in exts] + for ext in exts: + images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) + images = list(images) + + assert len(images) > 0, "no image found in {}".format(infer_dir) + logging.info("Found {} inference images in total.".format(len(images))) + + return images + + +def main(): + cfg = load_config(FLAGS.config) + + merge_config({'use_gpu':FLAGS.use_gpu,'weights':FLAGS.weights}) + check_config(cfg) + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + # check if paddlepaddle version is satisfied + check_version() + + main_arch = cfg.architecture + + model = TaskDao(FLAGS.task_id).get_task_result(TaskResultType.INFER) + infer_result = {} + if model: + infer_result = json.loads(model.result) + infer_result.update({"status": "running"}) + TaskDao(task_id=FLAGS.task_id).save_task_result(infer_result, ComponentName.DETECTION, type=TaskResultType.INFER) + + dataset = cfg.TestReader['dataset'] + test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) + dataset.set_images(test_images) + task_result = TaskDao(FLAGS.task_id).get_task_result(TaskResultType.LABEL) + if task_result: + dataset.anno_path = json.loads(task_result.result).get("label_path") + + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + model = create(main_arch) + + startup_prog = fluid.Program() + infer_prog = fluid.Program() + with fluid.program_guard(infer_prog, startup_prog): + with fluid.unique_name.guard(): + inputs_def = cfg['TestReader']['inputs_def'] + inputs_def['iterable'] = True + feed_vars, loader = model.build_inputs(**inputs_def) + test_fetches = model.test(feed_vars) + infer_prog = infer_prog.clone(True) + + reader = create_reader(cfg.TestReader, devices_num=1) + loader.set_sample_list_generator(reader, place) + + exe.run(startup_prog) + if cfg.weights: + checkpoint.load_params(exe, infer_prog, cfg.weights) + + # parse infer fetches + assert cfg.metric in ['COCO', 'VOC', 'OID', 'WIDERFACE'], \ + "unknown metric type {}".format(cfg.metric) + extra_keys = [] + if cfg['metric'] in ['COCO', 'OID']: + extra_keys = ['im_info', 'im_id', 'im_shape'] + if cfg['metric'] == 'VOC' or cfg['metric'] == 'WIDERFACE': + extra_keys = ['im_id', 'im_shape'] + keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys) + + # parse dataset category + if cfg.metric == 'COCO': + from ppdet.utils.coco_eval import bbox2out, mask2out, segm2out, get_category_info + if cfg.metric == 'OID': + from ppdet.utils.oid_eval import bbox2out, get_category_info + if cfg.metric == "VOC": + from ppdet.utils.voc_eval import bbox2out, get_category_info + if cfg.metric == "WIDERFACE": + from ppdet.utils.widerface_eval_utils import bbox2out, lmk2out, get_category_info + + anno_file = dataset.anno_path + with_background = dataset.with_background + use_default_label = dataset.use_default_label + + clsid2catid, catid2name = get_category_info(anno_file, with_background, + use_default_label) + + # whether output bbox is normalized in model output layer + is_bbox_normalized = False + if hasattr(model, 'is_bbox_normalized') and \ + callable(model.is_bbox_normalized): + is_bbox_normalized = model.is_bbox_normalized() + + # use VisualDL to log image + if FLAGS.use_vdl: + assert six.PY3, "VisualDL requires Python >= 3.5" + from visualdl import LogWriter + vdl_writer = LogWriter(FLAGS.vdl_log_dir) + vdl_image_step = 0 + vdl_image_frame = 0 # each frame can display ten pictures at most. + + + image_infers = [] + imid2path = dataset.get_imid2path() + for iter_id, data in enumerate(loader()): + outs = exe.run(infer_prog, + feed=data, + fetch_list=values, + return_numpy=False) + res = { + k: (np.array(v), v.recursive_sequence_lengths()) + for k, v in zip(keys, outs) + } + logging.info('Infer iter {}'.format(iter_id)) + if 'TTFNet' in cfg.architecture: + res['bbox'][1].append([len(res['bbox'][0])]) + if 'CornerNet' in cfg.architecture: + from ppdet.utils.post_process import corner_post_process + post_config = getattr(cfg, 'PostProcess', None) + corner_post_process(res, post_config, cfg.num_classes) + + bbox_results = None + mask_results = None + segm_results = None + lmk_results = None + if 'bbox' in res: + bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized) + if 'mask' in res: + mask_results = mask2out([res], clsid2catid, + model.mask_head.resolution) + if 'segm' in res: + segm_results = segm2out([res], clsid2catid) + if 'landmark' in res: + lmk_results = lmk2out([res], is_bbox_normalized) + + # visualize result + im_ids = res['im_id'][0] + for im_id in im_ids: + image_path = imid2path[int(im_id)] + image = Image.open(image_path).convert('RGB') + image = ImageOps.exif_transpose(image) + + # use VisualDL to log original image + if FLAGS.use_vdl: + original_image_np = np.array(image) + vdl_writer.add_image( + "original/frame_{}".format(vdl_image_frame), + original_image_np, vdl_image_step) + + image = visualize_results(image, + int(im_id), catid2name, + FLAGS.draw_threshold, bbox_results, + mask_results, segm_results, lmk_results) + + # use VisualDL to log image with bbox + if FLAGS.use_vdl: + infer_image_np = np.array(image) + vdl_writer.add_image("bbox/frame_{}".format(vdl_image_frame), + infer_image_np, vdl_image_step) + vdl_image_step += 1 + if vdl_image_step % 10 == 0: + vdl_image_step = 0 + vdl_image_frame += 1 + + save_name = get_save_image_name(FLAGS.output_dir, image_path) + logging.info("Detection bbox results save in {}".format(save_name)) + image.save(save_name, quality=95) + # xmin, ymin, w, h + for bbox in bbox_results: + category_id = bbox["category_id"] + bbox["category_name"] = catid2name[category_id] + bbox_dict ={"image":os.path.basename(image_path),"bbox_results":bbox_results} + image_infers.append(bbox_dict) + infer_result["result"] = image_infers + infer_result["status"] = "finish" + TaskDao(task_id=FLAGS.task_id).save_task_result(infer_result,ComponentName.DETECTION, type=TaskResultType.INFER) + + + +if __name__ == '__main__': + def str2bool(v): + return v.lower() in ("true", "t", "1") + enable_static_mode() + parser = ArgsParser() + parser.add_argument( + "--task_id", + type=str, + default=None) + parser.add_argument( + "--infer_dir", + type=str, + default=None, + help="Directory for images to perform inference on.") + parser.add_argument( + "--infer_img", + type=str, + default=None, + help="Image path, has higher priority over --infer_dir") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory for storing the output visualization files.") + parser.add_argument( + "--weights", + type=str, + default=None, + help="weights path") + parser.add_argument( + "--draw_threshold", + type=float, + default=0.5, + help="Threshold to reserve the result for visualization.") + parser.add_argument( + "--use_gpu", + type=str2bool, + default=False, + help="whether to use gpu.") + parser.add_argument( + "--use_vdl", + type=bool, + default=False, + help="whether to record the data to VisualDL.") + parser.add_argument( + '--vdl_log_dir', + type=str, + default="vdl_log_dir/image", + help='VisualDL logging directory for image.') + FLAGS = parser.parse_args() + main() diff --git a/VisualFL/visualfl/client/__init__.py b/VisualFL/visualfl/client/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/client/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/client/apply.py b/VisualFL/visualfl/client/apply.py new file mode 100644 index 000000000..dbc760899 --- /dev/null +++ b/VisualFL/visualfl/client/apply.py @@ -0,0 +1,96 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import urllib.parse +from pathlib import Path + +import click +import aiohttp +import asyncio +import json + +import yaml + +from visualfl import extensions + + +@click.group() +def cli(): + ... + + +def post(endpoint, path, json_data): + async def post_co(): + url = urllib.parse.urljoin(f"http://{endpoint}", path) + async with aiohttp.ClientSession() as session: + async with session.post( + url, json=json_data + ) as resp: + print(resp.status) + print(json.dumps(await resp.json(), indent=2)) + resp.raise_for_status() + + loop = asyncio.get_event_loop() + loop.run_until_complete(post_co()) + + + +@cli.command() +@click.option( + "--config", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--endpoint", + type=str, + required=True, +) +def apply(endpoint, config): + base = Path(config) + with base.open("r") as f: + config_json = yaml.load(f, yaml.Loader) + job_id = config_json.get("job_id") + task_id = config_json.get("task_id") + job_type = config_json.get("job_type") + role = config_json.get("role") + member_id = config_json.get("member_id") + callback_url = config_json.get("callback_url") + env = config_json.get("env") + data_set = config_json.get("data_set") + algorithm_config = config_json.get("algorithm_config") + + extensions.get_job_schema_validator(job_type).validate(env) + + post( + endpoint, + "apply", + dict( + job_id=job_id, + task_id=task_id, + job_type=job_type, + role=role, + member_id=member_id, + env=env, + data_set=data_set, + algorithm_config=algorithm_config, + callback_url=callback_url + ) + ) + + +if __name__ == "__main__": + cli() diff --git a/VisualFL/visualfl/client/cluster_manager.py b/VisualFL/visualfl/client/cluster_manager.py new file mode 100644 index 000000000..b2ceeb24e --- /dev/null +++ b/VisualFL/visualfl/client/cluster_manager.py @@ -0,0 +1,51 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import asyncio + +import click + + +@click.command(name="start-manager") +@click.option("--port", type=int, required=True, help="cluster manager address") +def start_manager(port): + """ + start manager + """ + from visualfl.utils import logger + + logger.set_logger("manager") + from visualfl.manager import ClusterManager + + loop = asyncio.get_event_loop() + manager = ClusterManager( + port=port, + ) + try: + loop.run_until_complete(manager.start()) + click.echo(f"cluster manager start") + loop.run_forever() + except KeyboardInterrupt: + click.echo("keyboard interrupted") + + finally: + loop.run_until_complete(manager.stop()) + click.echo(f"cluster manager server stop") + loop.close() + + +if __name__ == "__main__": + start_manager() diff --git a/VisualFL/visualfl/client/cluster_worker.py b/VisualFL/visualfl/client/cluster_worker.py new file mode 100644 index 000000000..cfe937afb --- /dev/null +++ b/VisualFL/visualfl/client/cluster_worker.py @@ -0,0 +1,78 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import asyncio + +import click + + +@click.command(name="start-worker") +@click.option("--name", type=str, required=True, help="worker name") +@click.option("--worker-ip", type=str, required=True, help="worker ip") +@click.option("--max-tasks", type=int, required=True, help="max tasks") +@click.option("--port-start", type=int, required=True, help="port start") +@click.option("--port-end", type=int, required=True, help="port end") +@click.option( + "--manager-address", type=str, required=True, help="cluster manager address" +) +@click.option( + "--data-base-dir", + type=click.Path(exists=True, file_okay=False, dir_okay=True), + required=False, + help="data base dir", +) +def start_worker( + name, + worker_ip, + max_tasks, + manager_address, + port_start, + port_end, + data_base_dir, +): + """ + start worker + """ + from visualfl.utils import logger + + logger.set_logger(f"worker-{worker_ip}") + from visualfl.worker import ClusterWorker + + loop = asyncio.get_event_loop() + worker = ClusterWorker( + worker_id=name, + worker_ip=worker_ip, + max_tasks=max_tasks, + manager_address=manager_address, + port_start=port_start, + port_end=port_end, + data_dir=data_base_dir, + ) + try: + loop.run_until_complete(worker.start()) + click.echo(f"worker {name} start") + loop.run_until_complete(worker.wait_for_termination()) + except KeyboardInterrupt: + click.echo("keyboard interrupted") + finally: + loop.run_until_complete(worker.stop()) + loop.run_until_complete(asyncio.sleep(1)) + loop.close() + click.echo(f"worker {name} stop") + + +if __name__ == "__main__": + start_worker() diff --git a/VisualFL/visualfl/client/infer.py b/VisualFL/visualfl/client/infer.py new file mode 100644 index 000000000..aedac2c4c --- /dev/null +++ b/VisualFL/visualfl/client/infer.py @@ -0,0 +1,96 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import urllib.parse +from pathlib import Path + +import click +import aiohttp +import asyncio +import json + +import yaml + +from visualfl import extensions + + +@click.group() +def cli(): + ... + + +def post(endpoint, path, json_data): + async def post_co(): + url = urllib.parse.urljoin(f"http://{endpoint}", path) + async with aiohttp.ClientSession() as session: + async with session.post( + url, json=json_data + ) as resp: + print(resp.status) + print(json.dumps(await resp.json(), indent=2)) + resp.raise_for_status() + + loop = asyncio.get_event_loop() + loop.run_until_complete(post_co()) + + + +@cli.command() +@click.option( + "--config", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--endpoint", + type=str, + required=True, +) +def infer(endpoint, config): + base = Path(config) + with base.open("r") as f: + config_json = yaml.load(f, yaml.Loader) + job_id = config_json.get("job_id") + task_id = config_json.get("task_id") + job_type = config_json.get("job_type") + role = config_json.get("role") + member_id = config_json.get("member_id") + callback_url = config_json.get("callback_url") + env = config_json.get("env") + data_set = config_json.get("data_set") + algorithm_config = config_json.get("algorithm_config") + + extensions.get_job_schema_validator(job_type).validate(env) + + post( + endpoint, + "infer", + dict( + job_id=job_id, + task_id=task_id, + job_type=job_type, + role=role, + member_id=member_id, + env=env, + data_set=data_set, + algorithm_config=algorithm_config, + callback_url=callback_url + ) + ) + + +if __name__ == "__main__": + cli() diff --git a/VisualFL/visualfl/client/master.py b/VisualFL/visualfl/client/master.py new file mode 100644 index 000000000..62de06636 --- /dev/null +++ b/VisualFL/visualfl/client/master.py @@ -0,0 +1,60 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import asyncio + +import click + + +@click.command() +@click.option("--member-id", type=str, required=True, help="member id") +@click.option("--submitter-port", type=int, required=True, help="submitter port") +@click.option( + "--cluster-address", type=str, required=True, help="cluster manager address" +) +@click.option( + "--local", type=bool, required=False, help="is local template" +) +def start_master(member_id, submitter_port, cluster_address,local=False): + """ + start master + """ + from visualfl.utils import logger + + logger.set_logger(f"master-{member_id}") + from visualfl.master import Master + + loop = asyncio.get_event_loop() + master = Master( + member_id=member_id, + cluster_address=cluster_address, + rest_port=submitter_port, + local=local + ) + try: + loop.run_until_complete(master.start()) + click.echo(f"master started") + loop.run_forever() + except KeyboardInterrupt: + click.echo("keyboard interrupted") + finally: + loop.run_until_complete(master.stop()) + click.echo(f"master stop") + loop.close() + + +if __name__ == "__main__": + start_master() diff --git a/VisualFL/visualfl/client/submitter.py b/VisualFL/visualfl/client/submitter.py new file mode 100644 index 000000000..52d053dcc --- /dev/null +++ b/VisualFL/visualfl/client/submitter.py @@ -0,0 +1,100 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import urllib.parse +from pathlib import Path + +import click +import aiohttp +import asyncio +import json + +import yaml + +from visualfl import extensions + + +@click.group() +def cli(): + ... + + +def post(endpoint, path, json_data): + async def post_co(): + url = urllib.parse.urljoin(f"http://{endpoint}", path) + async with aiohttp.ClientSession() as session: + async with session.post( + url, json=json_data + ) as resp: + print(resp.status) + print(json.dumps(await resp.json(), indent=2)) + resp.raise_for_status() + + loop = asyncio.get_event_loop() + loop.run_until_complete(post_co()) + + +@cli.command() +@click.option( + "--config", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--endpoint", + type=str, + required=True, +) +def submit(endpoint, config): + + base = Path(config) + with base.open("r") as f: + config_json = yaml.load(f, yaml.Loader) + job_id = config_json.get("job_id") + task_id = config_json.get("task_id") + job_type = config_json.get("job_type") + role = config_json.get("role") + member_id = config_json.get("member_id") + env = config_json.get("env") + data_set = config_json.get("data_set") + algorithm_config = config_json.get("algorithm_config") + + # algorithm_config_path = base.parent.joinpath( + # config_json.get("algorithm_config") + # ).absolute() + # with algorithm_config_path.open("r") as f: + # algorithm_config_string = f.read() + + # extensions.get_job_schema_validator(job_type).validate(env) + post( + endpoint, + "submit", + dict( + job_id=job_id, + task_id=task_id, + job_type=job_type, + role=role, + member_id=member_id, + env=env, + data_set=data_set, + algorithm_config=algorithm_config, + ), + ) + + + +if __name__ == "__main__": + cli() diff --git a/VisualFL/visualfl/db/db_models.py b/VisualFL/visualfl/db/db_models.py new file mode 100644 index 000000000..1622f8908 --- /dev/null +++ b/VisualFL/visualfl/db/db_models.py @@ -0,0 +1,245 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import operator +import sys + +from peewee import * +from peewee import CharField +from playhouse.pool import PooledMySQLDatabase + +from visualfl.utils import consts +from visualfl.utils.conf_utils import get_comm_config, get_env_config +from visualfl.utils.conf_utils import str2bool + +# Database Connectivity +host = get_comm_config(consts.COMM_CONF_KEY_MYSQL_HOST) +password = get_comm_config(consts.COMM_CONF_KEY_MYSQL_PASSWORD) +port = int(get_comm_config(consts.COMM_CONF_KEY_MYSQL_PORT)) +user = get_comm_config(consts.COMM_CONF_KEY_MYSQL_USERNAME) +database = get_comm_config(consts.COMM_CONF_KEY_MYSQL_DATABASE) +is_local = str2bool(get_comm_config(consts.COMM_CONF_IS_LOCAL)) + +# Environment variable +env_host = get_env_config(consts.COMM_CONF_KEY_MYSQL_HOST) +env_password = get_env_config(consts.COMM_CONF_KEY_MYSQL_PASSWORD) +env_port = get_env_config(consts.COMM_CONF_KEY_MYSQL_PORT) +if env_port: + env_port = int(env_port) +env_user = get_env_config(consts.COMM_CONF_KEY_MYSQL_USERNAME) +env_database = get_env_config(consts.COMM_CONF_KEY_MYSQL_DATABASE) + +settings = {'host': env_host or host, + 'password': env_password or password, + 'port': env_port or port, + 'user': env_user or user, + 'max_connections': 100 + } + +DB = PooledMySQLDatabase(env_database or database, **settings) + + +class ModelBase(Model): + id = CharField(primary_key=True) + created_by = CharField(null=True) + created_time = DateTimeField() + updated_by = CharField(null=True) + updated_time = DateTimeField() + + class Meta: + database = DB + + def to_json(self): + return self.__dict__['__data__'] + + +# GlobalSetting +class GlobalSetting(object): + """ + Get global setting + + Due to the adjustment of the GlobalSetting table, + it is now necessary to obtain the relevant configuration from GlobalConfig + """ + + @staticmethod + def get_member_id(): + from common.python.db.global_config_dao import GlobalConfigDao + return GlobalConfigDao.getMemberInfo().member_id + + @staticmethod + def get_member_name(): + from common.python.db.global_config_dao import GlobalConfigDao + return GlobalConfigDao.getMemberInfo().member_name + + +# Job +class Job(ModelBase): + created_by = CharField(null=True) + created_time = DateTimeField(constraints=[SQL("DEFAULT CURRENT_TIMESTAMP")]) + federated_learning_type = CharField() + finish_time = DateTimeField(null=True) + flow_id = CharField() + graph = TextField(null=True) + has_modeling_result = IntegerField(constraints=[SQL("DEFAULT 0")]) + id = CharField(primary_key=True) + job_id = CharField() + message = TextField(null=True) + my_role = CharField() + name = CharField() + progress = IntegerField(constraints=[SQL("DEFAULT 0")]) + progress_updated_time = DateTimeField(null=True) + project_id = CharField() + remark = TextField(null=True) + star = IntegerField(constraints=[SQL("DEFAULT 0")]) + start_time = DateTimeField(null=True) + status = CharField(constraints=[SQL("DEFAULT 'created'")]) + status_updated_time = DateTimeField(null=True) + updated_by = CharField(null=True) + updated_time = DateTimeField(null=True) + job_middle_data_is_clear = IntegerField(constraints=[SQL("DEFAULT 0")]) + + class Meta: + db_table = 'job' + indexes = ( + (('job_id', 'my_role'), True), + ) + + @staticmethod + def getByParam(**kwargs): + with DB.connection_context(): + filters = [] + for n, v in kwargs.items(): + attr_name = n + if hasattr(Job, attr_name): + filters.append(operator.attrgetter(n)(Job) == v) + if filters: + jobs = Job.select().where(*filters) + return [job for job in jobs] + else: + return [] + + + + +# Task +class Task(ModelBase): + created_by = CharField(null=True) + created_time = DateTimeField(constraints=[SQL("DEFAULT CURRENT_TIMESTAMP")]) + deep = IntegerField(null=True) + dependence_list = CharField(null=True) + error_cause = TextField(null=True) + finish_time = DateTimeField(null=True) + flow_id = CharField() + flow_node_id = CharField() + id = CharField(primary_key=True) + job_id = CharField() + message = CharField(null=True) + name = CharField() + parent_task_id_list = CharField(null=True) + pid = IntegerField(null=True) + position = IntegerField(null=True) + project_id = CharField(null=True) + role = CharField(null=True) + spend = IntegerField(null=True) + start_time = DateTimeField(null=True) + status = CharField() + task_conf = TextField() + task_id = CharField() + task_type = CharField() + updated_by = CharField(null=True) + updated_time = DateTimeField(null=True) + + class Meta: + table_name = 'task' + + @staticmethod + def getByParam(**kwargs): + with DB.connection_context(): + filters = [] + for n, v in kwargs.items(): + attr_name = n + if hasattr(Task, attr_name): + filters.append(operator.attrgetter(n)(Task) == v) + if filters: + tasks = Task.select().where(*filters) + return [task for task in tasks] + else: + return [] + + + +class TaskResult(ModelBase): + """ + Component result save + """ + component_type = CharField() + created_by = CharField(null=True) + created_time = DateTimeField(constraints=[SQL("DEFAULT CURRENT_TIMESTAMP")], index=True) + flow_id = CharField() + flow_node_id = CharField() + id = CharField(primary_key=True) + job_id = CharField() + name = CharField() + project_id = CharField(null=True) + result = TextField() + role = CharField() + serving_model = IntegerField(constraints=[SQL("DEFAULT 0")]) + task_id = CharField() + type = CharField() + updated_by = CharField(null=True) + updated_time = DateTimeField(null=True) + + class Meta: + table_name = 'task_result' + indexes = ( + (('task_id', 'type', 'role'), True), + ) + + + +class TaskProgress(ModelBase): + """ + Task progress + """ + created_by = CharField(null=True) + created_time = DateTimeField() + expect_end_time = DateTimeField(null=True) + expect_work_amount = IntegerField(null=True) + flow_id = CharField() + flow_node_id = CharField() + id = CharField(primary_key=True) + job_id = CharField() + progress = IntegerField(null=True) + progress_rate = DecimalField(null=True) + project_id = CharField(null=True) + really_work_amount = IntegerField(null=True) + role = CharField() + spend = IntegerField(null=True) + task_id = CharField() + task_type = CharField() + updated_by = CharField(null=True) + updated_time = DateTimeField(null=True) + pid_success = IntegerField(null=True) + + class Meta: + table_name = 'task_progress' + indexes = ( + (('task_id', 'role'), True), + ) + +if __name__ == '__main__': + pass diff --git a/VisualFL/visualfl/db/task_dao.py b/VisualFL/visualfl/db/task_dao.py new file mode 100644 index 000000000..b3852b54d --- /dev/null +++ b/VisualFL/visualfl/db/task_dao.py @@ -0,0 +1,421 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from visualfl.db.db_models import DB, Task, TaskResult,TaskProgress,is_local +from visualfl.utils.core_utils import current_datetime,get_commit_id +import datetime +import json +from visualfl.utils.logger import Logger +from visualfl.utils.consts import TaskStatus +import logging +import numpy + +class Encoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, numpy.integer): + return int(obj) + elif isinstance(obj, numpy.floating): + return float(obj) + elif isinstance(obj, numpy.ndarray): + return obj.tolist() + else: + return super(Encoder, self).default(obj) + +class TaskDao(Logger): + + def __init__(self,task_id): + self._task_id = task_id + + def start_task(self): + """ + start task + """ + try: + if is_local: + return + with DB.connection_context(): + task = Task.select().where( + Task.task_id == self._task_id + ).get() + + task.start_time = current_datetime() + task.updated_time = current_datetime() + task.status = TaskStatus.RUNNING + task.save() + except Exception as e: + self.exception(e) + self.error(f"save start task {self._task_id} error as {e} ") + + def update_task_status(self, status,message=None): + """ + Update task status + """ + try: + if is_local: + return + with DB.connection_context(): + task = Task.select().where( + Task.task_id == self._task_id + ).get() + + task.status = status + task.message = message + task.updated_time = current_datetime() + task.finish_time = current_datetime() + task.save() + except Exception as e: + self.exception(e) + self.error(f"update task {self._task_id} status to {status} error as {e} ") + + def get_task_result(self, result_type): + """ + Get task result + + Parameters + ---------- + result_type + Returns + ------- + + """ + if is_local: + return + + with DB.connection_context(): + + where_condition = [TaskResult.task_id == self._task_id,TaskResult.type == result_type] + + models = TaskResult.select().where(*tuple(where_condition)) + + if models: + return models[0] + else: + return None + + def save_task_result(self,task_result: dict,component_name: str,type: str): + """ + Save task result + + Parameters + ---------- + task_result + result_type + component_name:str + Component name, special case can be specified separately + + Returns + ------- + + """ + try: + if is_local: + return + + with DB.connection_context(): + models = TaskResult.select().where( + TaskResult.task_id == self._task_id, + TaskResult.type == type + ) + + tasks = Task.select().where( + Task.task_id == self._task_id, + ) + + # Compatible with local test without task information + if len(tasks) != 0: + task = tasks[0] + else: + return + + is_insert = True + if models: + model = models[0] + is_insert = False + else: + model = TaskResult() + model.created_time = datetime.datetime.now() + + model.job_id = task.job_id + model.name = component_name + model.task_id = self._task_id + model.role = task.role + model.type = type + model.updated_time = datetime.datetime.now() + model.result = json.dumps(task_result,cls=Encoder) + model.component_type = component_name + model.flow_id = task.flow_id + model.flow_node_id = task.flow_node_id + model.project_id = task.project_id + + if is_insert: + model.id = get_commit_id() + model.save(force_insert=True) + else: + model.save() + return model + except Exception as e: + logging.error(f"save task {self._task_id} result error as {e} ") + + def update_serving_model(self,type: str): + """ + Update serving model + """ + try: + if is_local: + return + + with DB.connection_context(): + models = TaskResult.select().where( + TaskResult.task_id == self._task_id, + TaskResult.type == type + ) + + if models: + model = models[0] + else: + return + + model.serving_model = 1 + model.save() + return model + except Exception as e: + logging.error(f"udate serving model error as {e},with task id : {self._task_id} ") + + def calc_progress(self,model: TaskProgress) -> TaskProgress: + """ + + Calculation progress + + According to the total engineering quantity, the current completion engineering quantity calculation progress + If there is actual engineering quantity, calculate the percentage based on actual engineering quantity, that is, it is finished + Otherwise, calculate the progress percentage according to the estimated engineering quantity + + Parameters + ---------- + model + + Returns + ------- + + """ + if is_local: + return + + if model.progress is None: + model.progress = 0 + if model.progress > model.expect_work_amount: + model.progress = model.expect_work_amount + + work_amount = model.really_work_amount or model.expect_work_amount + model.progress_rate = round(model.progress / work_amount * 100, 2) + if model.progress_rate > 100: + model.progress_rate = 100 + + if model.updated_time is not None and model.progress_rate > 0: + model.spend = int((model.updated_time - model.created_time).total_seconds() * 1000) + need_time = int(model.spend * 100 / model.progress_rate - model.spend) + model.expect_end_time = model.updated_time + datetime.timedelta(milliseconds=need_time) + + return model + + def init_task_progress(self,work_amount: int): + """ + + Initialize the total engineering quantity of the task schedule + + eg. Logistic regression algorithm parameters need to run 300 iterations, + then work_amount can be set to 300, then after each iteration is completed, + the current work amount needs to be +1 + + Parameters + ---------- + work_amount:int + Total engineering + + Returns + ------- + + """ + try: + if is_local: + return + + with DB.connection_context(): + model = TaskProgress.get_or_none( + TaskProgress.task_id == self._task_id, + ) + + is_insert = True + + if model: + is_insert = False + # reset + model.progress = 0 + model.really_work_amount = None + model.created_time = datetime.datetime.now() + model.updated_time = None + model.expect_end_time = None + model.spend = None + + else: + model = TaskProgress() + model.id = get_commit_id() + model.progress = 0 + model.created_time = datetime.datetime.now() + + # get task info + task_info = Task.get_or_none( + Task.task_id == self._task_id, + ) + if task_info: + model.flow_id = task_info.flow_id + model.flow_node_id = task_info.flow_node_id + model.project_id = task_info.project_id + model.job_id = task_info.job_id + model.task_id = self._task_id + model.role = task_info.role + model.task_type = task_info.task_type + else: + return + + model.expect_work_amount = work_amount + self.calc_progress(model) + model.save(force_insert=is_insert) + except Exception as e: + self.exception(e) + logging.error(f"init task {self._task_id} progress error as {e} ") + + def set_task_progress(self, work_amount: int): + """ + Update the progress according to the specified work amount + + Parameters + ---------- + work_amount:int + The amount of work currently completed + + Returns + ------- + + """ + try: + if is_local: + return + + if work_amount >= 0: + with DB.connection_context(): + model = TaskProgress.select().where( + TaskProgress.task_id == self._task_id, + ).get() + + model.progress = work_amount + model.updated_time = datetime.datetime.now() + self.calc_progress(model) + model.save() + except Exception as e: + self.exception(e) + self.error(f"set task {self._task_id} progress error as {e} ") + + def add_task_progress(self, step: int = 1): + """ + + Increase progress according to step + + Parameters + ---------- + step:int + + Returns + ------- + + """ + try: + if is_local: + return + + work_amount = 0 + with DB.connection_context(): + model = TaskProgress.select().where( + TaskProgress.task_id == self._task_id, + ).get() + if model.progress is not None: + work_amount = model.progress + step + else: + work_amount = step + + # Reserve one amount for use when the finish call + if work_amount > model.expect_work_amount - 1: + work_amount = model.expect_work_amount - 1 + + self.set_task_progress(work_amount) + except Exception as e: + self.exception(e) + logging.error(f"add task {self._task_id} progress error as {e} ") + + def get_task_progress(self): + """ + + get task progress + + Parameters + ---------- + + Returns + ------- + + """ + if is_local: + return + + with DB.connection_context(): + model = TaskProgress.select().where( + TaskProgress.task_id == self._task_id, + ).get() + if model.progress is not None: + return model.progress + else: + return None + + + def finish_task_progress(self): + """ + Finish task progress + + Returns + ------- + + """ + try: + if is_local: + return + + with DB.connection_context(): + model = TaskProgress.get_or_none( + TaskProgress.task_id == self._task_id, + ) + if model: + model.progress = model.progress + 1 + model.really_work_amount = model.progress + + if model.really_work_amount > model.expect_work_amount: + model.really_work_amount = model.expect_work_amount + + model.updated_time = datetime.datetime.now() + self.calc_progress(model) + model.pid_success = 1 + model.save() + except Exception as e: + self.exception(e) + logging.error(f"finish task {self._task_id} progress error as {e} ") \ No newline at end of file diff --git a/VisualFL/visualfl/extensions/__init__.py b/VisualFL/visualfl/extensions/__init__.py new file mode 100644 index 000000000..6ba9c8233 --- /dev/null +++ b/VisualFL/visualfl/extensions/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from visualfl.extensions._extensions import ( + get_task_class, + get_job_class, + get_job_schema_validator, +) + +__all__ = ["get_job_class", "get_task_class", "get_job_schema_validator"] diff --git a/VisualFL/visualfl/extensions/_extensions.py b/VisualFL/visualfl/extensions/_extensions.py new file mode 100644 index 000000000..f83c1c9ba --- /dev/null +++ b/VisualFL/visualfl/extensions/_extensions.py @@ -0,0 +1,165 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import importlib +import json +from pathlib import Path +from typing import Optional, MutableMapping, Type + +import jsonschema +import yaml + +from visualfl.paddle_fl.job import PaddleFLJob as Job +from visualfl.paddle_fl.abs.task import Task +from visualfl.utils.exception import VisualFLExtensionException +from visualfl.utils.logger import Logger + + +class _ExtensionLoader(Logger): + _job_classes: Optional[MutableMapping[str, Type[Job]]] = None + _job_schema_validator: Optional[MutableMapping] = None + _task_classes: Optional[MutableMapping[str, Type[Task]]] = None + + @classmethod + def _load(cls): + if cls._job_classes is not None: + return cls + + cls._job_classes = {} + cls._task_classes = {} + cls._job_schema_validator = {} + path = Path(__file__).parent.joinpath( + "conf/extensions.yaml" + ) + cls.trace(f"load extension configuration from {path}") + with open(path) as f: + try: + extensions = yaml.safe_load(f) + except yaml.YAMLError as e: + raise VisualFLExtensionException("load extension failed") from e + finally: + cls.trace_lazy( + "extension configuration:\n{yaml_config}", + yaml_config=lambda: yaml.safe_dump(extensions, indent=2), + ) + + for extension_name, configs in extensions.items(): + cls.trace(f"loading extension: {extension_name}") + + cls.trace(f"loading job classes from extension {extension_name}") + for extension_job in configs.get("jobs", []): + job_module_name, job_cls_name = extension_job["loader"].split(":") + module = importlib.import_module(job_module_name) + job_cls = getattr(module, job_cls_name) + if not issubclass(job_cls, Job): + raise VisualFLExtensionException( + f"JobLoader expected, {job_cls} found" + ) + cls._job_classes[extension_job["name"]] = job_cls + + if "schema" in extension_job: + with path.parent.joinpath(extension_job["schema"]).open() as g: + schema = json.load(g) + print(schema) + cls._job_schema_validator[ + extension_job["name"] + ] = jsonschema.Draft7Validator(schema) + else: + cls._job_schema_validator[ + extension_job["name"] + ] = jsonschema.Draft7Validator({}) + + cls.trace(f"loading task classes from extension {extension_name}") + for extension_task in configs.get("tasks", []): + loader_module, loader_cls = extension_task["loader"].split(":") + module = importlib.import_module(loader_module) + loader = getattr(module, loader_cls) + if not issubclass(loader, Task): + raise VisualFLExtensionException( + f"JobLoader expected, {loader} found" + ) + cls._task_classes[extension_task["name"]] = loader + cls.trace_lazy( + "loading extensions done. job classes: {job_classes}, task classes: {task_classes}", + job_classes=lambda: cls._job_classes, + task_classes=lambda: cls._task_classes, + ) + return cls + + @classmethod + def _load_schema(cls): + if cls._job_schema_validator is not None: + return cls + + cls._job_schema_validator = {} + path = Path(__file__).parent.joinpath( + "conf/extensions.yaml" + ) + cls.trace(f"load extension configuration from {path}") + with open(path) as f: + try: + extensions = yaml.safe_load(f) + except yaml.YAMLError as e: + raise VisualFLExtensionException("load extension failed") from e + finally: + cls.trace_lazy( + "extension configuration:\n{yaml_config}", + yaml_config=lambda: yaml.safe_dump(extensions, indent=2), + ) + + for extension_name, configs in extensions.items(): + cls.trace(f"loading extension: {extension_name}") + + for extension_job in configs.get("jobs", []): + if "schema" in extension_job: + with path.parent.joinpath(extension_job["schema"]).open() as g: + schema = json.load(g) + cls._job_schema_validator[ + extension_job["name"] + ] = jsonschema.Draft7Validator(schema) + else: + cls._job_schema_validator[ + extension_job["name"] + ] = jsonschema.Draft7Validator({}) + return cls + + @classmethod + def get_job_class(cls, name): + return cls._load()._job_classes.get(name) + + @classmethod + def get_task_class(cls, name): + return cls._load()._task_classes.get(name) + + @classmethod + def get_job_schema_validator(cls, name): + return cls._load_schema()._job_schema_validator.get(name) + + +def get_job_class(name) -> Type[Job]: + return _ExtensionLoader.get_job_class(name) + + +def get_task_class(name) -> Type[Task]: + return _ExtensionLoader.get_task_class(name) + + +def get_job_schema_validator(name): + return _ExtensionLoader.get_job_schema_validator(name) + + +if __name__ == '__main__': + loader = get_job_class("paddle_fl") diff --git a/VisualFL/visualfl/extensions/conf/extensions.yaml b/VisualFL/visualfl/extensions/conf/extensions.yaml new file mode 100755 index 000000000..928c9f6f1 --- /dev/null +++ b/VisualFL/visualfl/extensions/conf/extensions.yaml @@ -0,0 +1,10 @@ +PaddleFL: + jobs: + - name: paddle_fl + schema: ../schema/paddle_fl.json + loader: visualfl.paddle_fl.job:PaddleFLJob + tasks: + - name: fl_trainer + loader: visualfl.paddle_fl.tasks.trainer:FLTrainer + - name: fl_aggregator + loader: visualfl.paddle_fl.tasks.aggregator:FLAggregator diff --git a/VisualFL/visualfl/extensions/schema/paddle_fl.json b/VisualFL/visualfl/extensions/schema/paddle_fl.json new file mode 100755 index 000000000..683336011 --- /dev/null +++ b/VisualFL/visualfl/extensions/schema/paddle_fl.json @@ -0,0 +1,27 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "PaddleFL", + "description": "PaddleFL job config", + "type": "object", + "properties": { + "program": { + "type": "string", + "description": "program to run" + }, + "worker_num": { + "description": "number of worker", + "type": "integer" + }, + "max_iter": { + "description": "max number of iteration", + "type": "integer" + }, + "inner_step": { + "description": "inner step", + "type": "integer" + } + }, + "required": [ + "worker_num" + ] +} diff --git a/VisualFL/visualfl/manager.py b/VisualFL/visualfl/manager.py new file mode 100644 index 000000000..450a60426 --- /dev/null +++ b/VisualFL/visualfl/manager.py @@ -0,0 +1,350 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 The FedVision Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time +from typing import Optional, MutableMapping, AsyncGenerator, Tuple + +import grpc + +from visualfl.protobuf import cluster_pb2, cluster_pb2_grpc, job_pb2 +from visualfl.utils.logger import Logger, pretty_pb + + +class ClusterManager(Logger, cluster_pb2_grpc.ClusterManagerServicer): + def __init__( + self, + port: int, + host: str = None, + ): + """ + init cluster manager instance + Args: + port: + host: + """ + self._host = "[::]" if host is None else host + self._port = port + self._alive_workers: MutableMapping[str, _WorkerDescription] = {} + self._tasks_status = {} + self._max_heartbeat_delay = 5 + + self._server: Optional[grpc.aio.Server] = None + + def has_worker(self, worker_id) -> bool: + """ + check worker `worker_id` alive(enrolled) + Args: + worker_id: + + Returns: + + """ + return worker_id in self._alive_workers + + def add_worker(self, worker_id, worker_ip, max_tasks, port_start, port_end): + """ + add worker to manager + Args: + worker_id: + worker_ip: + max_tasks: + port_start: + port_end: + + Returns: + + """ + worker = _WorkerDescription( + worker_id=worker_id, + worker_ip=worker_ip, + max_tasks=max_tasks, + max_delay=self._max_heartbeat_delay, + port_start=port_start, + port_end=port_end, + ) + self._alive_workers[worker_id] = worker + + async def _healthy_watcher(): + try: + while True: + await asyncio.sleep(self._max_heartbeat_delay) + if worker_id not in self._alive_workers: + self.error(f"worker:{worker_id} not found") + break + + if worker.is_asystole(): + self.error(f"heartbeat from worker:{worker_id} loss") + break + finally: + self.remove_worker(worker_id) + + asyncio.create_task(_healthy_watcher()) + return worker + + def remove_worker(self, worker_id): + """ + remove worker from manager + Args: + worker_id: + """ + if worker_id not in self._alive_workers: + return + del self._alive_workers[worker_id] + + async def Enroll( + self, + request: cluster_pb2.Enroll.REQ, + context: grpc.aio.ServicerContext, + ) -> AsyncGenerator[cluster_pb2.Enroll.REP, None]: + """ + rpc server impl: process tasker enroll request + + Args: + request: + context: + + Returns: + + """ + self.debug(f"cluster worker enroll request: {pretty_pb(request)}") + if self.has_worker(request.worker_id): + yield cluster_pb2.Enroll.REP(status=cluster_pb2.Enroll.ALREADY_ENROLL) + return + + worker = self.add_worker( + request.worker_id, + request.worker_ip, + request.max_tasks, + request.port_start, + request.port_end, + ) + self.debug(f"cluster worker enroll success: worker: {request.worker_id}") + yield cluster_pb2.Enroll.REP(status=cluster_pb2.Enroll.ENROLL_SUCCESS) + + while self.has_worker(request.worker_id): + try: + task = await worker.wait_next_task(timeout=5) + except asyncio.TimeoutError: + continue + + self.debug( + f"task ready: job_id={task.job_id}, task_id={task.task_id}, task_type={task.task_type}" + ) + rep = cluster_pb2.Enroll.REP( + status=cluster_pb2.Enroll.TASK_READY, task=task + ) + self.debug( + f"response task({task.task_id}, {task.task_type}) to worker {request.worker_id}" + ) + yield rep + + self.remove_worker(request.worker_id) + + async def UpdateTaskStatus( + self, request: cluster_pb2.UpdateStatus.REQ, context: grpc.aio.ServicerContext + ) -> cluster_pb2.UpdateStatus.REP: + """ + process task status update request + Args: + request: + context: + + Returns: + + """ + if request.worker_id not in self._alive_workers: + return cluster_pb2.UpdateStatus.REP(status=cluster_pb2.UpdateStatus.FAILED) + await self._alive_workers[request.worker_id].update_heartbeat() + + if not request.task_id: + return cluster_pb2.UpdateStatus.REP(status=cluster_pb2.UpdateStatus.SUCCESS) + + if not request.task_id not in self._tasks_status: + return cluster_pb2.UpdateStatus.REP(status=cluster_pb2.UpdateStatus.FAILED) + + self.debug(f"update task status: {request.task_id} to {request.task_status}") + self._tasks_status[request.task_id] = request.task_status + return cluster_pb2.UpdateStatus.REP(status=cluster_pb2.UpdateStatus.SUCCESS) + + async def TaskSubmit( + self, request: cluster_pb2.TaskSubmit.REQ, context: grpc.aio.ServicerContext + ) -> cluster_pb2.TaskSubmit.REP: + """ + process task submit request + Args: + request: + context: + + Returns: + + """ + try: + task = request.task + if not task.assignee: + worker, _ = await self.dispatch() + await worker.put_task(task=task) + else: + await self._alive_workers[task.assignee].put_task(task=task) + return cluster_pb2.TaskSubmit.REP(status=cluster_pb2.TaskSubmit.SUCCESS) + except Exception as e: + self.exception(f"handle task submit failed: {e}") + return cluster_pb2.TaskSubmit.REP(status=cluster_pb2.TaskSubmit.FAILED) + + async def TaskResourceRequire(self, request, context): + """ + process task resource acquired request + Args: + request: + context: + + Returns: + + """ + worker, endpoints = await self.dispatch( + resource={"endpoints": request.num_endpoints} + ) + if worker is None: + return cluster_pb2.TaskResourceRequire.REP( + status=cluster_pb2.TaskResourceRequire.FAILED + ) + + response = cluster_pb2.TaskResourceRequire.REP( + status=cluster_pb2.TaskResourceRequire.SUCCESS, worker_id=worker.worker_id + ) + for endpoint in endpoints: + response.endpoints.append(endpoint) + return response + + async def start(self): + """ + start cluster manager service + Returns: + + """ + self.info(f"starting cluster manager at port: {self._port}") + self._server = grpc.aio.server( + options=[ + ("grpc.max_send_message_length", 512 * 1024 * 1024), + ("grpc.max_receive_message_length", 512 * 1024 * 1024), + ], + ) + cluster_pb2_grpc.add_ClusterManagerServicer_to_server(self, self._server) + self._server.add_insecure_port(f"{self._host}:{self._port}") + await self._server.start() + self.info(f"cluster manager started at port: {self._port}") + + async def stop(self): + """ + stop cluster manager service + """ + await self._server.stop(1) + + async def dispatch( + self, resource: dict = None + ) -> Tuple[Optional["_WorkerDescription"], list]: + """ + dispatch tasks to worker + Args: + resource: + + Returns: + + """ + if resource is None: + resource = {} + if not resource: + for k, v in self._alive_workers.items(): + if v.has_task_capacity(): + v.task_task_capacity() + return v, [] + elif "endpoints" in resource: + num_endpoints = resource["endpoints"] + for k, v in self._alive_workers.items(): + if v.has_num_valid_endpoints(num_endpoints) and v.has_task_capacity(): + v.task_task_capacity() + endpoints = v.take_endpoints(num_endpoints) + return v, endpoints + return None, [] + + +class _WorkerDescription(object): + def __init__( + self, worker_id, worker_ip, max_tasks, max_delay, port_start, port_end + ): + self.worker_id = worker_id + self.worker_ip = worker_ip + self._port_start = port_start + self._port_end = port_end + self._max_tasks = max_tasks + self._max_delay = max_delay + self._last_heartbeat = time.time() + self._task_queue: asyncio.Queue[job_pb2.Task] = asyncio.Queue() + self._port_used = [False] * ( + self._port_end - self._port_start + ) # todo: use memory-friendly data structure + self.num_port_remind = self._port_end - self._port_start + self.num_task_remind = self._max_tasks + self._current_pos = 0 + + def has_num_valid_endpoints(self, num): + return self.num_task_remind >= num + + def has_task_capacity(self): + return self.num_task_remind > 0 + + def task_task_capacity(self): + self.num_task_remind -= 1 + + def take_endpoints(self, num): + endpoints = [] + for i in range(num): + endpoints.append(self._next_valid_endpoint()) + return endpoints + + def _next_valid_endpoint(self): + for i in range(self._port_end - self._port_start): + index = (i + self._current_pos) % (self._port_end - self._port_start) + if not self._port_used[index]: + self._current_pos = index + 1 + return f"{self.worker_ip}:{self._port_start + index}" + raise Exception(f"no endpoint left") + + async def put_task(self, task: job_pb2.Task): + return await self._task_queue.put(task) + + async def update_heartbeat(self): + t = time.time() + self._last_heartbeat = t + + def is_asystole(self): + t = time.time() + return t - self._last_heartbeat > self._max_delay + + async def wait_next_task(self, timeout): + return await asyncio.wait_for(self._task_queue.get(), timeout=timeout) diff --git a/VisualFL/visualfl/master.py b/VisualFL/visualfl/master.py new file mode 100644 index 000000000..324998f1e --- /dev/null +++ b/VisualFL/visualfl/master.py @@ -0,0 +1,764 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 The FedVision Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os,sys +import asyncio +import random +import subprocess + +import aiofiles +import enum +import json +from pathlib import Path +import traceback +from datetime import datetime +from typing import Optional, MutableMapping, List +import attr +import grpc +from aiohttp import web +from visualfl import extensions +from visualfl.paddle_fl.abs.job import Job +from visualfl.protobuf import ( + job_pb2, + cluster_pb2, + cluster_pb2_grpc +) +from visualfl.utils.exception import VisualFLExtensionException +from visualfl.utils.logger import Logger +from visualfl.utils.tools import * +from visualfl.utils.consts import TaskStatus +from visualfl.utils import data_loader +from visualfl.paddle_fl.executor import ProcessExecutor +from visualfl import __basedir__,__logs_dir__ +from visualfl.utils.consts import ComponentName,TaskResultType + +class _JobStatus(enum.Enum): + """ + query job status + """ + + NOTFOUND = "not_found" + APPLYING = "applying" + INFERING = "infering" + WAITING = "waiting" + PROPOSAL = "proposal" + RUNNING = "running" + FAILED = "failed" + SUCCESS = "success" + + +@attr.s +class _SharedStatus(object): + """ + status shared by ... + """ + + member_id = attr.ib(type=str) + job_types = attr.ib(type=List[str], default=["paddle_fl"]) + + def __attrs_post_init__(self): + self.job_status: MutableMapping[str, _JobStatus] = {} + self.cluster_task_queue: asyncio.Queue[job_pb2.Task] = asyncio.Queue() + self.job_queue: asyncio.Queue[Job] = asyncio.Queue() + self.apply_queue: asyncio.Queue[Job] = asyncio.Queue() + self.infer_queue: asyncio.Queue[Job] = asyncio.Queue() + self.job_counter = 0 + + def generate_job_id(self): + """ + generate unique job id + """ + self.job_counter += 1 + return f"{self.member_id}-{datetime.now().strftime('%Y%m%d%H%M%S')}-{self.job_counter}" + + +@attr.s +class ProposalAcceptRule(Logger): + shared_status = attr.ib(type=_SharedStatus) + + async def accept(self, job_type): + # accept all + return job_type in self.shared_status.job_types + + +class RESTService(Logger): + """ + service accept restful request from users + """ + + def __init__(self, shared_status: _SharedStatus, port: int, host: str = None): + """ + init rest services instance + Args: + shared_status: + port: + host: + """ + self.shared_status = shared_status + self.port = port + self.host = host + + self._site: Optional[web.TCPSite] = None + + async def start_rest_site(self): + """ + start web service non-blocked + """ + self.info( + f" ing restful services at {':' if self.host is None else self.host}:{self.port}" + ) + app = web.Application() + app.add_routes(self._register_routes()) + runner = web.AppRunner(app, access_log=self.get_logger()) + await runner.setup() + self._site = web.TCPSite(runner=runner, host=self.host, port=self.port) + await self._site.start() + self.info( + f"restful services started at {':' if self.host is None else self.host}:{self.port}" + ) + + async def stop_rest_site(self): + """ + stop web service + """ + if self._site is not None: + await self._site.stop() + + def _register_routes( + self, route_table: Optional[web.RouteTableDef] = None + ) -> web.RouteTableDef: + """ + register routes: + + 1. submitter + 2. query + Args: + route_table: optional provide a `RouteTableDef` instance. + + Returns: + + """ + if route_table is None: + route_table = web.RouteTableDef() + route_table.post("/apply")(self._restful_apply) + route_table.post("/submit")(self._restful_submit) + route_table.post("/query")(self._restful_query) + route_table.post("/infer")(self._restful_infer) + route_table.post("/stop")(self._restful_stop) + route_table.get("/serving_model/download")(self._restful_download) + return route_table + + def web_response(self,code,message,job_id=None): + return web.json_response(data=dict(code=code,message=message,job_id=job_id), status=code) + + async def _restful_download(self,request: web.Request) -> web.Response: + + def query_parse(req): + obj = req.query_string + queryitem = [] + if obj: + query = req.query.items() + for item in query: + queryitem.append(item) + return dict(queryitem) + else: + return None + + async def export_serving_mode(job_id,task_id,config_json,algorithm_config_json): + step = TaskDao(task_id).get_task_progress() + serving_model_path = Path(__logs_dir__).joinpath(f"jobs/{job_id}/serving_model") + local_trainer_indexs = config_json.get("local_trainer_indexs") + weights = Path(__logs_dir__).joinpath(f"jobs/{job_id}/trainer_{local_trainer_indexs[0]}/checkpoint/{step}") + program = algorithm_config_json.get("program") + architecture = algorithm_config_json.get("architecture") + if program == "paddle_detection": + program_full_path = os.path.join(__basedir__, 'algorithm', 'paddle_detection') + config_name = f'{architecture}.yml' + algorithm_config_path = os.path.join(program_full_path, "configs", architecture.split('_')[0],config_name) + + executor = ProcessExecutor(serving_model_path) + executable = sys.executable + cmd = " ".join( + [ + f"{executable} -m visualfl.algorithm.{program}.export_serving_model", + f"--task_id {task_id}", + f"-o weights={weights}", + f"--output_dir {serving_model_path}", + f"-c {algorithm_config_path}", + f">{executor.stdout} 2>{executor.stderr}", + ] + ) + returncode, pid = await executor.execute(cmd) + if returncode != 0: + raise Exception("export serving model error") + + query = query_parse(request) + self.debug(f"export serving model request data: {query}") + + job_id = query.get("job_id") + task_id = query.get("task_id") + serving_model_path = Path(__logs_dir__).joinpath(f"jobs/{job_id}/serving_model") + algorithm_config_path = Path(__logs_dir__).joinpath(f"jobs/{job_id}/master/algorithm_config.json") + config_path = Path(__logs_dir__).joinpath(f"jobs/{job_id}/master/config.json") + with open(config_path) as f: + config_json = json.load(f) + with open(algorithm_config_path) as f: + algorithm_config_json = json.load(f) + + await export_serving_mode(job_id,task_id,config_json,algorithm_config_json) + + cfg_name = algorithm_config_json.get("architecture") + zip_file = os.path.join(serving_model_path, f"{cfg_name}.zip") + data_loader.make_zip(os.path.join(serving_model_path, cfg_name),zip_file) + + if os.path.exists(zip_file): + async with aiofiles.open(zip_file, 'rb') as f: + content = await f.read() + if content: + response = web.Response( + content_type='application/octet-stream', + headers={'Content-Disposition': 'attachment;filename={}'.format(zip_file)}, + body=content + ) + return response + + else: + return self.web_response(400, f"read file :{zip_file} error",job_id) + else: + return self.web_response(400, f"file path :{zip_file} not exists",job_id) + + async def _restful_submit(self, request: web.Request) -> web.Response: + """ + handle submit request + Args: + request: + + Returns: + + """ + try: + data = await request.json() + self.debug(f"restful submit request data: {data}") + except json.JSONDecodeError as e: + return self.web_response(400,str(e)) + + try: + job_id = data["job_id"] + task_id = data["task_id"] + role = data["role"] + member_id = data["member_id"] + job_type = data["job_type"] + config = data["env"] + data_set = data["data_set"] + download_url = data_set["download_url"] + data_name = data_set["name"] + algorithm_config = data.get("algorithm_config") + program = algorithm_config["program"] + config["max_iter"] = algorithm_config["max_iter"] + algorithm_config["download_url"] = download_url + algorithm_config["data_name"] = data_name + + except Exception: + return self.web_response(400, traceback.format_exc(),job_id) + + + # noinspection PyBroadException + try: + loader = extensions.get_job_class(job_type) + validator = extensions.get_job_schema_validator(job_type) + if loader is None: + raise VisualFLExtensionException(f"job type {job_type} not supported") + # validator.validate(config) + job = loader.load( + job_id=job_id, task_id=task_id,role=role,member_id=member_id,config=config, algorithm_config=algorithm_config + ) + + except Exception: + # self.logger.exception("[submit]catch exception") + reason = traceback.format_exc() + return self.web_response(400, reason,job_id) + + + self.shared_status.job_status[job_id] = _JobStatus.WAITING + await self.shared_status.job_queue.put(job) + + return self.web_response(200, "success", job_id) + + async def _restful_stop(self, request: web.Request) -> web.Response: + """ + handle query request + + Args: + request: + + Returns: + + """ + try: + data = await request.json() + except json.JSONDecodeError as e: + return self.web_response(400, str(e)) + + job_id = data.get("job_id", None) + if job_id is None: + return self.web_response(400, "required `job_id`") + + try: + for line in os.popen(f'ps -ef | grep fl_ | grep {job_id} | grep -v grep').readlines(): + pid = line.split()[1] + subprocess.Popen(f"kill -9 {pid}", shell=True) + except Exception: + Logger.error(f"failed: can't stop job {job_id}") + return self.web_response(400, f"failed: can't stop job {job_id}") + + return self.web_response(200, "stop job success", job_id) + + async def _restful_query(self, request: web.Request) -> web.Response: + """ + handle query request + + Args: + request: + + Returns: + + """ + try: + data = await request.json() + except json.JSONDecodeError as e: + return self.web_response(400, str(e)) + + job_id = data.get("job_id", None) + if job_id is None: + return self.web_response(400, "required `job_id`") + + if job_id not in self.shared_status.job_status: + return self.web_response(404, "job_id not found",job_id) + + return web.json_response( + data=dict(code=200,job_id=job_id,status=str(self.shared_status.job_status[job_id]),message="success"), + ) + + + async def _restful_apply(self, request: web.Request) -> web.Response: + """ + handle apply request + + Args: + request: + + Returns: + + """ + try: + data = await request.json() + self.debug(f"restful apply request data: {data}") + except json.JSONDecodeError as e: + return self.web_response(400, str(e)) + + try: + job_id = data["job_id"] + task_id = data["task_id"] + role = data["role"] + member_id = data["member_id"] + job_type = data["job_type"] + config = data["env"] + callback_url = data["callback_url"] + data_set = data["data_set"] + download_url = data_set["download_url"] + data_name = data_set["name"] + algorithm_config = data.get("algorithm_config") + config["max_iter"] = algorithm_config["max_iter"] + algorithm_config["download_url"] = download_url + algorithm_config["data_name"] = data_name + + except Exception: + return self.web_response(400, traceback.format_exc(),job_id) + + try: + loader = extensions.get_job_class(job_type) + validator = extensions.get_job_schema_validator(job_type) + if loader is None: + raise VisualFLExtensionException(f"job type {job_type} not supported") + # validator.validate(env_config) + job = loader.load( + job_id=job_id, task_id=task_id,role=role, member_id=member_id, config=config, algorithm_config=algorithm_config,callback_url=callback_url + ) + + except Exception: + # self.logger.exception("[submit]catch exception") + return self.web_response(400, traceback.format_exc(),job_id) + + self.shared_status.job_status[job_id] = _JobStatus.APPLYING + await self.shared_status.apply_queue.put(job) + + return self.web_response(200, "success", job_id) + + async def _restful_infer(self, request: web.Request) -> web.Response: + """ + handle infer request + + Args: + request: + + Returns: + + """ + try: + data = await request.json() + self.debug(f"restful infer request data: {data}") + except json.JSONDecodeError as e: + return self.web_response(400, str(e)) + + try: + job_id = data["job_id"] + task_id = data["task_id"] + role = data["role"] + member_id = data["member_id"] + job_type = data["job_type"] + config = data["env"] + data_set = data["data_set"] + download_url = data_set["download_url"] + data_name = data_set["name"] + config["download_url"] = download_url + config["data_name"] = data_name + algorithm_config = data.get("algorithm_config") + cur_step = TaskDao(task_id).get_task_progress() + input_dir = os.path.join(__logs_dir__,f"jobs/{job_id}/infer/input") + infer_session_id = data_set.get("infer_session_id", '') + infer_dir = data_loader.job_download(download_url, infer_session_id, input_dir) + data_loader.extractImages(infer_dir) + output_dir = os.path.join(__logs_dir__, f"jobs/{job_id}/infer/output/{os.path.basename(infer_dir)}") + config["cur_step"] = cur_step + config["infer_dir"] = infer_dir + config["output_dir"] = output_dir + + task_result = {"infer_session_id": infer_session_id,"status": "wait_run"} + program = algorithm_config["program"] + componentName = ComponentName.DETECTION if program == "paddle_detection" else ComponentName.CLASSIFY + TaskDao(task_id).save_task_result(task_result, componentName,type=TaskResultType.INFER) + + except Exception as e: + self.exception(f"infer request download and process images error as {e} ") + return self.web_response(400, traceback.format_exc(),job_id) + + try: + loader = extensions.get_job_class(job_type) + if loader is None: + raise VisualFLExtensionException(f"job type {job_type} not supported") + from visualfl.paddle_fl.job import PaddleFLJob + job = loader.load( + job_id=job_id, task_id=task_id, role=role, member_id=member_id, config=config, + algorithm_config=algorithm_config, is_infer=True + ) + + except Exception: + return self.web_response(400, traceback.format_exc(),job_id) + + self.shared_status.job_status[job_id] = _JobStatus.INFERING + await self.shared_status.infer_queue.put(job) + + return self.web_response(200, "success", job_id) + + +class ClusterManagerConnect(Logger): + """ + cluster manager client + """ + + def __init__(self, address, shared_status: _SharedStatus): + """ + init cluster manager client + Args: + address: + shared_status: + """ + self.address = address + self.shared_status = shared_status + self._channel: Optional[grpc.aio.Channel] = None + self._stub: Optional[cluster_pb2_grpc.ClusterManagerStub] = None + + async def submit_tasks_to_cluster(self): + """ + infinity loop to get task from queue and submit it to cluster + """ + while True: + task = await self.shared_status.cluster_task_queue.get() + self.debug( + f"task sending: task_id={task.task_id} task_type={task.task_type} to cluster" + ) + await self._stub.TaskSubmit(cluster_pb2.TaskSubmit.REQ(task=task)) + self.debug( + f"task sent: task_id={task.task_id} task_type={task.task_type} to cluster" + ) + + async def task_resource_require( + self, request: cluster_pb2.TaskResourceRequire.REQ + ) -> cluster_pb2.TaskResourceRequire.REP: + """ + acquired resource from cluster(ports) + Args: + request: + + Returns: + + """ + response = await self._stub.TaskResourceRequire(request) + return response + + async def start_cluster_channel(self): + """ + start channel to cluster manager + """ + self.info(f"start cluster channel to {self.address}") + self._channel = grpc.aio.insecure_channel( + self.address, + options=[ + ("grpc.max_send_message_length", 512 * 1024 * 1024), + ("grpc.max_receive_message_length", 512 * 1024 * 1024), + ], + ) + self._stub = cluster_pb2_grpc.ClusterManagerStub(self._channel) + self.info(f"cluster channel started to {self.address}") + + async def cluster_channel_ready(self): + """ + await until channel ready + """ + return await self._channel.channel_ready() + + async def stop_cluster_channel(self, grace: Optional[float] = None): + """ + stop channel to cluster manager + Args: + grace: + + Returns: + + """ + self.info(f"stopping cluster channel") + await self._channel.close(grace) + self.info(f"cluster channel started to {self.address}") + + +class Master(Logger): + def __init__( + self, + member_id: str, + cluster_address: str, + rest_port: int, + rest_host: str = None, + local: bool = False + ): + """ + init master + + Args: + member_id: + rest_port: + rest_host: + """ + self.shared_status = _SharedStatus(member_id=member_id) + self._rest_site = RESTService( + shared_status=self.shared_status, port=rest_port, host=rest_host + ) + self._cluster = ClusterManagerConnect( + shared_status=self.shared_status, address=cluster_address + ) + self.local = local + + def callback(self,job,status=None,message=None): + json_data = dict( + job_id=job.job_id, + task_id = job._web_task_id, + status=status, + message=message, + server_endpoint=job._server_endpoint, + aggregator_endpoint=job._aggregator_endpoint, + aggregator_assignee=job._aggregator_assignee + ) + + self.debug(f"callback url {job._callback_url} , json data is {json_data}") + import requests + r = requests.post(job._callback_url,json=json_data) + self.debug(f"callback {job._callback_url} result: {r.text}") + + async def _infer_job_handler(self): + """ + handle infer jobs. + """ + async def _co_handler(job: Job): + + try: + + await job.infer() + # self.callback(job) + + except Exception as e: + self.exception(f"job infer failed: {e}") + + + while True: + infer_job = await self.shared_status.infer_queue.get() + asyncio.create_task(_co_handler(infer_job)) + + + async def _apply_job_handler(self): + """ + handle submitted jobs. + """ + async def _co_handler(job: Job): + + try: + if job.resource_required is not None: + response = await self._cluster.task_resource_require( + job.resource_required + ) + if response.status != cluster_pb2.TaskResourceRequire.SUCCESS: + raise Exception( + "job failed due to no enough resource" + ) # todo: maybe wait some times and retry? + job.set_required_resource(response) + + await job.compile() + + self.shared_status.job_status[job.job_id] = _JobStatus.RUNNING + for task in job.generate_aggregator_tasks(): + self.debug( + f"send local task: {task.task_id} with task type: {task.task_type} to cluster" + ) + await self.shared_status.cluster_task_queue.put(task) + + self.callback(job) + + except Exception as e: + self.exception(f"run jobs failed: {e}") + TaskDao(job._web_task_id).update_task_status(TaskStatus.ERROR, str(e)) + + + while True: + apply_job = await self.shared_status.apply_queue.get() + asyncio.create_task(_co_handler(apply_job)) + + + async def _submitted_job_handler(self): + """ + handle submitted jobs. + """ + async def _co_handler(job: Job): + + # todo: generalize this process + # stick to paddle fl job now + + try: + + if self.local: + if job.resource_required is not None: + response = await self._cluster.task_resource_require( + job.resource_required + ) + if response.status != cluster_pb2.TaskResourceRequire.SUCCESS: + raise Exception( + "job failed due to no enough resource" + ) # todo: maybe wait some times and retry? + job.set_required_resource(response) + + # compile job + await job.compile() + + self.shared_status.job_status[job.job_id] = _JobStatus.RUNNING + for task in job.generate_aggregator_tasks(): + self.debug( + f"send aggregator task: {task.task_id} with task type: {task.task_type} to cluster" + ) + await self.shared_status.cluster_task_queue.put(task) + + await asyncio.sleep(5) + else: + await job.compile() + + for task in job.generate_trainer_tasks(): + self.debug( + f"send trainer task: {task.task_id} with task type: {task.task_type} to cluster" + ) + await self.shared_status.cluster_task_queue.put(task) + + except Exception as e: + self.exception(f"run jobs failed: {e}") + TaskDao(job._web_task_id).update_task_status(status=TaskStatus.ERROR, message=str(e)) + + while True: + submitted_job = await self.shared_status.job_queue.get() + asyncio.create_task(_co_handler(submitted_job)) + + async def start(self): + """ + start master: + + 1. cluster manager to process tasks + 2. restful service to handler request from user + 3. coordinator to connect to `the world` + + """ + + # connect to cluster + await self._cluster.start_cluster_channel() + while True: + try: + await asyncio.wait_for(self._cluster.cluster_channel_ready(), 5) + except asyncio.TimeoutError: + self.warning(f"cluster channel not ready, retry in 5 seconds") + else: + self.info(f"cluster channel ready!") + break + # get task from queue and submit it to cluster + asyncio.create_task(self._cluster.submit_tasks_to_cluster()) + + # start rest site + await self._rest_site.start_rest_site() + + + #get job from apply_queue and require source + asyncio.create_task(self._apply_job_handler()) + + #get job from job_queue and send task to cluster by put it into a task queue + asyncio.create_task(self._submitted_job_handler()) + + # get job from infer_queue + asyncio.create_task(self._infer_job_handler()) + + async def stop(self): + """ + stop master + """ + # await self._coordinator.stop_coordinator_channel(grace=1) + await self._rest_site.stop_rest_site() + await self._cluster.stop_cluster_channel(grace=1) + diff --git a/VisualFL/visualfl/paddle_fl/__init__.py b/VisualFL/visualfl/paddle_fl/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/paddle_fl/abs/__init__.py b/VisualFL/visualfl/paddle_fl/abs/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/abs/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/paddle_fl/abs/executor.py b/VisualFL/visualfl/paddle_fl/abs/executor.py new file mode 100644 index 000000000..97243e3da --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/abs/executor.py @@ -0,0 +1,39 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +from pathlib import Path + + +class Executor(object): + def __init__(self, working_dir: Path): + self._working_dir = working_dir + self._working_dir.mkdir(parents=True, exist_ok=True) + + @abc.abstractmethod + async def execute(self, cmd) -> int: + ... + + @property + def stderr(self): + return "stderr" + + @property + def stdout(self): + return "stdout" + + @property + def working_dir(self): + return self._working_dir diff --git a/VisualFL/visualfl/paddle_fl/abs/job.py b/VisualFL/visualfl/paddle_fl/abs/job.py new file mode 100644 index 000000000..ea05d03b2 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/abs/job.py @@ -0,0 +1,56 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +from typing import List + +from visualfl.protobuf import job_pb2 + + +class Job(metaclass=abc.ABCMeta): + + job_type: str + + def __init__(self, job_id: str): + self.job_id = job_id + + @property + def resource_required(self): + return None + + def set_required_resource(self, response): + ... + + async def compile(self): + ... + + async def infer(self): + ... + + @abc.abstractmethod + def generate_aggregator_tasks(self) -> List[job_pb2.Task]: + ... + + @abc.abstractmethod + def generate_trainer_tasks(self) -> List[job_pb2.Task]: + ... + + @classmethod + @abc.abstractmethod + def load(cls, job_id: str,role: str, member_id: str, config, algorithm_config) -> "Job": + ... + + def generate_task_id(self, task_name): + return f"{self.job_id}-task_{task_name}" diff --git a/VisualFL/visualfl/paddle_fl/abs/task.py b/VisualFL/visualfl/paddle_fl/abs/task.py new file mode 100644 index 000000000..7f0f17bf2 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/abs/task.py @@ -0,0 +1,44 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc + +from visualfl.paddle_fl.abs.executor import Executor +from visualfl.protobuf import job_pb2 + + +class Task(metaclass=abc.ABCMeta): + """ + A abstract Task class. + """ + + task_type: str + + def __init__(self, job_id, task_id,web_task_id): + self.job_id = job_id + self.task_id = task_id + self.web_task_id = web_task_id + + @abc.abstractmethod + async def exec(self, executor: Executor) -> int: + ... + + def __str__(self): + return f"{self.__class__.__name__}[{self.__dict__}]" + + @classmethod + @abc.abstractmethod + def deserialize(cls, pb: job_pb2.Task) -> "Task": + ... diff --git a/VisualFL/visualfl/paddle_fl/executor.py b/VisualFL/visualfl/paddle_fl/executor.py new file mode 100644 index 000000000..be57730da --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/executor.py @@ -0,0 +1,61 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import subprocess +import os +from pathlib import Path +from typing import Optional + +from visualfl.utils.logger import Logger +from visualfl import VISUALFL_DATA_BASE_ENV +from visualfl.paddle_fl.abs.executor import Executor + + +class ProcessExecutor(Executor,Logger): + + def __init__(self, working_dir: Path, data_dir=None): + super().__init__(working_dir) + self._data_dir = data_dir + + async def execute(self, cmd) -> Optional[int]: + self.info(f"execute cmd {cmd} at {self.working_dir}") + try: + env = os.environ.copy() + if self._data_dir is not None: + env[VISUALFL_DATA_BASE_ENV] = self._data_dir + sub = await asyncio.subprocess.create_subprocess_shell( + cmd, shell=True, cwd=self.working_dir, env=env + ) + await sub.communicate() + return sub.returncode,sub.pid + + except Exception as e: + self.error(e) + + def syncexecute(self, cmd) -> Optional[int]: + self.info(f"execute cmd {cmd} at {self.working_dir}") + try: + env = os.environ.copy() + if self._data_dir is not None: + env[VISUALFL_DATA_BASE_ENV] = self._data_dir + + p = subprocess.Popen(cmd,shell=True, cwd=self.working_dir, env=env) + p.communicate() + return p.returncode,p.pid + except Exception as e: + self.error(e) + + diff --git a/VisualFL/visualfl/paddle_fl/job.py b/VisualFL/visualfl/paddle_fl/job.py new file mode 100644 index 000000000..42c7c1856 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/job.py @@ -0,0 +1,220 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import json +import os.path +import sys +from pathlib import Path +from typing import List + +from visualfl.paddle_fl.abs.job import Job +from visualfl.paddle_fl.executor import ProcessExecutor +from visualfl.protobuf import job_pb2, cluster_pb2 +from visualfl.utils.exception import VisualFLJobCompileException +from visualfl.protobuf import fl_job_pb2 +from visualfl import __basedir__,__logs_dir__ +JOB_TYPE = "paddle_fl" + + +class PaddleFLJob(Job): + job_type = JOB_TYPE + + @classmethod + def load(cls, job_id,task_id, role, member_id, config, algorithm_config,callback_url=None,is_infer=False,) -> "PaddleFLJob": + job = PaddleFLJob(job_id=job_id,task_id=task_id,role=role,member_id=member_id) + if is_infer: + job._init_infer_job(config,algorithm_config,callback_url) + else: + job._init_fl_job(config,algorithm_config,callback_url) + return job + + def __init__(self,job_id,task_id,role,member_id): + super().__init__(job_id=job_id) + self._web_task_id = task_id + self._role = role + self._member_id = member_id + + def _init_fl_job(self,config, algorithm_config,callback_url=None): + self._worker_num = config["worker_num"] + self._local_worker_num = config["local_worker_num"] + self._local_trainer_indexs = config["local_trainer_indexs"] + self._program = algorithm_config["program"] + self._trainer_entrypoint = f"visualfl.algorithm.{self._program}.fl_trainer" + self._config_string = json.dumps(config) + self._algorithm_config = json.dumps(algorithm_config) + self._server_endpoint = config.get("server_endpoint",None) + self._aggregator_endpoint = config.get("aggregator_endpoint",None) + self._aggregator_assignee = config.get("aggregator_assignee",None) + self._callback_url = callback_url + + def _init_infer_job(self, config, algorithm_config, callback_url=None): + self._local_trainer_indexs = config["local_trainer_indexs"] + self._program = algorithm_config["program"] + # self._config_string = json.dumps(config) + # self._algorithm_config = json.dumps(algorithm_config) + self._callback_url = callback_url + self._use_gpu = True if config["device"].lower() == 'gpu' else False + self._output_dir = config["output_dir"] + self._infer_dir = config["infer_dir"] + cur_step = config["cur_step"] + self._weights = Path(__logs_dir__).joinpath( + f"jobs/{self.job_id}/trainer_{self._local_trainer_indexs[0]}/checkpoint/{cur_step}") + self._algorithm_config_path = Path(__logs_dir__).joinpath(f"jobs/{self.job_id}/master/algorithm_config.json") + + architecture = algorithm_config["architecture"] + if self._program == "paddle_detection": + program_full_path = os.path.join(__basedir__, 'algorithm', 'paddle_detection') + config_name = f'{architecture}.yml' + self._algorithm_config_path = os.path.join(program_full_path, "configs", architecture.split('_')[0],config_name) + + @property + def resource_required(self): + return cluster_pb2.TaskResourceRequire.REQ(num_endpoints=2) + + # noinspection PyAttributeOutsideInit + def set_required_resource(self, response): + self._server_endpoint = response.endpoints[0] + self._aggregator_endpoint = response.endpoints[1] + self._aggregator_assignee = response.worker_id + + @property + def compile_path(self): + return Path(__logs_dir__).joinpath(f"jobs/{self.job_id}/master") + + @property + def infer_path(self): + return Path(__logs_dir__).joinpath(f"jobs/{self.job_id}/infer") + + async def compile(self): + executor = ProcessExecutor(self.compile_path) + with self.compile_path.joinpath("algorithm_config.json").open("w") as f: + f.write(self._algorithm_config) + with self.compile_path.joinpath("config.json").open("w") as f: + f.write(self._config_string) + executable = sys.executable + cmd = " ".join( + [ + f"{executable} -m visualfl.algorithm.{self._program}.fl_master", + f"--ps-endpoint {self._server_endpoint}", + f"--algorithm-config algorithm_config.json", + f"--config config.json", + f">{executor.stdout} 2>{executor.stderr}", + ] + ) + returncode,pid = await executor.execute(cmd) + if returncode != 0: + raise VisualFLJobCompileException("compile error") + + async def infer(self): + executor = ProcessExecutor(self.infer_path) + executable = sys.executable + cmd = " ".join( + [ + f"{executable} -m visualfl.algorithm.{self._program}.infer", + f"--task_id {self._web_task_id}", + f"--use_gpu {self._use_gpu}", + f"--weights {self._weights}", + f"--infer_dir {self._infer_dir}", + f"--output_dir {self._output_dir}", + f"-c {self._algorithm_config_path}", + f">{executor.stdout} 2>{executor.stderr}", + ] + ) + returncode, pid = await executor.execute(cmd) + if returncode != 0: + raise VisualFLJobCompileException("infer error") + + def generate_trainer_tasks(self) -> List[job_pb2.Task]: + tasks = [] + for i ,v in enumerate(self._local_trainer_indexs): + tasks.append(self._generate_trainer_task_pb(v)) + return tasks + + def generate_aggregator_tasks(self) -> List[job_pb2.Task]: + return [ + self._generate_aggregator_task_pb(), + ] + + def _generate_trainer_task_pb(self,i): + task_pb = job_pb2.Task( + job_id=self.job_id, + task_id=f"trainer_{i}", + web_task_id=self._web_task_id, + task_type="fl_trainer") + + trainer_pb = fl_job_pb2.PaddleFLWorkerTask( + scheduler_ep=self._aggregator_endpoint, + trainer_id=i, + trainer_ep=f"trainer_{i}", + entrypoint=self._trainer_entrypoint, + main_program=_load_program_bytes( + self.compile_path.joinpath(f"compile/trainer{i}/trainer.main.program") + ), + startup_program=_load_program_bytes( + self.compile_path.joinpath( + f"compile/trainer{i}/trainer.startup.program" + ) + ), + send_program=_load_program_bytes( + self.compile_path.joinpath(f"compile/trainer{i}/trainer.send.program") + ), + recv_program=_load_program_bytes( + self.compile_path.joinpath(f"compile/trainer{i}/trainer.recv.program") + ), + feed_names=_load_program_bytes( + self.compile_path.joinpath(f"compile/trainer{i}/feed_names") + ), + target_names=_load_program_bytes( + self.compile_path.joinpath(f"compile/trainer{i}/target_names") + ), + strategy=_load_program_bytes( + self.compile_path.joinpath(f"compile/trainer{i}/strategy.pkl") + ), + feeds=_load_program_bytes( + self.compile_path.joinpath(f"compile/trainer{1}/feeds.pkl") + ), + config_string=self._config_string, + algorithm_config_string=self._algorithm_config, + ) + task_pb.task.Pack(trainer_pb) + return task_pb + + def _generate_aggregator_task_pb(self): + scheduler_pb = fl_job_pb2.PaddleFLAggregatorTask( + scheduler_ep=self._aggregator_endpoint, + ) + scheduler_pb.main_program = _load_program_bytes( + self.compile_path.joinpath(f"compile/server0/server.main.program") + ) + scheduler_pb.startup_program = _load_program_bytes( + self.compile_path.joinpath(f"compile/server0/server.startup.program") + ) + scheduler_pb.config_string = self._config_string + + task_pb = job_pb2.Task( + job_id=self.job_id, + web_task_id=self._web_task_id, + task_id=f"aggregator", + task_type="fl_aggregator", + assignee=self._aggregator_assignee, + ) + task_pb.task.Pack(scheduler_pb) + return task_pb + + +def _load_program_bytes(path: Path): + with path.open("rb") as f: + return f.read() diff --git a/VisualFL/visualfl/paddle_fl/scheduler/__init__.py b/VisualFL/visualfl/paddle_fl/scheduler/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/scheduler/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/paddle_fl/scheduler/fl_scheduler.py b/VisualFL/visualfl/paddle_fl/scheduler/fl_scheduler.py new file mode 100644 index 000000000..81e25cfa1 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/scheduler/fl_scheduler.py @@ -0,0 +1,276 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations +import os +import signal +import asyncio +import json +import random +import sys +from pathlib import Path +import logging + +import click +import grpc + +from visualfl.paddle_fl.executor import ProcessExecutor +from visualfl.utils.exception import VisualFLWorkerException +from visualfl.protobuf import scheduler_pb2_grpc, scheduler_pb2 + + +class Scheduler(scheduler_pb2_grpc.SchedulerServicer): + def __init__( + self, + job_id:str, + scheduler_ep: str, + worker_num: int, + sample_num: int, + max_iter: int, + startup_program, + main_program, + ): + """ + init scheduler + """ + self._scheduler_ep = scheduler_ep + self._worker_num = worker_num + self._sample_num = sample_num + self._max_iter = max_iter + + self._grpc_port = int(self._scheduler_ep.split(":")[-1]) + self._grpc_server = None + + self._inited_workers = {} + self._ready = asyncio.Event() + + self._current_step = 0 + self._candidate = set() + self._wait_next = asyncio.Event() + + self._stop_event = asyncio.Event() + + self._max_delay = 300 + + self._fl_server_watcher = FLServerWatched( + job_id=job_id,main_program=main_program, startup_program=startup_program + ) + + async def start(self): + logging.info(f"starting scheduler gRPC server") + self._grpc_server = grpc.aio.server() + scheduler_pb2_grpc.add_SchedulerServicer_to_server(self, self._grpc_server) + self._grpc_server.add_insecure_port(f"[::]:{self._grpc_port}") + await self._grpc_server.start() + logging.info(f"scheduler gRPC server started at port {self._grpc_port}") + + # start server + await self._fl_server_watcher.start() + + # async def _healthy_watcher(): + # while True: + # await asyncio.sleep(self._max_delay) + # if len(self._inited_workers) == 0: + # self._stop_event.set() + # logging.debug(f"no workers init") + # break + # + # asyncio.create_task(_healthy_watcher()) + + async def stop(self): + logging.info(f"stopping gRPC server gracefully") + await self._grpc_server.stop(1) + logging.info(f"gRPC server stopped") + + await self._fl_server_watcher.stop() + + async def wait_for_termination(self): + await self._stop_event.wait() + await asyncio.sleep(2) + + async def Init(self, request, context): + if self._ready.is_set() or request.name in self._inited_workers: + return scheduler_pb2.Init.REP(status=scheduler_pb2.Init.REJECT) + + self._inited_workers[request.name] = 0 + self._check_init_status() + logging.debug(f"init: {request.name}") + return scheduler_pb2.Init.REP(status=scheduler_pb2.Init.INIT) + + def _check_init_status(self): + if len(self._inited_workers) == self._worker_num: + logging.debug(f"init done") + self._select_candidate() + logging.debug(f"selected: {self._candidate}") + self._ready.set() + + def _check_finish_status(self): + if len(self._candidate) == 0: + logging.debug(f"all worker done, {self._current_step}/{self._max_iter}") + if self._max_iter == self._current_step: + self._stop_event.set() + return + + self._current_step += 1 + self._select_candidate() + self._wait_next.set() + self._wait_next = asyncio.Event() + + def _select_candidate(self): + self._candidate.clear() + logging.debug( + f"starting candidate selection from {self._inited_workers}, k={self._sample_num}" + ) + self._candidate.update( + random.sample(list(self._inited_workers.keys()), k=self._sample_num) + ) + logging.debug(f"candidate selected: {self._candidate}") + + async def WorkerJoin(self, request, context): + logging.debug(f"worker joining: {request.name}") + if request.name not in self._inited_workers: + return scheduler_pb2.WorkerJoin.REP(status=scheduler_pb2.WorkerJoin.REJECT) + await self._ready.wait() + + if request.step < self._current_step: + return scheduler_pb2.WorkerJoin.REP(status=scheduler_pb2.WorkerJoin.REJECT) + + if request.step == self._current_step: + if request.name not in self._candidate: + return scheduler_pb2.WorkerJoin.REP( + status=scheduler_pb2.WorkerJoin.NOT_SELECTED + ) + return scheduler_pb2.WorkerJoin.REP(status=scheduler_pb2.WorkerJoin.ACCEPT) + + if request.step == self._current_step + 1: + if self._max_iter == self._current_step: + return scheduler_pb2.WorkerJoin.REP( + status=scheduler_pb2.WorkerJoin.REJECT + ) + await self._wait_next.wait() + if request.name not in self._candidate: + return scheduler_pb2.WorkerJoin.REP( + status=scheduler_pb2.WorkerJoin.NOT_SELECTED + ) + return scheduler_pb2.WorkerJoin.REP(status=scheduler_pb2.WorkerJoin.ACCEPT) + + return scheduler_pb2.WorkerJoin.REP(status=scheduler_pb2.WorkerJoin.REJECT) + + async def WorkerFinish(self, request, context): + if request.name not in self._candidate: + return scheduler_pb2.WorkerFinish.REP( + status=scheduler_pb2.WorkerFinish.REJECT + ) + self._candidate.remove(request.name) + self._check_finish_status() + return scheduler_pb2.WorkerFinish.REP(status=scheduler_pb2.WorkerFinish.DONE) + + +class FLServerWatched(object): + """ + use scheduler to start and kill fl_server + """ + + def __init__(self, job_id,main_program, startup_program): + self.job_id = job_id + self._main_program = main_program + self._startup_program = startup_program + self.sub_pid = None + + async def start(self): + executor = ProcessExecutor(Path(".")) + python_executable = sys.executable + cmd = " ".join( + [ + f"{python_executable} -m visualfl.paddle_fl.scheduler.fl_server", + f"--job-id={self.job_id}", + f"--startup-program={self._startup_program}", + f"--main-program={self._main_program}", + f">{executor.stdout} 2>{executor.stderr} &", + ] + ) + returncode,pid = await executor.execute(cmd) + if returncode != 0: + raise VisualFLWorkerException( + f"execute task {cmd} failed, return code: {returncode}" + ) + self.sub_pid = pid + + async def stop(self): + if self.sub_pid is not None: + try: + os.kill(int(self.sub_pid)+1,signal.SIGKILL) + except ProcessLookupError as e: + logging.debug(f"kill {self.sub_pid} ProcessLookupError {e}") + + +@click.command() +@click.option("--job-id", type=str, required=True) +@click.option("--scheduler-ep", type=str, required=True) +@click.option( + "--main-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--startup-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--config", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +def fl_scheduler( + job_id, + scheduler_ep, + startup_program, + main_program, + config, +): + logging.basicConfig( + filename="aggregator.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG, + ) + with open(config) as f: + config_dict = json.load(f) + max_iter = config_dict["max_iter"] + worker_num = config_dict["worker_num"] + + loop = asyncio.get_event_loop() + scheduler = Scheduler( + job_id=job_id, + scheduler_ep=scheduler_ep, + worker_num=worker_num, + sample_num=worker_num, + max_iter=max_iter, + startup_program=startup_program, + main_program=main_program, + ) + loop.run_until_complete(scheduler.start()) + + try: + loop.run_until_complete(scheduler.wait_for_termination()) + finally: + loop.run_until_complete(scheduler.stop()) + loop.close() + + +if __name__ == "__main__": + fl_scheduler() diff --git a/VisualFL/visualfl/paddle_fl/scheduler/fl_server.py b/VisualFL/visualfl/paddle_fl/scheduler/fl_server.py new file mode 100644 index 000000000..a8caaf0cf --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/scheduler/fl_server.py @@ -0,0 +1,52 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import click +from paddle import fluid + + +@click.command() +@click.option("--job-id", type=str, required=True) +@click.option( + "--main-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +@click.option( + "--startup-program", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + required=True, +) +def fl_server( + job_id, + startup_program, + main_program, +): + def _load_job_from_file(path): + with Path(path).open("rb") as f: + return fluid.Program.parse_from_string(f.read()) + + server_startup_program = _load_job_from_file(Path(startup_program)) + server_main_program = _load_job_from_file(Path(main_program)) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(server_startup_program) + exe.run(server_main_program) + + + +if __name__ == "__main__": + fl_server() diff --git a/VisualFL/visualfl/paddle_fl/tasks/__init__.py b/VisualFL/visualfl/paddle_fl/tasks/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/tasks/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/paddle_fl/tasks/aggregator.py b/VisualFL/visualfl/paddle_fl/tasks/aggregator.py new file mode 100644 index 000000000..83eb6dc04 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/tasks/aggregator.py @@ -0,0 +1,85 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys + +from visualfl.paddle_fl.abs.executor import Executor +from visualfl.paddle_fl.abs.task import Task +from visualfl.protobuf import job_pb2 +from visualfl.utils.exception import VisualFLWorkerException +from visualfl.protobuf import fl_job_pb2 + + +class FLAggregator(Task): + task_type = "fl_aggregator" + + def __init__( + self, + job_id, + task_id, + web_task_id, + scheduler_ep, + main_program, + startup_program, + config_string, + ): + super().__init__(job_id=job_id, task_id=task_id,web_task_id=web_task_id) + self._scheduler_ep = scheduler_ep + self._main_program = main_program + self._startup_program = startup_program + self._config_string = config_string + + @classmethod + def deserialize(cls, pb: job_pb2.Task) -> "FLAggregator": + if pb.task_type != cls.task_type: + raise VisualFLWorkerException( + f"try to deserialize task_type {pb.task_type} by {cls.task_type}" + ) + scheduler_task_pb = fl_job_pb2.PaddleFLAggregatorTask() + pb.task.Unpack(scheduler_task_pb) + return FLAggregator( + job_id=pb.job_id, + task_id=pb.task_id, + web_task_id=pb.web_task_id, + scheduler_ep=scheduler_task_pb.scheduler_ep, + startup_program=scheduler_task_pb.startup_program, + main_program=scheduler_task_pb.main_program, + config_string=scheduler_task_pb.config_string, + ) + + async def exec(self, executor: Executor): + python_executable = sys.executable + cmd = " ".join( + [ + f"{python_executable} -m visualfl.paddle_fl.scheduler.fl_scheduler", + f"--job-id={self.job_id}", + f"--scheduler-ep={self._scheduler_ep}", + f"--startup-program=startup_program", + f"--main-program=main_program", + f"--config=config.json", + f">{executor.stdout} 2>{executor.stderr}", + ] + ) + with executor.working_dir.joinpath("main_program").open("wb") as f: + f.write(self._main_program) + with executor.working_dir.joinpath("startup_program").open("wb") as f: + f.write(self._startup_program) + with executor.working_dir.joinpath("config.json").open("w") as f: + f.write(self._config_string) + returncode,pid = await executor.execute(cmd) + if returncode != 0: + raise VisualFLWorkerException( + f"execute task: {self.task_id} failed, return code: {returncode}" + ) diff --git a/VisualFL/visualfl/paddle_fl/tasks/trainer.py b/VisualFL/visualfl/paddle_fl/tasks/trainer.py new file mode 100644 index 000000000..4fc628ae7 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/tasks/trainer.py @@ -0,0 +1,139 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys + +from visualfl.paddle_fl.abs.executor import Executor +from visualfl.paddle_fl.abs.task import Task +from visualfl.protobuf import job_pb2 +from visualfl.utils.exception import VisualFLWorkerException +from visualfl.protobuf import fl_job_pb2 + + +class FLTrainer(Task): + task_type = "fl_trainer" + + def __init__( + self, + job_id, + task_id, + web_task_id, + scheduler_ep: str, + trainer_id: int, + trainer_ep: str, + entrypoint, + startup_program, + main_program, + send_program, + recv_program, + feed_names, + target_names, + strategy, + feeds, + config_string, + algorithm_config_string, + ): + super().__init__(job_id=job_id, task_id=task_id,web_task_id=web_task_id) + self._scheduler_ep = scheduler_ep + self._trainer_id = trainer_id + self._trainer_ep = trainer_ep + self._entrypoint = entrypoint + self._startup_program = startup_program + self._main_program = main_program + self._send_program = send_program + self._recv_program = recv_program + self._feed_names = feed_names + self._target_names = target_names + self._strategy = strategy + self._feeds = feeds + self._config_string = config_string + self._algorithm_config_string = algorithm_config_string + + @classmethod + def deserialize(cls, pb: job_pb2.Task) -> "FLTrainer": + if pb.task_type != cls.task_type: + raise VisualFLWorkerException( + f"try to deserialize task_type {pb.task_type} by {cls.task_type}" + ) + worker_task_pb = fl_job_pb2.PaddleFLWorkerTask() + pb.task.Unpack(worker_task_pb) + return FLTrainer( + job_id=pb.job_id, + task_id=pb.task_id, + web_task_id=pb.web_task_id, + scheduler_ep=worker_task_pb.scheduler_ep, + trainer_id=worker_task_pb.trainer_id, + trainer_ep=worker_task_pb.trainer_ep, + entrypoint=worker_task_pb.entrypoint, + startup_program=worker_task_pb.startup_program, + main_program=worker_task_pb.main_program, + send_program=worker_task_pb.send_program, + recv_program=worker_task_pb.recv_program, + feed_names=worker_task_pb.feed_names, + target_names=worker_task_pb.target_names, + strategy=worker_task_pb.strategy, + feeds=worker_task_pb.feeds, + config_string=worker_task_pb.config_string, + algorithm_config_string=worker_task_pb.algorithm_config_string, + ) + + async def exec(self, executor: Executor): + python_executable = sys.executable + cmd = " ".join( + [ + f"{python_executable} -m {self._entrypoint}", + f"--job-id={self.job_id}", + f"--task-id={self.web_task_id}", + f"--scheduler-ep={self._scheduler_ep}", + f"--trainer-id={self._trainer_id}", + f"--trainer-ep={self._trainer_ep}", + f"--startup-program=startup_program", + f"--main-program=main_program", + f"--send-program=send_program", + f"--recv-program=recv_program", + f"--feed-names=feed_names", + f"--target-names=target_names", + f"--feeds=feeds", + f"--strategy=strategy", + f"--config config.json", + f"--algorithm-config algorithm_config.json" + f">{executor.stdout} 2>{executor.stderr}", + ] + ) + with executor.working_dir.joinpath("main_program").open("wb") as f: + f.write(self._main_program) + with executor.working_dir.joinpath("startup_program").open("wb") as f: + f.write(self._startup_program) + with executor.working_dir.joinpath("send_program").open("wb") as f: + f.write(self._send_program) + with executor.working_dir.joinpath("recv_program").open("wb") as f: + f.write(self._recv_program) + with executor.working_dir.joinpath("feed_names").open("wb") as f: + f.write(self._feed_names) + with executor.working_dir.joinpath("target_names").open("wb") as f: + f.write(self._target_names) + with executor.working_dir.joinpath("strategy").open("wb") as f: + f.write(self._strategy) + with executor.working_dir.joinpath("feeds").open("wb") as f: + f.write(self._feeds) + with executor.working_dir.joinpath("config.json").open("w") as f: + f.write(self._config_string) + with executor.working_dir.joinpath("algorithm_config.json").open("w") as f: + f.write(self._algorithm_config_string) + returncode,pid = await executor.execute(cmd) + if returncode is None or returncode != 0: + raise VisualFLWorkerException( + f"execute task: {self.task_id} failed, return code: {returncode}" + ) diff --git a/VisualFL/visualfl/paddle_fl/trainer/__init__.py b/VisualFL/visualfl/paddle_fl/trainer/__init__.py new file mode 100644 index 000000000..74b684484 --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/trainer/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from visualfl.paddle_fl.trainer._trainer import FedAvgTrainer + +__all__ = ["FedAvgTrainer"] diff --git a/VisualFL/visualfl/paddle_fl/trainer/_trainer.py b/VisualFL/visualfl/paddle_fl/trainer/_trainer.py new file mode 100644 index 000000000..290c2d10d --- /dev/null +++ b/VisualFL/visualfl/paddle_fl/trainer/_trainer.py @@ -0,0 +1,172 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import pickle +from typing import Optional + +import grpc +from paddle import fluid + +from visualfl.utils.logger import Logger +from visualfl.protobuf import scheduler_pb2_grpc, scheduler_pb2 +from paddle_fl.core.master.fl_job import FLJobBase + + +class TrainerSchedulerAgent(Logger): + def __init__(self, worker_name, scheduler_ep): + self._worker_name = worker_name + self._scheduler_ep = scheduler_ep + + self._channel: Optional[grpc.Channel] = None + self._stub: Optional[scheduler_pb2_grpc.SchedulerStub] = None + + def start_channel(self): + self._channel = grpc.insecure_channel(self._scheduler_ep) + self._stub = scheduler_pb2_grpc.SchedulerStub(self._channel) + + self.debug(f"waiting channel ready") + future = grpc.channel_ready_future(self._channel) + future.result() + self.debug(f"channel ready") + return self + + def init_worker(self): + self.debug(f"start to init") + self._stub.Init(scheduler_pb2.Init.REQ(name=self._worker_name)) + self.debug(f"init success") + + def join(self, step: int): + self.debug("start to join") + response = self._stub.WorkerJoin( + scheduler_pb2.WorkerJoin.REQ(name=self._worker_name, step=step) + ) + self.debug(f"join success: {response.status}") + return response.status == scheduler_pb2.WorkerJoin.ACCEPT + + def finish(self): + self.debug("start to finish") + status = self._stub.WorkerFinish( + scheduler_pb2.WorkerFinish.REQ(name=self._worker_name) + ) + self.debug(f"finish success: {status}") + return status == scheduler_pb2.WorkerFinish.DONE + + def close(self): + self._channel.close() + + +class FedAvgTrainer(FLJobBase): + def __init__(self, scheduler_ep, trainer_ep): + self._logger = logging.getLogger("FLTrainer") + super(FedAvgTrainer, self).__init__() + self._scheduler_ep = scheduler_ep + self._trainer_ep = trainer_ep + + self.scheduler_agent: Optional[TrainerSchedulerAgent] = None + self.exe: Optional[fluid.Executor] = None + self.cur_step = 0 + + def start(self, place): + self.scheduler_agent = TrainerSchedulerAgent( + scheduler_ep=self._scheduler_ep, worker_name=self._trainer_ep + ) + self.scheduler_agent.start_channel() + self.scheduler_agent.init_worker() + + self.exe = fluid.Executor(place) + self.exe.run(self._startup_program) + + def load_job( + self, + startup_program: str, + main_program: str, + send_program: str, + recv_program: str, + feed_names: str, + target_names: str, + strategy: str, + ): + self._startup_program = self._load_program(startup_program) + self._main_program = self._load_program(main_program) + self._send_program = self._load_program(send_program) + self._recv_program = self._load_program(recv_program) + + self._step = self._load_strategy(strategy)._inner_step + self._feed_names = self._load_str_list(feed_names) + self._target_names = self._load_str_list(target_names) + + def load_feed_list(self, feeds_path): + data = [] + with open(feeds_path, "rb") as f: + num = pickle.load(f) + for _ in range(num): + data.append(fluid.data(**pickle.load(f))) + return data + + @staticmethod + def _load_strategy(input_file): + + return pickle.load(open(input_file, "rb")) + + def reset(self): + self.cur_step = 0 + + def run_with_epoch(self, reader, feeder, fetch, num_epoch): + self._logger.debug("begin to run recv program") + self.exe.run(self._recv_program) + self._logger.debug("recv done") + epoch = 0 + for i in range(num_epoch): + for data in reader(): + acc = self.exe.run( + self._main_program, feed=feeder.feed(data), fetch_list=fetch + ) + print(f"acc: {acc}") + self.cur_step += 1 + epoch += 1 + self._logger.debug("begin to run send program") + self.exe.run(self._send_program) + + def run(self, feed, fetch): + self._logger.debug( + f"begin to run FedAvgTrainer, cur_step={self.cur_step}, inner_step={self._step}" + ) + if self.cur_step % self._step == 0: + self._logger.debug("run recv program start") + self.exe.run(self._recv_program) + self._logger.debug("run recv program done") + + self._logger.debug("run main program start") + loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch) + self._logger.debug("run main program done") + + if self.cur_step % self._step == 0: + self._logger.debug("run send program start") + self.exe.run(self._send_program) + self._logger.debug("run send program done") + self.cur_step += 1 + return loss + + def save_model(self, model_path): + fluid.io.save_inference_model( + dirname=model_path, + feeded_var_names=self._feed_names, + target_vars=[ + self._main_program.global_block().var(fetch_var_name) + for fetch_var_name in self._target_names + ], + executor=self.exe, + main_program=self._main_program, + ) diff --git a/VisualFL/visualfl/protobuf/__init__.py b/VisualFL/visualfl/protobuf/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/protobuf/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/protobuf/cluster_pb2.py b/VisualFL/visualfl/protobuf/cluster_pb2.py new file mode 100644 index 000000000..940cc280f --- /dev/null +++ b/VisualFL/visualfl/protobuf/cluster_pb2.py @@ -0,0 +1,785 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: cluster.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +import visualfl.protobuf.job_pb2 as job__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='cluster.proto', + package='visualfl', + syntax='proto3', + serialized_options=None, + serialized_pb=b'\n\rcluster.proto\x12\x08visualfl\x1a\x19google/protobuf/any.proto\x1a\tjob.proto\"\x8b\x02\n\x06\x45nroll\x1a\x64\n\x03REQ\x12\x11\n\tworker_id\x18\x01 \x01(\t\x12\x11\n\tworker_ip\x18\x02 \x01(\t\x12\x11\n\tmax_tasks\x18\x03 \x01(\x05\x12\x12\n\nport_start\x18\x04 \x01(\x05\x12\x10\n\x08port_end\x18\x05 \x01(\x05\x1aL\n\x03REP\x12\'\n\x06status\x18\x01 \x01(\x0e\x32\x17.visualfl.Enroll.Status\x12\x1c\n\x04task\x18\x02 \x01(\x0b\x32\x0e.visualfl.Task\"M\n\x06Status\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x12\n\x0e\x45NROLL_SUCCESS\x10\x01\x12\x12\n\x0e\x41LREADY_ENROLL\x10\x02\x12\x0e\n\nTASK_READY\x10\x03\"\x92\x03\n\x0cUpdateStatus\x1a\xc5\x01\n\x03REQ\x12\x11\n\tworker_id\x18\x01 \x01(\t\x12\x0e\n\x06job_id\x18\x02 \x01(\t\x12\x0f\n\x07task_id\x18\x03 \x01(\t\x12\x36\n\x0btask_status\x18\x04 \x01(\x0e\x32!.visualfl.UpdateStatus.TaskStatus\x12\x14\n\x0c\x65xception_id\x18\x05 \x01(\t\x12\x11\n\texception\x18\x06 \x01(\t\x12)\n\x0b\x65xec_result\x18\x07 \x01(\x0b\x32\x14.google.protobuf.Any\x1a\x34\n\x03REP\x12-\n\x06status\x18\x01 \x01(\x0e\x32\x1d.visualfl.UpdateStatus.Status\"T\n\nTaskStatus\x12\x10\n\x0cTASK_UNKNOWN\x10\x00\x12\x0f\n\x0bTASK_CANCEL\x10\x01\x12\x12\n\x0eTASK_EXCEPTION\x10\x02\x12\x0f\n\x0bTASK_FINISH\x10\x03\".\n\x06Status\x12\x0b\n\x07UNKNOWN\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07SUCCESS\x10\x02\"\x95\x01\n\nTaskSubmit\x1a#\n\x03REQ\x12\x1c\n\x04task\x18\x01 \x01(\x0b\x32\x0e.visualfl.Task\x1a\x32\n\x03REP\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.visualfl.TaskSubmit.Status\".\n\x06Status\x12\x0b\n\x07UNKNOWN\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07SUCCESS\x10\x02\"\xc6\x01\n\x13TaskResourceRequire\x1a\x1c\n\x03REQ\x12\x15\n\rnum_endpoints\x18\x01 \x01(\x05\x1a\x61\n\x03REP\x12\x34\n\x06status\x18\x01 \x01(\x0e\x32$.visualfl.TaskResourceRequire.Status\x12\x11\n\tworker_id\x18\x02 \x01(\t\x12\x11\n\tendpoints\x18\x03 \x03(\t\".\n\x06Status\x12\x0b\n\x07UNKNOWN\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07SUCCESS\x10\x02\x32\xbb\x02\n\x0e\x43lusterManager\x12\x38\n\x06\x45nroll\x12\x14.visualfl.Enroll.REQ\x1a\x14.visualfl.Enroll.REP\"\x00\x30\x01\x12L\n\x10UpdateTaskStatus\x12\x1a.visualfl.UpdateStatus.REQ\x1a\x1a.visualfl.UpdateStatus.REP\"\x00\x12\x42\n\nTaskSubmit\x12\x18.visualfl.TaskSubmit.REQ\x1a\x18.visualfl.TaskSubmit.REP\"\x00\x12]\n\x13TaskResourceRequire\x12!.visualfl.TaskResourceRequire.REQ\x1a!.visualfl.TaskResourceRequire.REP\"\x00\x62\x06proto3' + , + dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,job__pb2.DESCRIPTOR,]) + + + +_ENROLL_STATUS = _descriptor.EnumDescriptor( + name='Status', + full_name='visualfl.Enroll.Status', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='UNKNOWN', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='ENROLL_SUCCESS', index=1, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='ALREADY_ENROLL', index=2, number=2, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='TASK_READY', index=3, number=3, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=256, + serialized_end=333, +) +_sym_db.RegisterEnumDescriptor(_ENROLL_STATUS) + +_UPDATESTATUS_TASKSTATUS = _descriptor.EnumDescriptor( + name='TaskStatus', + full_name='visualfl.UpdateStatus.TaskStatus', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='TASK_UNKNOWN', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='TASK_CANCEL', index=1, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='TASK_EXCEPTION', index=2, number=2, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='TASK_FINISH', index=3, number=3, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=606, + serialized_end=690, +) +_sym_db.RegisterEnumDescriptor(_UPDATESTATUS_TASKSTATUS) + +_UPDATESTATUS_STATUS = _descriptor.EnumDescriptor( + name='Status', + full_name='visualfl.UpdateStatus.Status', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='UNKNOWN', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FAILED', index=1, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='SUCCESS', index=2, number=2, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=692, + serialized_end=738, +) +_sym_db.RegisterEnumDescriptor(_UPDATESTATUS_STATUS) + +_TASKSUBMIT_STATUS = _descriptor.EnumDescriptor( + name='Status', + full_name='visualfl.TaskSubmit.Status', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='UNKNOWN', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FAILED', index=1, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='SUCCESS', index=2, number=2, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=692, + serialized_end=738, +) +_sym_db.RegisterEnumDescriptor(_TASKSUBMIT_STATUS) + +_TASKRESOURCEREQUIRE_STATUS = _descriptor.EnumDescriptor( + name='Status', + full_name='visualfl.TaskResourceRequire.Status', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='UNKNOWN', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FAILED', index=1, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='SUCCESS', index=2, number=2, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=692, + serialized_end=738, +) +_sym_db.RegisterEnumDescriptor(_TASKRESOURCEREQUIRE_STATUS) + + +_ENROLL_REQ = _descriptor.Descriptor( + name='REQ', + full_name='visualfl.Enroll.REQ', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='worker_id', full_name='visualfl.Enroll.REQ.worker_id', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='worker_ip', full_name='visualfl.Enroll.REQ.worker_ip', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='max_tasks', full_name='visualfl.Enroll.REQ.max_tasks', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='port_start', full_name='visualfl.Enroll.REQ.port_start', index=3, + number=4, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='port_end', full_name='visualfl.Enroll.REQ.port_end', index=4, + number=5, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=76, + serialized_end=176, +) + +_ENROLL_REP = _descriptor.Descriptor( + name='REP', + full_name='visualfl.Enroll.REP', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='visualfl.Enroll.REP.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='task', full_name='visualfl.Enroll.REP.task', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=178, + serialized_end=254, +) + +_ENROLL = _descriptor.Descriptor( + name='Enroll', + full_name='visualfl.Enroll', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[_ENROLL_REQ, _ENROLL_REP, ], + enum_types=[ + _ENROLL_STATUS, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=66, + serialized_end=333, +) + + +_UPDATESTATUS_REQ = _descriptor.Descriptor( + name='REQ', + full_name='visualfl.UpdateStatus.REQ', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='worker_id', full_name='visualfl.UpdateStatus.REQ.worker_id', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='job_id', full_name='visualfl.UpdateStatus.REQ.job_id', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='task_id', full_name='visualfl.UpdateStatus.REQ.task_id', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='task_status', full_name='visualfl.UpdateStatus.REQ.task_status', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='exception_id', full_name='visualfl.UpdateStatus.REQ.exception_id', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='exception', full_name='visualfl.UpdateStatus.REQ.exception', index=5, + number=6, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='exec_result', full_name='visualfl.UpdateStatus.REQ.exec_result', index=6, + number=7, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=353, + serialized_end=550, +) + +_UPDATESTATUS_REP = _descriptor.Descriptor( + name='REP', + full_name='visualfl.UpdateStatus.REP', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='visualfl.UpdateStatus.REP.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=552, + serialized_end=604, +) + +_UPDATESTATUS = _descriptor.Descriptor( + name='UpdateStatus', + full_name='visualfl.UpdateStatus', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[_UPDATESTATUS_REQ, _UPDATESTATUS_REP, ], + enum_types=[ + _UPDATESTATUS_TASKSTATUS, + _UPDATESTATUS_STATUS, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=336, + serialized_end=738, +) + + +_TASKSUBMIT_REQ = _descriptor.Descriptor( + name='REQ', + full_name='visualfl.TaskSubmit.REQ', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='task', full_name='visualfl.TaskSubmit.REQ.task', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=755, + serialized_end=790, +) + +_TASKSUBMIT_REP = _descriptor.Descriptor( + name='REP', + full_name='visualfl.TaskSubmit.REP', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='visualfl.TaskSubmit.REP.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=792, + serialized_end=842, +) + +_TASKSUBMIT = _descriptor.Descriptor( + name='TaskSubmit', + full_name='visualfl.TaskSubmit', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[_TASKSUBMIT_REQ, _TASKSUBMIT_REP, ], + enum_types=[ + _TASKSUBMIT_STATUS, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=741, + serialized_end=890, +) + + +_TASKRESOURCEREQUIRE_REQ = _descriptor.Descriptor( + name='REQ', + full_name='visualfl.TaskResourceRequire.REQ', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='num_endpoints', full_name='visualfl.TaskResourceRequire.REQ.num_endpoints', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=916, + serialized_end=944, +) + +_TASKRESOURCEREQUIRE_REP = _descriptor.Descriptor( + name='REP', + full_name='visualfl.TaskResourceRequire.REP', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='visualfl.TaskResourceRequire.REP.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='worker_id', full_name='visualfl.TaskResourceRequire.REP.worker_id', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='endpoints', full_name='visualfl.TaskResourceRequire.REP.endpoints', index=2, + number=3, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=946, + serialized_end=1043, +) + +_TASKRESOURCEREQUIRE = _descriptor.Descriptor( + name='TaskResourceRequire', + full_name='visualfl.TaskResourceRequire', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[_TASKRESOURCEREQUIRE_REQ, _TASKRESOURCEREQUIRE_REP, ], + enum_types=[ + _TASKRESOURCEREQUIRE_STATUS, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=893, + serialized_end=1091, +) + +_ENROLL_REQ.containing_type = _ENROLL +_ENROLL_REP.fields_by_name['status'].enum_type = _ENROLL_STATUS +_ENROLL_REP.fields_by_name['task'].message_type = job__pb2._TASK +_ENROLL_REP.containing_type = _ENROLL +_ENROLL_STATUS.containing_type = _ENROLL +_UPDATESTATUS_REQ.fields_by_name['task_status'].enum_type = _UPDATESTATUS_TASKSTATUS +_UPDATESTATUS_REQ.fields_by_name['exec_result'].message_type = google_dot_protobuf_dot_any__pb2._ANY +_UPDATESTATUS_REQ.containing_type = _UPDATESTATUS +_UPDATESTATUS_REP.fields_by_name['status'].enum_type = _UPDATESTATUS_STATUS +_UPDATESTATUS_REP.containing_type = _UPDATESTATUS +_UPDATESTATUS_TASKSTATUS.containing_type = _UPDATESTATUS +_UPDATESTATUS_STATUS.containing_type = _UPDATESTATUS +_TASKSUBMIT_REQ.fields_by_name['task'].message_type = job__pb2._TASK +_TASKSUBMIT_REQ.containing_type = _TASKSUBMIT +_TASKSUBMIT_REP.fields_by_name['status'].enum_type = _TASKSUBMIT_STATUS +_TASKSUBMIT_REP.containing_type = _TASKSUBMIT +_TASKSUBMIT_STATUS.containing_type = _TASKSUBMIT +_TASKRESOURCEREQUIRE_REQ.containing_type = _TASKRESOURCEREQUIRE +_TASKRESOURCEREQUIRE_REP.fields_by_name['status'].enum_type = _TASKRESOURCEREQUIRE_STATUS +_TASKRESOURCEREQUIRE_REP.containing_type = _TASKRESOURCEREQUIRE +_TASKRESOURCEREQUIRE_STATUS.containing_type = _TASKRESOURCEREQUIRE +DESCRIPTOR.message_types_by_name['Enroll'] = _ENROLL +DESCRIPTOR.message_types_by_name['UpdateStatus'] = _UPDATESTATUS +DESCRIPTOR.message_types_by_name['TaskSubmit'] = _TASKSUBMIT +DESCRIPTOR.message_types_by_name['TaskResourceRequire'] = _TASKRESOURCEREQUIRE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Enroll = _reflection.GeneratedProtocolMessageType('Enroll', (_message.Message,), { + + 'REQ' : _reflection.GeneratedProtocolMessageType('REQ', (_message.Message,), { + 'DESCRIPTOR' : _ENROLL_REQ, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.Enroll.REQ) + }) + , + + 'REP' : _reflection.GeneratedProtocolMessageType('REP', (_message.Message,), { + 'DESCRIPTOR' : _ENROLL_REP, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.Enroll.REP) + }) + , + 'DESCRIPTOR' : _ENROLL, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.Enroll) + }) +_sym_db.RegisterMessage(Enroll) +_sym_db.RegisterMessage(Enroll.REQ) +_sym_db.RegisterMessage(Enroll.REP) + +UpdateStatus = _reflection.GeneratedProtocolMessageType('UpdateStatus', (_message.Message,), { + + 'REQ' : _reflection.GeneratedProtocolMessageType('REQ', (_message.Message,), { + 'DESCRIPTOR' : _UPDATESTATUS_REQ, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.UpdateStatus.REQ) + }) + , + + 'REP' : _reflection.GeneratedProtocolMessageType('REP', (_message.Message,), { + 'DESCRIPTOR' : _UPDATESTATUS_REP, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.UpdateStatus.REP) + }) + , + 'DESCRIPTOR' : _UPDATESTATUS, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.UpdateStatus) + }) +_sym_db.RegisterMessage(UpdateStatus) +_sym_db.RegisterMessage(UpdateStatus.REQ) +_sym_db.RegisterMessage(UpdateStatus.REP) + +TaskSubmit = _reflection.GeneratedProtocolMessageType('TaskSubmit', (_message.Message,), { + + 'REQ' : _reflection.GeneratedProtocolMessageType('REQ', (_message.Message,), { + 'DESCRIPTOR' : _TASKSUBMIT_REQ, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.TaskSubmit.REQ) + }) + , + + 'REP' : _reflection.GeneratedProtocolMessageType('REP', (_message.Message,), { + 'DESCRIPTOR' : _TASKSUBMIT_REP, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.TaskSubmit.REP) + }) + , + 'DESCRIPTOR' : _TASKSUBMIT, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.TaskSubmit) + }) +_sym_db.RegisterMessage(TaskSubmit) +_sym_db.RegisterMessage(TaskSubmit.REQ) +_sym_db.RegisterMessage(TaskSubmit.REP) + +TaskResourceRequire = _reflection.GeneratedProtocolMessageType('TaskResourceRequire', (_message.Message,), { + + 'REQ' : _reflection.GeneratedProtocolMessageType('REQ', (_message.Message,), { + 'DESCRIPTOR' : _TASKRESOURCEREQUIRE_REQ, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.TaskResourceRequire.REQ) + }) + , + + 'REP' : _reflection.GeneratedProtocolMessageType('REP', (_message.Message,), { + 'DESCRIPTOR' : _TASKRESOURCEREQUIRE_REP, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.TaskResourceRequire.REP) + }) + , + 'DESCRIPTOR' : _TASKRESOURCEREQUIRE, + '__module__' : 'cluster_pb2' + # @@protoc_insertion_point(class_scope:visualfl.TaskResourceRequire) + }) +_sym_db.RegisterMessage(TaskResourceRequire) +_sym_db.RegisterMessage(TaskResourceRequire.REQ) +_sym_db.RegisterMessage(TaskResourceRequire.REP) + + + +_CLUSTERMANAGER = _descriptor.ServiceDescriptor( + name='ClusterManager', + full_name='visualfl.ClusterManager', + file=DESCRIPTOR, + index=0, + serialized_options=None, + serialized_start=1094, + serialized_end=1409, + methods=[ + _descriptor.MethodDescriptor( + name='Enroll', + full_name='visualfl.ClusterManager.Enroll', + index=0, + containing_service=None, + input_type=_ENROLL_REQ, + output_type=_ENROLL_REP, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='UpdateTaskStatus', + full_name='visualfl.ClusterManager.UpdateTaskStatus', + index=1, + containing_service=None, + input_type=_UPDATESTATUS_REQ, + output_type=_UPDATESTATUS_REP, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='TaskSubmit', + full_name='visualfl.ClusterManager.TaskSubmit', + index=2, + containing_service=None, + input_type=_TASKSUBMIT_REQ, + output_type=_TASKSUBMIT_REP, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='TaskResourceRequire', + full_name='visualfl.ClusterManager.TaskResourceRequire', + index=3, + containing_service=None, + input_type=_TASKRESOURCEREQUIRE_REQ, + output_type=_TASKRESOURCEREQUIRE_REP, + serialized_options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_CLUSTERMANAGER) + +DESCRIPTOR.services_by_name['ClusterManager'] = _CLUSTERMANAGER + +# @@protoc_insertion_point(module_scope) diff --git a/VisualFL/visualfl/protobuf/cluster_pb2_grpc.py b/VisualFL/visualfl/protobuf/cluster_pb2_grpc.py new file mode 100644 index 000000000..ee58733b3 --- /dev/null +++ b/VisualFL/visualfl/protobuf/cluster_pb2_grpc.py @@ -0,0 +1,180 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +import visualfl.protobuf.cluster_pb2 as cluster__pb2 + + +class ClusterManagerStub(object): + """service in cluster manager called by worker + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Enroll = channel.unary_stream( + '/visualfl.ClusterManager/Enroll', + request_serializer=cluster__pb2.Enroll.REQ.SerializeToString, + response_deserializer=cluster__pb2.Enroll.REP.FromString, + ) + self.UpdateTaskStatus = channel.unary_unary( + '/visualfl.ClusterManager/UpdateTaskStatus', + request_serializer=cluster__pb2.UpdateStatus.REQ.SerializeToString, + response_deserializer=cluster__pb2.UpdateStatus.REP.FromString, + ) + self.TaskSubmit = channel.unary_unary( + '/visualfl.ClusterManager/TaskSubmit', + request_serializer=cluster__pb2.TaskSubmit.REQ.SerializeToString, + response_deserializer=cluster__pb2.TaskSubmit.REP.FromString, + ) + self.TaskResourceRequire = channel.unary_unary( + '/visualfl.ClusterManager/TaskResourceRequire', + request_serializer=cluster__pb2.TaskResourceRequire.REQ.SerializeToString, + response_deserializer=cluster__pb2.TaskResourceRequire.REP.FromString, + ) + + +class ClusterManagerServicer(object): + """service in cluster manager called by worker + """ + + def Enroll(self, request, context): + """service for worker: enroll and fetch tasks + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateTaskStatus(self, request, context): + """service for worker: update status or heartbeat + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TaskSubmit(self, request, context): + """service for master: submit task to cluster + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TaskResourceRequire(self, request, context): + """Missing associated documentation comment in .proto file""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_ClusterManagerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Enroll': grpc.unary_stream_rpc_method_handler( + servicer.Enroll, + request_deserializer=cluster__pb2.Enroll.REQ.FromString, + response_serializer=cluster__pb2.Enroll.REP.SerializeToString, + ), + 'UpdateTaskStatus': grpc.unary_unary_rpc_method_handler( + servicer.UpdateTaskStatus, + request_deserializer=cluster__pb2.UpdateStatus.REQ.FromString, + response_serializer=cluster__pb2.UpdateStatus.REP.SerializeToString, + ), + 'TaskSubmit': grpc.unary_unary_rpc_method_handler( + servicer.TaskSubmit, + request_deserializer=cluster__pb2.TaskSubmit.REQ.FromString, + response_serializer=cluster__pb2.TaskSubmit.REP.SerializeToString, + ), + 'TaskResourceRequire': grpc.unary_unary_rpc_method_handler( + servicer.TaskResourceRequire, + request_deserializer=cluster__pb2.TaskResourceRequire.REQ.FromString, + response_serializer=cluster__pb2.TaskResourceRequire.REP.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'visualfl.ClusterManager', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class ClusterManager(object): + """service in cluster manager called by worker + """ + + @staticmethod + def Enroll(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/visualfl.ClusterManager/Enroll', + cluster__pb2.Enroll.REQ.SerializeToString, + cluster__pb2.Enroll.REP.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateTaskStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/visualfl.ClusterManager/UpdateTaskStatus', + cluster__pb2.UpdateStatus.REQ.SerializeToString, + cluster__pb2.UpdateStatus.REP.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TaskSubmit(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/visualfl.ClusterManager/TaskSubmit', + cluster__pb2.TaskSubmit.REQ.SerializeToString, + cluster__pb2.TaskSubmit.REP.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TaskResourceRequire(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/visualfl.ClusterManager/TaskResourceRequire', + cluster__pb2.TaskResourceRequire.REQ.SerializeToString, + cluster__pb2.TaskResourceRequire.REP.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/VisualFL/visualfl/protobuf/fl_job_pb2.py b/VisualFL/visualfl/protobuf/fl_job_pb2.py new file mode 100644 index 000000000..fdff048a2 --- /dev/null +++ b/VisualFL/visualfl/protobuf/fl_job_pb2.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: fl_job.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='fl_job.proto', + package='visualfl', + syntax='proto3', + serialized_options=None, + serialized_pb=b'\n\x0c\x66l_job.proto\x12\x08visualfl\"t\n\x16PaddleFLAggregatorTask\x12\x14\n\x0cscheduler_ep\x18\x01 \x01(\t\x12\x14\n\x0cmain_program\x18\x02 \x01(\x0c\x12\x17\n\x0fstartup_program\x18\x03 \x01(\x0c\x12\x15\n\rconfig_string\x18\x04 \x01(\t\"\xc4\x02\n\x12PaddleFLWorkerTask\x12\x14\n\x0cscheduler_ep\x18\x01 \x01(\t\x12\x12\n\ntrainer_id\x18\x02 \x01(\r\x12\x12\n\ntrainer_ep\x18\x03 \x01(\t\x12\x12\n\nentrypoint\x18\x04 \x01(\t\x12\x14\n\x0cmain_program\x18\x05 \x01(\x0c\x12\x17\n\x0fstartup_program\x18\x06 \x01(\x0c\x12\x14\n\x0csend_program\x18\x07 \x01(\x0c\x12\x14\n\x0crecv_program\x18\x08 \x01(\x0c\x12\x12\n\nfeed_names\x18\t \x01(\x0c\x12\x14\n\x0ctarget_names\x18\n \x01(\x0c\x12\x10\n\x08strategy\x18\x0b \x01(\x0c\x12\r\n\x05\x66\x65\x65\x64s\x18\x0c \x01(\x0c\x12\x15\n\rconfig_string\x18\r \x01(\t\x12\x1f\n\x17\x61lgorithm_config_string\x18\x0e \x01(\tb\x06proto3' +) + + + + +_PADDLEFLAGGREGATORTASK = _descriptor.Descriptor( + name='PaddleFLAggregatorTask', + full_name='visualfl.PaddleFLAggregatorTask', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='scheduler_ep', full_name='visualfl.PaddleFLAggregatorTask.scheduler_ep', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='main_program', full_name='visualfl.PaddleFLAggregatorTask.main_program', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='startup_program', full_name='visualfl.PaddleFLAggregatorTask.startup_program', index=2, + number=3, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='config_string', full_name='visualfl.PaddleFLAggregatorTask.config_string', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=26, + serialized_end=142, +) + + +_PADDLEFLWORKERTASK = _descriptor.Descriptor( + name='PaddleFLWorkerTask', + full_name='visualfl.PaddleFLWorkerTask', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='scheduler_ep', full_name='visualfl.PaddleFLWorkerTask.scheduler_ep', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='trainer_id', full_name='visualfl.PaddleFLWorkerTask.trainer_id', index=1, + number=2, type=13, cpp_type=3, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='trainer_ep', full_name='visualfl.PaddleFLWorkerTask.trainer_ep', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='entrypoint', full_name='visualfl.PaddleFLWorkerTask.entrypoint', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='main_program', full_name='visualfl.PaddleFLWorkerTask.main_program', index=4, + number=5, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='startup_program', full_name='visualfl.PaddleFLWorkerTask.startup_program', index=5, + number=6, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='send_program', full_name='visualfl.PaddleFLWorkerTask.send_program', index=6, + number=7, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='recv_program', full_name='visualfl.PaddleFLWorkerTask.recv_program', index=7, + number=8, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='feed_names', full_name='visualfl.PaddleFLWorkerTask.feed_names', index=8, + number=9, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='target_names', full_name='visualfl.PaddleFLWorkerTask.target_names', index=9, + number=10, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='strategy', full_name='visualfl.PaddleFLWorkerTask.strategy', index=10, + number=11, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='feeds', full_name='visualfl.PaddleFLWorkerTask.feeds', index=11, + number=12, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='config_string', full_name='visualfl.PaddleFLWorkerTask.config_string', index=12, + number=13, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='algorithm_config_string', full_name='visualfl.PaddleFLWorkerTask.algorithm_config_string', index=13, + number=14, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=145, + serialized_end=469, +) + +DESCRIPTOR.message_types_by_name['PaddleFLAggregatorTask'] = _PADDLEFLAGGREGATORTASK +DESCRIPTOR.message_types_by_name['PaddleFLWorkerTask'] = _PADDLEFLWORKERTASK +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +PaddleFLAggregatorTask = _reflection.GeneratedProtocolMessageType('PaddleFLAggregatorTask', (_message.Message,), { + 'DESCRIPTOR' : _PADDLEFLAGGREGATORTASK, + '__module__' : 'fl_job_pb2' + # @@protoc_insertion_point(class_scope:visualfl.PaddleFLAggregatorTask) + }) +_sym_db.RegisterMessage(PaddleFLAggregatorTask) + +PaddleFLWorkerTask = _reflection.GeneratedProtocolMessageType('PaddleFLWorkerTask', (_message.Message,), { + 'DESCRIPTOR' : _PADDLEFLWORKERTASK, + '__module__' : 'fl_job_pb2' + # @@protoc_insertion_point(class_scope:visualfl.PaddleFLWorkerTask) + }) +_sym_db.RegisterMessage(PaddleFLWorkerTask) + + +# @@protoc_insertion_point(module_scope) diff --git a/VisualFL/visualfl/protobuf/job_pb2.py b/VisualFL/visualfl/protobuf/job_pb2.py new file mode 100644 index 000000000..1a6683092 --- /dev/null +++ b/VisualFL/visualfl/protobuf/job_pb2.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: job.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='job.proto', + package='visualfl', + syntax='proto3', + serialized_options=None, + serialized_pb=b'\n\tjob.proto\x12\x08visualfl\x1a\x19google/protobuf/any.proto\"\x85\x01\n\x04Task\x12\x0e\n\x06job_id\x18\x01 \x01(\t\x12\x0f\n\x07task_id\x18\x02 \x01(\t\x12\x13\n\x0bweb_task_id\x18\x03 \x01(\t\x12\x11\n\ttask_type\x18\x04 \x01(\t\x12\"\n\x04task\x18\x05 \x01(\x0b\x32\x14.google.protobuf.Any\x12\x10\n\x08\x61ssignee\x18\x06 \x01(\tb\x06proto3' + , + dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,]) + + + + +_TASK = _descriptor.Descriptor( + name='Task', + full_name='visualfl.Task', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='job_id', full_name='visualfl.Task.job_id', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='task_id', full_name='visualfl.Task.task_id', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='web_task_id', full_name='visualfl.Task.web_task_id', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='task_type', full_name='visualfl.Task.task_type', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='task', full_name='visualfl.Task.task', index=4, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='assignee', full_name='visualfl.Task.assignee', index=5, + number=6, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=51, + serialized_end=184, +) + +_TASK.fields_by_name['task'].message_type = google_dot_protobuf_dot_any__pb2._ANY +DESCRIPTOR.message_types_by_name['Task'] = _TASK +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), { + 'DESCRIPTOR' : _TASK, + '__module__' : 'job_pb2' + # @@protoc_insertion_point(class_scope:visualfl.Task) + }) +_sym_db.RegisterMessage(Task) + + +# @@protoc_insertion_point(module_scope) diff --git a/VisualFL/visualfl/protobuf/proto/cluster.proto b/VisualFL/visualfl/protobuf/proto/cluster.proto new file mode 100755 index 000000000..073ab89aa --- /dev/null +++ b/VisualFL/visualfl/protobuf/proto/cluster.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; +import "google/protobuf/any.proto"; +import "job.proto"; + +package visualfl; + + +//service in cluster manager called by worker +service ClusterManager{ + // service for worker: enroll and fetch tasks + rpc Enroll(Enroll.REQ) returns (stream Enroll.REP) {} + // service for worker: update status or heartbeat + rpc UpdateTaskStatus(UpdateStatus.REQ) returns (UpdateStatus.REP) {} + // service for master: submit task to cluster + rpc TaskSubmit(TaskSubmit.REQ) returns (TaskSubmit.REP) {} + rpc TaskResourceRequire(TaskResourceRequire.REQ) returns (TaskResourceRequire.REP) {} +} + +message Enroll { + enum Status { + UNKNOWN = 0; + ENROLL_SUCCESS = 1; + ALREADY_ENROLL = 2; + TASK_READY = 3; + } + message REQ { + string worker_id = 1; + string worker_ip = 2; + int32 max_tasks = 3; + int32 port_start = 4; + int32 port_end = 5; + } + message REP { + Status status = 1; + visualfl.Task task = 2; + } +} + +message UpdateStatus { + enum TaskStatus { + TASK_UNKNOWN = 0; + TASK_CANCEL = 1; + TASK_EXCEPTION = 2; + TASK_FINISH = 3; + } + enum Status { + UNKNOWN = 0; + FAILED = 1; + SUCCESS = 2; + } + message REQ { + string worker_id = 1; + string job_id = 2; + string task_id = 3; + TaskStatus task_status = 4; + string exception_id = 5; + string exception = 6; + google.protobuf.Any exec_result = 7; + } + message REP { + Status status = 1; + } +} + +message TaskSubmit { + enum Status { + UNKNOWN = 0; + FAILED = 1; + SUCCESS = 2; + } + message REQ { + visualfl.Task task = 1; + } + message REP { + Status status = 1; + } +} + + +message TaskResourceRequire { + enum Status { + UNKNOWN = 0; + FAILED = 1; + SUCCESS = 2; + } + message REQ { + int32 num_endpoints = 1; + } + message REP { + Status status = 1; + string worker_id = 2; + repeated string endpoints = 3; + } +} diff --git a/VisualFL/visualfl/protobuf/proto/fl_job.proto b/VisualFL/visualfl/protobuf/proto/fl_job.proto new file mode 100755 index 000000000..3020a50a6 --- /dev/null +++ b/VisualFL/visualfl/protobuf/proto/fl_job.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; +package visualfl; + +message PaddleFLAggregatorTask { + string scheduler_ep = 1; + + bytes main_program = 2; + bytes startup_program = 3; + + string config_string = 4; +} + + +message PaddleFLWorkerTask { + string scheduler_ep = 1; + uint32 trainer_id = 2; + string trainer_ep = 3; + string entrypoint = 4; + + bytes main_program = 5; + bytes startup_program = 6; + bytes send_program = 7; + bytes recv_program = 8; + bytes feed_names = 9; + bytes target_names = 10; + bytes strategy = 11; + bytes feeds = 12; + + string config_string = 13; + string algorithm_config_string = 14; +} + diff --git a/VisualFL/visualfl/protobuf/proto/job.proto b/VisualFL/visualfl/protobuf/proto/job.proto new file mode 100755 index 000000000..8cd646945 --- /dev/null +++ b/VisualFL/visualfl/protobuf/proto/job.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; +import "google/protobuf/any.proto"; + +package visualfl; + +message Task { + string job_id = 1; + string task_id = 2; + string web_task_id = 3; + string task_type = 4; + google.protobuf.Any task = 5; + string assignee = 6; +} + diff --git a/VisualFL/visualfl/protobuf/proto/scheduler.proto b/VisualFL/visualfl/protobuf/proto/scheduler.proto new file mode 100755 index 000000000..635ca47ca --- /dev/null +++ b/VisualFL/visualfl/protobuf/proto/scheduler.proto @@ -0,0 +1,53 @@ +syntax = "proto3"; + +package visualfl; + +service Scheduler { + rpc Init(Init.REQ) returns (Init.REP) {} + rpc WorkerJoin(WorkerJoin.REQ) returns (WorkerJoin.REP) {} + rpc WorkerFinish(WorkerFinish.REQ) returns (WorkerFinish.REP) {} +} + +message Init { + enum Status { + REJECT = 0; + INIT = 1; + } + message REQ { + string name = 1; + } + + message REP { + Status status = 1; + } +} + +message WorkerJoin { + enum Status { + REJECT = 0; + NOT_SELECTED = 1; + ACCEPT = 2; + } + message REQ { + string name = 1; + uint32 step = 2; + } + + message REP { + Status status = 1; + } +} + +message WorkerFinish { + enum Status { + REJECT = 0; + DONE = 1; + } + message REQ { + string name = 1; + } + + message REP { + Status status = 1; + } +} diff --git a/VisualFL/visualfl/protobuf/proto_generate.sh b/VisualFL/visualfl/protobuf/proto_generate.sh new file mode 100755 index 000000000..ddf41909f --- /dev/null +++ b/VisualFL/visualfl/protobuf/proto_generate.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + + +BASEDIR=$(dirname "$0") +cd "$BASEDIR" || exit + +PROTO_DIR="proto" +TARGER_DIR="." + +generate() { + python3 -m grpc_tools.protoc -I./$PROTO_DIR --python_out=./$TARGER_DIR --grpc_python_out=./$TARGER_DIR "$1" +} + +generate_all() { + for proto in "$PROTO_DIR"/*.proto; do + echo "protoc: $proto" + generate "$proto" + done +} + +if [ $# -gt 0 ]; then + generate "$1" +else + generate_all +fi diff --git a/VisualFL/visualfl/protobuf/scheduler_pb2.py b/VisualFL/visualfl/protobuf/scheduler_pb2.py new file mode 100644 index 000000000..587eb412f --- /dev/null +++ b/VisualFL/visualfl/protobuf/scheduler_pb2.py @@ -0,0 +1,502 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: scheduler.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='scheduler.proto', + package='visualfl', + syntax='proto3', + serialized_options=None, + serialized_pb=b'\n\x0fscheduler.proto\x12\x08visualfl\"i\n\x04Init\x1a\x13\n\x03REQ\x12\x0c\n\x04name\x18\x01 \x01(\t\x1a,\n\x03REP\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.visualfl.Init.Status\"\x1e\n\x06Status\x12\n\n\x06REJECT\x10\x00\x12\x08\n\x04INIT\x10\x01\"\x97\x01\n\nWorkerJoin\x1a!\n\x03REQ\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04step\x18\x02 \x01(\r\x1a\x32\n\x03REP\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.visualfl.WorkerJoin.Status\"2\n\x06Status\x12\n\n\x06REJECT\x10\x00\x12\x10\n\x0cNOT_SELECTED\x10\x01\x12\n\n\x06\x41\x43\x43\x45PT\x10\x02\"y\n\x0cWorkerFinish\x1a\x13\n\x03REQ\x12\x0c\n\x04name\x18\x01 \x01(\t\x1a\x34\n\x03REP\x12-\n\x06status\x18\x01 \x01(\x0e\x32\x1d.visualfl.WorkerFinish.Status\"\x1e\n\x06Status\x12\n\n\x06REJECT\x10\x00\x12\x08\n\x04\x44ONE\x10\x01\x32\xcb\x01\n\tScheduler\x12\x30\n\x04Init\x12\x12.visualfl.Init.REQ\x1a\x12.visualfl.Init.REP\"\x00\x12\x42\n\nWorkerJoin\x12\x18.visualfl.WorkerJoin.REQ\x1a\x18.visualfl.WorkerJoin.REP\"\x00\x12H\n\x0cWorkerFinish\x12\x1a.visualfl.WorkerFinish.REQ\x1a\x1a.visualfl.WorkerFinish.REP\"\x00\x62\x06proto3' +) + + + +_INIT_STATUS = _descriptor.EnumDescriptor( + name='Status', + full_name='visualfl.Init.Status', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='REJECT', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='INIT', index=1, number=1, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=104, + serialized_end=134, +) +_sym_db.RegisterEnumDescriptor(_INIT_STATUS) + +_WORKERJOIN_STATUS = _descriptor.EnumDescriptor( + name='Status', + full_name='visualfl.WorkerJoin.Status', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='REJECT', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='NOT_SELECTED', index=1, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='ACCEPT', index=2, number=2, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=238, + serialized_end=288, +) +_sym_db.RegisterEnumDescriptor(_WORKERJOIN_STATUS) + +_WORKERFINISH_STATUS = _descriptor.EnumDescriptor( + name='Status', + full_name='visualfl.WorkerFinish.Status', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='REJECT', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DONE', index=1, number=1, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=381, + serialized_end=411, +) +_sym_db.RegisterEnumDescriptor(_WORKERFINISH_STATUS) + + +_INIT_REQ = _descriptor.Descriptor( + name='REQ', + full_name='visualfl.Init.REQ', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='visualfl.Init.REQ.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=37, + serialized_end=56, +) + +_INIT_REP = _descriptor.Descriptor( + name='REP', + full_name='visualfl.Init.REP', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='visualfl.Init.REP.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=58, + serialized_end=102, +) + +_INIT = _descriptor.Descriptor( + name='Init', + full_name='visualfl.Init', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[_INIT_REQ, _INIT_REP, ], + enum_types=[ + _INIT_STATUS, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=29, + serialized_end=134, +) + + +_WORKERJOIN_REQ = _descriptor.Descriptor( + name='REQ', + full_name='visualfl.WorkerJoin.REQ', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='visualfl.WorkerJoin.REQ.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='step', full_name='visualfl.WorkerJoin.REQ.step', index=1, + number=2, type=13, cpp_type=3, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=151, + serialized_end=184, +) + +_WORKERJOIN_REP = _descriptor.Descriptor( + name='REP', + full_name='visualfl.WorkerJoin.REP', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='visualfl.WorkerJoin.REP.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=186, + serialized_end=236, +) + +_WORKERJOIN = _descriptor.Descriptor( + name='WorkerJoin', + full_name='visualfl.WorkerJoin', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[_WORKERJOIN_REQ, _WORKERJOIN_REP, ], + enum_types=[ + _WORKERJOIN_STATUS, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=137, + serialized_end=288, +) + + +_WORKERFINISH_REQ = _descriptor.Descriptor( + name='REQ', + full_name='visualfl.WorkerFinish.REQ', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='visualfl.WorkerFinish.REQ.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=37, + serialized_end=56, +) + +_WORKERFINISH_REP = _descriptor.Descriptor( + name='REP', + full_name='visualfl.WorkerFinish.REP', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='visualfl.WorkerFinish.REP.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=327, + serialized_end=379, +) + +_WORKERFINISH = _descriptor.Descriptor( + name='WorkerFinish', + full_name='visualfl.WorkerFinish', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[_WORKERFINISH_REQ, _WORKERFINISH_REP, ], + enum_types=[ + _WORKERFINISH_STATUS, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=290, + serialized_end=411, +) + +_INIT_REQ.containing_type = _INIT +_INIT_REP.fields_by_name['status'].enum_type = _INIT_STATUS +_INIT_REP.containing_type = _INIT +_INIT_STATUS.containing_type = _INIT +_WORKERJOIN_REQ.containing_type = _WORKERJOIN +_WORKERJOIN_REP.fields_by_name['status'].enum_type = _WORKERJOIN_STATUS +_WORKERJOIN_REP.containing_type = _WORKERJOIN +_WORKERJOIN_STATUS.containing_type = _WORKERJOIN +_WORKERFINISH_REQ.containing_type = _WORKERFINISH +_WORKERFINISH_REP.fields_by_name['status'].enum_type = _WORKERFINISH_STATUS +_WORKERFINISH_REP.containing_type = _WORKERFINISH +_WORKERFINISH_STATUS.containing_type = _WORKERFINISH +DESCRIPTOR.message_types_by_name['Init'] = _INIT +DESCRIPTOR.message_types_by_name['WorkerJoin'] = _WORKERJOIN +DESCRIPTOR.message_types_by_name['WorkerFinish'] = _WORKERFINISH +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Init = _reflection.GeneratedProtocolMessageType('Init', (_message.Message,), { + + 'REQ' : _reflection.GeneratedProtocolMessageType('REQ', (_message.Message,), { + 'DESCRIPTOR' : _INIT_REQ, + '__module__' : 'scheduler_pb2' + # @@protoc_insertion_point(class_scope:visualfl.Init.REQ) + }) + , + + 'REP' : _reflection.GeneratedProtocolMessageType('REP', (_message.Message,), { + 'DESCRIPTOR' : _INIT_REP, + '__module__' : 'scheduler_pb2' + # @@protoc_insertion_point(class_scope:visualfl.Init.REP) + }) + , + 'DESCRIPTOR' : _INIT, + '__module__' : 'scheduler_pb2' + # @@protoc_insertion_point(class_scope:visualfl.Init) + }) +_sym_db.RegisterMessage(Init) +_sym_db.RegisterMessage(Init.REQ) +_sym_db.RegisterMessage(Init.REP) + +WorkerJoin = _reflection.GeneratedProtocolMessageType('WorkerJoin', (_message.Message,), { + + 'REQ' : _reflection.GeneratedProtocolMessageType('REQ', (_message.Message,), { + 'DESCRIPTOR' : _WORKERJOIN_REQ, + '__module__' : 'scheduler_pb2' + # @@protoc_insertion_point(class_scope:visualfl.WorkerJoin.REQ) + }) + , + + 'REP' : _reflection.GeneratedProtocolMessageType('REP', (_message.Message,), { + 'DESCRIPTOR' : _WORKERJOIN_REP, + '__module__' : 'scheduler_pb2' + # @@protoc_insertion_point(class_scope:visualfl.WorkerJoin.REP) + }) + , + 'DESCRIPTOR' : _WORKERJOIN, + '__module__' : 'scheduler_pb2' + # @@protoc_insertion_point(class_scope:visualfl.WorkerJoin) + }) +_sym_db.RegisterMessage(WorkerJoin) +_sym_db.RegisterMessage(WorkerJoin.REQ) +_sym_db.RegisterMessage(WorkerJoin.REP) + +WorkerFinish = _reflection.GeneratedProtocolMessageType('WorkerFinish', (_message.Message,), { + + 'REQ' : _reflection.GeneratedProtocolMessageType('REQ', (_message.Message,), { + 'DESCRIPTOR' : _WORKERFINISH_REQ, + '__module__' : 'scheduler_pb2' + # @@protoc_insertion_point(class_scope:visualfl.WorkerFinish.REQ) + }) + , + + 'REP' : _reflection.GeneratedProtocolMessageType('REP', (_message.Message,), { + 'DESCRIPTOR' : _WORKERFINISH_REP, + '__module__' : 'scheduler_pb2' + # @@protoc_insertion_point(class_scope:visualfl.WorkerFinish.REP) + }) + , + 'DESCRIPTOR' : _WORKERFINISH, + '__module__' : 'scheduler_pb2' + # @@protoc_insertion_point(class_scope:visualfl.WorkerFinish) + }) +_sym_db.RegisterMessage(WorkerFinish) +_sym_db.RegisterMessage(WorkerFinish.REQ) +_sym_db.RegisterMessage(WorkerFinish.REP) + + + +_SCHEDULER = _descriptor.ServiceDescriptor( + name='Scheduler', + full_name='visualfl.Scheduler', + file=DESCRIPTOR, + index=0, + serialized_options=None, + serialized_start=414, + serialized_end=617, + methods=[ + _descriptor.MethodDescriptor( + name='Init', + full_name='visualfl.Scheduler.Init', + index=0, + containing_service=None, + input_type=_INIT_REQ, + output_type=_INIT_REP, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='WorkerJoin', + full_name='visualfl.Scheduler.WorkerJoin', + index=1, + containing_service=None, + input_type=_WORKERJOIN_REQ, + output_type=_WORKERJOIN_REP, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='WorkerFinish', + full_name='visualfl.Scheduler.WorkerFinish', + index=2, + containing_service=None, + input_type=_WORKERFINISH_REQ, + output_type=_WORKERFINISH_REP, + serialized_options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_SCHEDULER) + +DESCRIPTOR.services_by_name['Scheduler'] = _SCHEDULER + +# @@protoc_insertion_point(module_scope) diff --git a/VisualFL/visualfl/protobuf/scheduler_pb2_grpc.py b/VisualFL/visualfl/protobuf/scheduler_pb2_grpc.py new file mode 100644 index 000000000..5bf948019 --- /dev/null +++ b/VisualFL/visualfl/protobuf/scheduler_pb2_grpc.py @@ -0,0 +1,142 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +import visualfl.protobuf.scheduler_pb2 as scheduler__pb2 + + +class SchedulerStub(object): + """Missing associated documentation comment in .proto file""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Init = channel.unary_unary( + '/visualfl.Scheduler/Init', + request_serializer=scheduler__pb2.Init.REQ.SerializeToString, + response_deserializer=scheduler__pb2.Init.REP.FromString, + ) + self.WorkerJoin = channel.unary_unary( + '/visualfl.Scheduler/WorkerJoin', + request_serializer=scheduler__pb2.WorkerJoin.REQ.SerializeToString, + response_deserializer=scheduler__pb2.WorkerJoin.REP.FromString, + ) + self.WorkerFinish = channel.unary_unary( + '/visualfl.Scheduler/WorkerFinish', + request_serializer=scheduler__pb2.WorkerFinish.REQ.SerializeToString, + response_deserializer=scheduler__pb2.WorkerFinish.REP.FromString, + ) + + +class SchedulerServicer(object): + """Missing associated documentation comment in .proto file""" + + def Init(self, request, context): + """Missing associated documentation comment in .proto file""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def WorkerJoin(self, request, context): + """Missing associated documentation comment in .proto file""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def WorkerFinish(self, request, context): + """Missing associated documentation comment in .proto file""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_SchedulerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Init': grpc.unary_unary_rpc_method_handler( + servicer.Init, + request_deserializer=scheduler__pb2.Init.REQ.FromString, + response_serializer=scheduler__pb2.Init.REP.SerializeToString, + ), + 'WorkerJoin': grpc.unary_unary_rpc_method_handler( + servicer.WorkerJoin, + request_deserializer=scheduler__pb2.WorkerJoin.REQ.FromString, + response_serializer=scheduler__pb2.WorkerJoin.REP.SerializeToString, + ), + 'WorkerFinish': grpc.unary_unary_rpc_method_handler( + servicer.WorkerFinish, + request_deserializer=scheduler__pb2.WorkerFinish.REQ.FromString, + response_serializer=scheduler__pb2.WorkerFinish.REP.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'visualfl.Scheduler', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Scheduler(object): + """Missing associated documentation comment in .proto file""" + + @staticmethod + def Init(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/visualfl.Scheduler/Init', + scheduler__pb2.Init.REQ.SerializeToString, + scheduler__pb2.Init.REP.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def WorkerJoin(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/visualfl.Scheduler/WorkerJoin', + scheduler__pb2.WorkerJoin.REQ.SerializeToString, + scheduler__pb2.WorkerJoin.REP.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def WorkerFinish(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/visualfl.Scheduler/WorkerFinish', + scheduler__pb2.WorkerFinish.REQ.SerializeToString, + scheduler__pb2.WorkerFinish.REP.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/VisualFL/visualfl/utils/__init__.py b/VisualFL/visualfl/utils/__init__.py new file mode 100644 index 000000000..3da16e031 --- /dev/null +++ b/VisualFL/visualfl/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/VisualFL/visualfl/utils/conf_utils.py b/VisualFL/visualfl/utils/conf_utils.py new file mode 100644 index 000000000..d50d3bb64 --- /dev/null +++ b/VisualFL/visualfl/utils/conf_utils.py @@ -0,0 +1,83 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from cachetools import cached, LRUCache + +from visualfl import __config_path__ + +@cached(cache=LRUCache(maxsize=64)) +def get_comm_config(key, default=None): + """ + Get config from config.properties + + Parameters + ---------- + key + default + + Returns + ------- + + """ + comm_file_path = __config_path__ + if os.path.exists(comm_file_path): + with open(comm_file_path, encoding="utf8") as fp: + lines = fp.readlines() + for line in lines: + if line and not line.startswith("#"): + split_arr = line.split('=') + if split_arr[0].strip() == key: + return split_arr[1].strip() + return default + + +@cached(cache=LRUCache(maxsize=64)) +def get_env_config(key, default=None): + """ + Read configuration from environment variables + + Parameters + ---------- + key + default + + Returns + ------- + + """ + env_dist = os.environ + val = env_dist.get(key) + return val if val else default + + +def set_env(key, value): + """ + Set environment variables + + Parameters + ---------- + key + value + + Returns + ------- + + """ + os.environ[key] = value + + +def str2bool(v): + return v.lower() in ("true", "t", "1") \ No newline at end of file diff --git a/VisualFL/visualfl/utils/consts.py b/VisualFL/visualfl/utils/consts.py new file mode 100644 index 000000000..9856268a7 --- /dev/null +++ b/VisualFL/visualfl/utils/consts.py @@ -0,0 +1,93 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +JOB_TYPE_PADDLE = "paddle_fl" +TASK_TYPE_AGG = "fl_aggregator" +TASK_TYPE_TRAINER = "fl_trainer" + + +ROLES = ["arbiter", "promoter", "provider"] + +# global config + +COMM_CONF_KEY_MYSQL_HOST = "db.mysql.host" +COMM_CONF_KEY_MYSQL_PORT = "db.mysql.port" +COMM_CONF_KEY_MYSQL_DATABASE = "db.mysql.database" +COMM_CONF_KEY_MYSQL_USERNAME = "db.mysql.username" +COMM_CONF_KEY_MYSQL_PASSWORD = "db.mysql.password" +COMM_CONF_IS_LOCAL = "is_local" + +# SQLite config +COMM_CONF_DB_SQLITE_DATABASE = "db.sqlite.database" + +COMM_CONF_WEFE_JOB_WORK_MODE = "wefe.job.work_mode" + + +class JobStatus(object): + WAIT_RUN = 'wait_run' + WAIT_STOP = 'wait_stop' + RUNNING = 'running' + STOP_ON_RUNNING = 'stop_on_running' + ERROR_ON_RUNNING = 'error_on_running' + SUCCESS = 'success' + WAIT_SUCCESS = 'wait_success' + TIMEOUT = 'timeout' + + @staticmethod + def is_finished(status): + """ + Determine whether the specified task status is stopped + """ + return status == JobStatus.STOP_ON_RUNNING \ + or status == JobStatus.ERROR_ON_RUNNING \ + or status == JobStatus.SUCCESS + + +class TaskStatus(object): + # CREATED = 'created' + WAITRUN = 'wait_run' + RUNNING = 'running' + SUCCESS = 'success' + ERROR = 'error' + TIMEOUT = 'timeout' + STOP = 'stop' + +class MemberRole(object): + """ + Member role + """ + PROVIDER = "provider" + PROMOTER = "promoter" + ARBITER = "arbiter" + +class ComponentName(object): + """ + component name + """ + CLASSIFY = "PaddleClassify" + DETECTION = "PaddleDetection" + +class TaskResultType(object): + """ + task result type + """ + LOSS = "loss" + ACCURACY = "accuracy" + MAP = "mAP" + INFER = "infer" + LABEL = "label" + +if __name__ == '__main__': + pass diff --git a/VisualFL/visualfl/utils/core_utils.py b/VisualFL/visualfl/utils/core_utils.py new file mode 100644 index 000000000..f13f23a47 --- /dev/null +++ b/VisualFL/visualfl/utils/core_utils.py @@ -0,0 +1,200 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import hashlib +import json +import os +import pickle +import socket +import time +import uuid + +import numpy as np + + +def fate_uuid(): + return uuid.uuid1().hex + + +def get_commit_id(): + # the model may be larger, SHA1 is not used + return fate_uuid() + + +def string_to_bytes(string): + return string if isinstance(string, bytes) else string.encode(encoding="utf-8") + + +def bytes_to_string(byte): + return byte.decode(encoding="utf-8") + + +def json_dumps(src, byte=False): + if byte: + return string_to_bytes(json.dumps(src)) + else: + return json.dumps(src) + + +def json_loads(src): + if isinstance(src, bytes): + return json.loads(bytes_to_string(src)) + else: + return json.loads(src) + + +def current_timestamp(): + return int(time.time() * 1000) + + +def current_datetime(): + return time.localtime(time.time()) + + +def get_delta_seconds(a, b): + second = 0 + pre = a + now = b + + if b < a: + pre = b + now = a + + delta = now - pre + + day = delta.days + if day > 0: + second = second + day * 24 * 60 * 60 + + return second + delta.seconds + + +def timestamp_to_date(timestamp=current_timestamp(), format_string="%Y-%m-%d %H:%M:%S"): + timestamp = int(timestamp) / 1000 + time_array = time.localtime(timestamp) + str_date = time.strftime(format_string, time_array) + return str_date + + +def base64_encode(src): + return bytes_to_string(base64.b64encode(src.encode("utf-8"))) + + +def base64_decode(src): + return bytes_to_string(base64.b64decode(src)) + + +def serialize_b64(src, to_str=False): + dest = base64.b64encode(pickle.dumps(src)) + if not to_str: + return dest + else: + return bytes_to_string(dest) + + +def deserialize_b64(src): + return pickle.loads(base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src)) + + +def get_lan_ip(): + if os.name != "nt": + import fcntl + import struct + + def get_interface_ip(ifname): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + return socket.inet_ntoa( + fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24]) + + ip = socket.gethostbyname(socket.getfqdn()) + if ip.startswith("127.") and os.name != "nt": + interfaces = [ + "bond1", + "eth0", + "eth1", + "eth2", + "wlan0", + "wlan1", + "wifi0", + "ath0", + "ath1", + "ppp0", + ] + for if_name in interfaces: + try: + ip = get_interface_ip(if_name) + break + except IOError: + pass + return ip or '' + + +def serialize(src): + """ + The default serialization method + + In non-special cases, this method is used, protocol=2 + + Parameters + ---------- + src: object + data to serialize + + Returns + ------- + + """ + return pickle.dumps(src, protocol=2) + + +def deserialize(src): + """ + The default deserialization method, + can automatically read the version number, so no need to specify + + Parameters + ---------- + src + + Returns + ------- + + """ + return pickle.loads(src) + + +def md5(src): + return hashlib.md5(src.encode('utf-8')).hexdigest() + + +def hash_code(s): + """ + Calculate the hashcode value, same to Java + + Parameters + ---------- + s + + Returns + ------- + + """ + h = 0 + if len(s) > 0: + for item in s: + h = np.int32(31) * h + np.int32(ord(item)) + return np.int32(h) + else: + return 0 diff --git a/VisualFL/visualfl/utils/data_loader.py b/VisualFL/visualfl/utils/data_loader.py new file mode 100644 index 000000000..cbfa25fe9 --- /dev/null +++ b/VisualFL/visualfl/utils/data_loader.py @@ -0,0 +1,332 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +from __future__ import print_function + +import requests +import shutil +import sys +import tarfile +import zipfile +import six.moves.cPickle as pickle +import functools +from visualfl import get_data_dir +from paddle.dataset.image import * +from paddle.reader import * +from paddle import compat as cpt +import os +from multiprocessing import cpu_count +import six +from six.moves import cPickle as pickle +import logging + +__all__ = ['train', 'test', 'valid'] + +DATA_URL = 'xxx' +DATA_MD5='XXX' + +DATA_DIR = os.path.join(get_data_dir(),"flowers") +IMAGE_FILE_NAME='image.tgz' +TRAIN_LIST_FILE='train_list.txt' +TEST_LIST_FILE='test_list.txt' +VALID_LIST_FILE='val_list.txt' + +TRAIN_FLAG = 'trnid' +TEST_FLAG = 'tstid' +VALID_FLAG = 'valid' + + +def default_mapper(is_train, sample): + ''' + map image bytes data to type needed by model input layer + ''' + img, label = sample + img = load_image_bytes(img) + img = simple_transform( + img, 256, 224, is_train, mean=[103.94, 116.78, 123.68]) + return img.flatten().astype('float32'), label + + +train_mapper = functools.partial(default_mapper, True) +test_mapper = functools.partial(default_mapper, False) + + +def reader_creator(data_file, + img2label_file, + dataset_name, + mapper, + buffered_size=1024, + use_xmap=True, + cycle=False): + ''' + 1. read images from tar file and + merge images into batch files in 102flowers.tgz_batch/ + 2. get a reader to read sample from batch file + + :param data_file: downloaded data file + :type data_file: string + :param img2label_file: downloaded label file + :type label_file: string + :param dataset_name: data set name (trnid|tstid|valid) + :type dataset_name: string + :param mapper: a function to map image bytes data to type + needed by model input layer + :type mapper: callable + :param buffered_size: the size of buffer used to process images + :type buffered_size: int + :param cycle: whether to cycle through the dataset + :type cycle: bool + :return: data reader + :rtype: callable + ''' + img2label = {} + for line in open(img2label_file): + line = line.strip('\n') + lines = line.split(' ') + img = lines[0] + img2label[img] = int(lines[1]) + file_list = batch_images_from_tar(data_file, dataset_name, img2label) + + def reader(): + while True: + with open(file_list, 'r') as f_list: + for file in f_list: + file = file.strip() + batch = None + with open(file, 'rb') as f: + if six.PY2: + batch = pickle.load(f) + else: + batch = pickle.load(f, encoding='bytes') + + if six.PY3: + batch = cpt.to_text(batch) + data_batch = batch['data'] + labels_batch = batch['label'] + for sample, label in six.moves.zip(data_batch, + labels_batch): + yield sample, int(label) + if not cycle: + break + + if use_xmap: + return xmap_readers(mapper, reader, min(4, cpu_count()), buffered_size) + else: + return map_readers(mapper, reader) + + +def train(data_dir=DATA_DIR,mapper=train_mapper, buffered_size=1024, use_xmap=True, cycle=False): + ''' + Create flowers training set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] + translated from original color image by steps: + 1. resize to 256*256 + 2. random crop to 224*224 + 3. flatten + :param mapper: a function to map sample. + :type mapper: callable + :param buffered_size: the size of buffer used to process images + :type buffered_size: int + :param cycle: whether to cycle through the dataset + :type cycle: bool + :return: train data reader + :rtype: callable + ''' + return reader_creator( + os.path.join(data_dir, IMAGE_FILE_NAME), + os.path.join(data_dir, TRAIN_LIST_FILE), + TRAIN_FLAG, + mapper, + buffered_size, + use_xmap, + cycle=cycle) + + +def test(data_dir=DATA_DIR,mapper=test_mapper, buffered_size=1024, use_xmap=True, cycle=False): + ''' + Create flowers test set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] + translated from original color image by steps: + 1. resize to 256*256 + 2. random crop to 224*224 + 3. flatten + :param mapper: a function to map sample. + :type mapper: callable + :param buffered_size: the size of buffer used to process images + :type buffered_size: int + :param cycle: whether to cycle through the dataset + :type cycle: bool + :return: test data reader + :rtype: callable + ''' + return reader_creator( + os.path.join(data_dir, IMAGE_FILE_NAME), + os.path.join(data_dir, TEST_LIST_FILE), + TEST_FLAG, + mapper, + buffered_size, + use_xmap, + cycle=cycle) + + +def valid(data_dir=DATA_DIR,mapper=test_mapper, buffered_size=1024, use_xmap=True): + ''' + Create flowers validation set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] + translated from original color image by steps: + 1. resize to 256*256 + 2. random crop to 224*224 + 3. flatten + :param mapper: a function to map sample. + :type mapper: callable + :param buffered_size: the size of buffer used to process images + :type buffered_size: int + :return: test data reader + :rtype: callable + ''' + return reader_creator( + os.path.join(data_dir,IMAGE_FILE_NAME), + os.path.join(data_dir, VALID_LIST_FILE), + VALID_FLAG, mapper, + buffered_size, use_xmap) + + +def download(url, dirname, save_name=None): + if not os.path.exists(dirname): + os.makedirs(dirname) + + filename = os.path.join(dirname, + url.split('/')[-1] + if save_name is None else save_name) + + if os.path.exists(filename): + return filename + + retry = 0 + retry_limit = 3 + while not os.path.exists(filename): + if retry < retry_limit: + retry += 1 + else: + raise RuntimeError("Cannot download {0} within retry limit {1}". + format(url, retry_limit)) + sys.stderr.write("Cache file %s not found, downloading %s \n" % + (filename, url)) + sys.stderr.write("Begin to download\n") + r = requests.get(url, stream=True) + total_length = r.headers.get('content-length') + + if total_length is None: + with open(filename, 'wb') as f: + shutil.copyfileobj(r.raw, f) + else: + with open(filename, 'wb') as f: + chunk_size = 4096 + total_length = int(total_length) + total_iter = total_length / chunk_size + 1 + log_interval = total_iter / 20 if total_iter > 20 else 1 + log_index = 0 + for data in r.iter_content(chunk_size=chunk_size): + if six.PY2: + data = six.b(data) + f.write(data) + log_index += 1 + if log_index % log_interval == 0: + sys.stderr.write("../algorithm/paddle_clas") + sys.stdout.flush() + sys.stderr.write("\nDownload finished\n") + sys.stdout.flush() + + return filename + +def extract(tar_file, target_path): + try: + tar = tarfile.open(tar_file, "r:gz") + file_names = tar.getnames() + for file_name in file_names: + tar.extract(file_name, target_path) + tar.close() + except Exception as e: + print(e) + + +def un_zip(file_name,target_path): + """unzip zip file""" + try: + zip_file = zipfile.ZipFile(file_name) + if os.path.isdir(target_path): + pass + else: + os.mkdir(target_path) + names = zip_file.namelist() + for name in names: + zip_file.extract(name,target_path) + zip_file.close() + return os.path.dirname(names[0]) + except Exception as e: + print(e) + raise Exception(f"unzip file {file_name} error as {e}") + +def make_zip(source_dir, zip_file): + zipf = zipfile.ZipFile(zip_file, 'w') + pre_len = len(os.path.dirname(source_dir)) + for parent, dirnames, filenames in os.walk(source_dir): + for filename in filenames: + pathfile = os.path.join(parent, filename) + arcname = pathfile[pre_len:].strip(os.path.sep) + zipf.write(pathfile, arcname) + zipf.close() + +def job_download(url, job_id,base_dir): + try: + data_file = download(url, base_dir, f"{job_id}.zip") + dir_name = un_zip(data_file, base_dir) + target_dir = os.path.join(base_dir,dir_name) + except Exception as e: + logging.error(f"job download with {job_id} error as {e} ") + + return target_dir + + +def getImageList(dir, filelist): + newDir = dir + if os.path.isfile(dir): + if dir.endswith(".jpg") or dir.endswith(".JPG") or dir.endswith(".png") or dir.endswith(".PNG")\ + or dir.endswith(".jpeg") or dir.endswith(".webp") or dir.endswith(".bmp") or dir.endswith(".tif")\ + or dir.endswith(".gif"): + filelist.append(dir) + + elif os.path.isdir(dir): + for s in os.listdir(dir): + newDir = os.path.join(dir, s) + getImageList(newDir, filelist) + return filelist + +def extractImages(src_dir): + imageList = getImageList(src_dir,[]) + target_path = f"{os.path.dirname(src_dir)}_tmp" + if os.path.isdir(target_path): + pass + else: + os.mkdir(target_path) + for item in imageList: + tmp = os.path.basename(item) + shutil.copy(item, target_path + '/' + tmp) + shutil.rmtree(src_dir) + os.rename(target_path,src_dir) \ No newline at end of file diff --git a/VisualFL/visualfl/utils/exception.py b/VisualFL/visualfl/utils/exception.py new file mode 100644 index 000000000..5c83dda3d --- /dev/null +++ b/VisualFL/visualfl/utils/exception.py @@ -0,0 +1,39 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + + +class VisualFLBaseException(BaseException): + ... + + +class VisualFLException(VisualFLBaseException, Exception): + ... + + +class VisualFLExtensionException(VisualFLException): + ... + + +class VisualFLWorkerException(VisualFLException): + ... + + +class VisualFLJobCompileException(VisualFLException): + ... + + +class VisualFLDataNotFoundException(VisualFLException): + ... diff --git a/VisualFL/visualfl/utils/logger.py b/VisualFL/visualfl/utils/logger.py new file mode 100644 index 000000000..dd4a94cf7 --- /dev/null +++ b/VisualFL/visualfl/utils/logger.py @@ -0,0 +1,122 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path +from typing import Any, Union + +from google.protobuf import text_format +from loguru import logger + +from visualfl import __logs_dir__ + +__BASE_LOGGER = None + + +def set_logger(filename="unnamed"): + log_dir = Path(__logs_dir__) + if not log_dir.exists(): + log_dir.mkdir(exist_ok=True) + + log_format = ( + "[{extra[base]}]" + "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{name}:{function}:{line}:{message}" + ) + config = { + "handlers": [ + # dict(sink=sys.stdout, format=log_format, level="DEBUG"), + dict( + sink=f"{log_dir.joinpath(filename)}.log", + format=log_format, + level="DEBUG", + ), + ], + "extra": {"base": "unknown"}, + } + logger.configure(**config) + global __BASE_LOGGER + __BASE_LOGGER = logger + + +set_logger() + + +class Logger(object): + + _logger = None + + @classmethod + def get_logger(cls, lazy=False): + if cls._logger is None: + cls._logger = logger.bind(base=cls.__name__).opt(depth=1) + if lazy: + return cls._logger.opt(lazy=True, depth=1) + return cls._logger + + @classmethod + def log( + cls, + __level: Union[int, str], + __message: str, + *args: Any, + lazy=False, + **kwargs: Any, + ): + cls.get_logger(lazy=lazy).log(__level, __message, *args, **kwargs) + + @classmethod + def trace(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger(lazy=False).trace(__message, *args, **kwargs) + + @classmethod + def trace_lazy(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger(lazy=True).trace(__message, *args, **kwargs) + + @classmethod + def debug(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger().debug(__message, *args, **kwargs) + + @classmethod + def debug_lazy(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger(lazy=True).debug(__message, *args, **kwargs) + + @classmethod + def info(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger().info(__message, *args, **kwargs) + + @classmethod + def info_lazy(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger(lazy=True).info(__message, *args, **kwargs) + + @classmethod + def warning(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger().warning(__message, *args, **kwargs) + + @classmethod + def error(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger().error(__message, *args, **kwargs) + + @classmethod + def critical(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger().critical(__message, *args, **kwargs) + + @classmethod + def exception(cls, __message: str, *args: Any, **kwargs: Any): + cls.get_logger().exception(__message, *args, **kwargs) + + +def pretty_pb(pb): + return text_format.MessageToString(pb, as_one_line=True) diff --git a/VisualFL/visualfl/utils/logger.pyi b/VisualFL/visualfl/utils/logger.pyi new file mode 100755 index 000000000..4436fb45d --- /dev/null +++ b/VisualFL/visualfl/utils/logger.pyi @@ -0,0 +1,83 @@ +from typing import Any, overload, Union + +import loguru + +class Logger: + @classmethod + def get_logger(cls, lazy=False) -> loguru.Logger: ... + @classmethod + @overload + def log( + cls, + __level: Union[int, str], + __message: str, + *args: Any, + lazy=False, + **kwargs: Any + ): ... + @classmethod + @overload + def log(cls, __level: Union[int, str], __message: Any, lazy=False) -> None: ... + @classmethod + @overload + def trace(cls, __message: str, *args: Any, **kwargs: Any): ... + @classmethod + @overload + def trace(cls, __message: Any): ... + @classmethod + @overload + def trace_lazy(cls, __message: str, *args: Any, **kwargs: Any): ... + @classmethod + @overload + def trace_lazy(cls, __message: Any): ... + @classmethod + @overload + def debug(cls, __message: str, *args: Any, **kwargs: Any): ... + @classmethod + @overload + def debug(cls, __message: Any): ... + @classmethod + @overload + def debug_lazy(cls, __message: str, *args: Any, **kwargs: Any): ... + @classmethod + @overload + def debug_lazy(cls, __message: Any): ... + @classmethod + @overload + def info(cls, __message: str, *args: Any, **kwargs: Any): ... + @classmethod + @overload + def info(cls, __message: Any): ... + @classmethod + @overload + def info_lazy(cls, __message: str, *args: Any, **kwargs: Any): ... + @classmethod + @overload + def info_lazy(cls, __message: Any): ... + @classmethod + @overload + def warning(cls, __message: str, *args: Any, **kwargs: Any) -> None: ... + @classmethod + @overload + def warning(cls, __message: Any) -> None: ... + @classmethod + @overload + def error(cls, __message: str, *args: Any, **kwargs: Any) -> None: ... + @classmethod + @overload + def error(cls, __message: Any) -> None: ... + @classmethod + @overload + def critical(cls, __message: str, *args: Any, **kwargs: Any) -> None: ... + @classmethod + @overload + def critical(cls, __message: Any) -> None: ... + @classmethod + @overload + def exception(cls, __message: str, *args: Any, **kwargs: Any) -> None: ... + @classmethod + @overload + def exception(cls, __message: Any) -> None: ... + +def pretty_pb(pb) -> str: ... +def set_logger(filename="visual_fl.log"): ... diff --git a/VisualFL/visualfl/utils/tools.py b/VisualFL/visualfl/utils/tools.py new file mode 100644 index 000000000..c41eed1d5 --- /dev/null +++ b/VisualFL/visualfl/utils/tools.py @@ -0,0 +1,51 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import aiohttp +import asyncio +import json +from visualfl.db.task_dao import TaskDao +import time +import logging + +def post(url, json_data): + async def post_co(): + async with aiohttp.ClientSession() as session: + async with session.post( + url, json=json_data + ) as resp: + print(resp.status) + print(json.dumps(await resp.json(), indent=2)) + resp.raise_for_status() + + loop = asyncio.get_event_loop() + loop.run_until_complete(post_co()) + +def save_data_to_db(task_id,tag,value,step,component_name): + try: + result,data = {},{} + current_milli_time = int(round(time.time() * 1000)) + tag = "accuracy" if "accuracy" in tag else "loss" + + dao = TaskDao(task_id) + model = dao.get_task_result(tag) + if model: + result = json.loads(model.result) + data = result.get("data") + + data[int(step)] = dict(value=float(value), timestamp=current_milli_time) + result.update(data=data) + dao.save_task_result(task_result=result,component_name=component_name,type=tag) + except Exception as e: + logging.error(f"task {task_id} save data to db error {e}") diff --git a/VisualFL/visualfl/worker.py b/VisualFL/visualfl/worker.py new file mode 100644 index 000000000..1f3572f2c --- /dev/null +++ b/VisualFL/visualfl/worker.py @@ -0,0 +1,303 @@ +# Copyright 2021 Tianmian Tech. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2020 The FedVision Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Optional, List, AsyncIterable + +import grpc + +from visualfl import __logs_dir__ +from visualfl import extensions +from visualfl.paddle_fl.abs.task import Task +from visualfl.paddle_fl.executor import ProcessExecutor +from visualfl.protobuf import cluster_pb2_grpc, cluster_pb2 +from visualfl.utils.exception import ( + VisualFLWorkerException, + VisualFLExtensionException, + VisualFLException, +) +from visualfl.utils.logger import Logger, pretty_pb +from visualfl.db.task_dao import TaskDao +from visualfl.utils.consts import TaskStatus + + +class ClusterWorker(Logger): + def __init__( + self, + worker_id: str, + worker_ip: str, + max_tasks: int, + port_start: int, + port_end: int, + manager_address: str, + data_dir: str = None, + ): + """ + init cluster worker instance + Args: + worker_id: + worker_ip: + max_tasks: + port_start: + port_end: + manager_address: + data_dir: + """ + self._task_queue: asyncio.Queue = asyncio.Queue() + self._semaphore = asyncio.Semaphore(max_tasks) + + self._worker_id = worker_id + self._worker_ip = worker_ip + self._manager_address = manager_address + self._max_tasks = max_tasks + self._port_start = port_start + self._port_end = port_end + self._heartbeat_interval = 1 + self._data_dir = data_dir + + self._channel: Optional[grpc.Channel] = None + self._stub: Optional[cluster_pb2_grpc.ClusterManagerStub] = None + + self._tasks: List[asyncio.Future] = [] + self._task_status: asyncio.Queue = asyncio.Queue() + self._stop_event = asyncio.Event() + + self._asyncio_task_collection: Optional[List[asyncio.Task]] = None + + async def start(self): + """ + start worker + + 1. enroll to manager + 2. start heartbeat loop + 3. start task exec loop + 4. process tasks + """ + self.info(f"starting worker {self._worker_id}") + + self.info(f"staring grpc channel to cluster manager") + self._channel = grpc.aio.insecure_channel( + self._manager_address, + options=[ + ("grpc.max_send_message_length", 512 * 1024 * 1024), + ("grpc.max_receive_message_length", 512 * 1024 * 1024), + ], + ) + self._stub = cluster_pb2_grpc.ClusterManagerStub(self._channel) + + self.info(f"sending enroll request to cluster manager") + response_stream: AsyncIterable[cluster_pb2.Enroll.REP] = self._stub.Enroll( + cluster_pb2.Enroll.REQ( + worker_id=self._worker_id, + worker_ip=self._worker_ip, + max_tasks=self._max_tasks, + port_start=self._port_start, + port_end=self._port_end, + ) + ) + first_response = True + + try: + async for response in response_stream: + + if first_response: + if response.status == cluster_pb2.Enroll.ALREADY_ENROLL: + raise VisualFLWorkerException( + f"worker<{self._worker_id}> already enrolled, use new name or remove it from manager" + ) + + if response.status != cluster_pb2.Enroll.ENROLL_SUCCESS: + raise VisualFLWorkerException( + f"worker<{self._worker_id}>enroll failed with unknown status: {response.status}" + ) + self.info( + f"worker<{self._worker_id}>success enrolled to cluster manager" + ) + + async def _co_update_status(): + while True: + try: + request = await asyncio.wait_for( + self._task_status.get(), self._heartbeat_interval + ) + except asyncio.TimeoutError: + self.trace( + "wait task status timeout. sending heartbeat request" + ) + request = cluster_pb2.UpdateStatus.REQ( + worker_id=self._worker_id + ) + + try: + update_response = await self._stub.UpdateTaskStatus( + request + ) + except grpc.aio.AioRpcError as _e: + self.error(f"can't send heartbeat to manager, {_e}") + self._stop_event.set() + return + if ( + update_response.status + != cluster_pb2.UpdateStatus.SUCCESS + ): + self.error( + f"update status failed, please check manager status" + ) + + self.info("starting heartbeat loop") + self._asyncio_task_collection = [ + asyncio.create_task(_co_update_status()), + ] + self.info("heartbeat loop started") + + self.info(f"starting task execute loop") + self._asyncio_task_collection.append( + asyncio.create_task(self._co_task_execute_loop()) + ) + self.info(f"task execute loop started") + first_response = False + continue + + # fetch tasks + if response.status != cluster_pb2.Enroll.TASK_READY: + raise VisualFLWorkerException( + f"expect status {cluster_pb2.Enroll.TASK_READY}, got {response.status}" + ) + + self.trace_lazy( + f"response <{{response}}> got", response=lambda: pretty_pb(response) + ) + try: + task_id = response.task.task_id + task_type = response.task.task_type + task_class = extensions.get_task_class(task_type) + if task_class is None: + self.error(f"task type {task_type} not found") + raise VisualFLExtensionException( + f"task type {task_type} not found" + ) + task = task_class.deserialize(response.task) + await self._task_queue.put(task) + TaskDao(task.web_task_id).update_task_status(TaskStatus.WAITRUN) + self.trace(f"put task in queue: task_id={task_id}") + except VisualFLException as e: + self.error(f"preprocess fetched task failed: {e}") + except Exception as e: + self.exception(e) + except grpc.aio.AioRpcError as e: + self.error(f"gRPC error: can't connect with cluster manager, {e}") + self._stop_event.set() + + async def wait_for_termination(self): + """ + block until stop event was set + """ + await self._stop_event.wait() + self.info(f"stop event set, stopping worker {self._worker_id}") + + async def stop(self): + """ + stop worker + """ + if self._channel is not None: + await self._channel.close() + self._channel = None + + self.info(f"canceling unfinished asyncio tasks") + if self._asyncio_task_collection is not None: + for task in self._asyncio_task_collection: + if not task.done(): + task.cancel() + self.trace(f"canceled task {task}") + self.info(f"all unfinished asyncio tasks canceled") + + async def _task_exec_coroutine(self, _task: Task): + try: + self.info( + f"start to exec task, job_id={_task.job_id}, task_id={_task.task_id}, task_type={_task.task_type}" + ) + executor = ProcessExecutor( + Path(__logs_dir__).joinpath(f"jobs/{_task.job_id}/{_task.task_id}"), + data_dir=self._data_dir, + ) + response = await _task.exec(executor) + self.info( + f"finish exec task, job_id={_task.job_id}, task_id={_task.task_id}" + ) + + self.trace(f"update task status") + await self._task_status.put( + cluster_pb2.UpdateStatus.REQ( + worker_id=self._worker_id, + job_id=_task.job_id, + task_id=_task.task_id, + task_status=cluster_pb2.UpdateStatus.TASK_FINISH, + exec_result=response, + ) + ) + self.info( + f"task status success updated to {cluster_pb2.UpdateStatus.TASK_FINISH}. " + f"job_id={_task.job_id}, task_id={_task.task_id}" + ) + + except Exception as e: + self.exception(e) + await self._task_status.put( + cluster_pb2.UpdateStatus.REQ( + worker_id=self._worker_id, + job_id=_task.job_id, + task_id=_task.task_id, + task_status=cluster_pb2.UpdateStatus.TASK_EXCEPTION, + exception=str(e), + ) + ) + TaskDao(_task.web_task_id).update_task_status(TaskStatus.ERROR,str(e)) + finally: + self._semaphore.release() + self.trace_lazy( + f"semaphore released, current: {{current}}", + current=lambda: self._semaphore, + ) + + async def _co_task_execute_loop(self): + + # noinspection PyUnusedLocal + + while True: + self.trace(f"acquiring semaphore") + await self._semaphore.acquire() + self.trace(f"acquired semaphore") + self.trace(f"get from task queue") + ready_task = await self._task_queue.get() + self.trace_lazy(f"got {{task}} from task queue", task=lambda: ready_task) + asyncio.create_task(self._task_exec_coroutine(ready_task)) + self.trace(f"asyncio task created to exec task") diff --git a/Wefe/assembly/scripts/install.sh b/Wefe/assembly/scripts/install.sh new file mode 100644 index 000000000..47af099c2 --- /dev/null +++ b/Wefe/assembly/scripts/install.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 检索当前要编译的项目,再去调具体子项目的编译脚本 +case "$application" in + wefe-board-service) + /bin/bash ./board/board-service/assembly/scripts/install.sh $application + ;; + wefe-board-website) + /bin/bash -x ./board/board-website/assembly/scripts/install.sh $application + ;; + wefe-union-service) + /bin/bash ./union/union-service/assembly/scripts/install.sh $application + ;; + wefe-union-website) + /bin/bash ./union/union-website/assembly/scripts/install.sh $application + ;; + wefe-gateway) + /bin/bash ./gateway/assembly/scripts/install.sh $application + ;; + wefe-serving-service) + /bin/bash ./serving/serving-service/assembly/scripts/install.sh $application + ;; + wefe-serving-website) + /bin/bash ./serving/serving-website/assembly/scripts/install.sh $application + ;; + wefe-blockchain-data-sync) + /bin/bash ./blockchain/wefe-blockchain-data-sync/assembly/scripts/install.sh $application + ;; + wefe-data-fusion-service) + /bin/bash ./fusion/fusion-service/assembly/scripts/install.sh $application + ;; + wefe-data-fusion-website) + /bin/bash ./fusion/fusion-website/assembly/scripts/install.sh $application + ;; +esac + +exit 0 diff --git a/Wefe/board/board-service/assembly/scripts/install.sh b/Wefe/board/board-service/assembly/scripts/install.sh new file mode 100644 index 000000000..6526fec1c --- /dev/null +++ b/Wefe/board/board-service/assembly/scripts/install.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +workdir=$(pwd) + +cd "$workdir" + +mvn clean install -Dmaven.test.skip=true -am -pl board/board-service +echo "打包完毕" + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/"$application" <<-EOF +{ + "targetPath": "${workdir}/board/board-service/target" +} +EOF + +exit 0 \ No newline at end of file diff --git a/Wefe/board/board-website/assembly/scripts/install.sh b/Wefe/board/board-website/assembly/scripts/install.sh new file mode 100644 index 000000000..464dc88f3 --- /dev/null +++ b/Wefe/board/board-website/assembly/scripts/install.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +# workdir=$(dirname $0)/../../ ; cd $workdir +workdir=$(pwd)/$(dirname $0)/../../ ; cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +## CI_ 打头的为和运维约定好的变量,CI_DEPLOY_ENV 代表编译环境 +[ -e $HOME/.nvm/nvm.sh ] && source $HOME/.nvm/nvm.sh + +rm -rf node_modules +nvm use 16.13.0 || : +nrm use npm +npm install +npm run build -- $CI_DEPLOY_ENV=$CI_SERVICE_NAME tail=2 + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "$workdir/dist" +} +EOF + +exit 0 diff --git a/Wefe/fusion/fusion-service/assembly/scripts/install.sh b/Wefe/fusion/fusion-service/assembly/scripts/install.sh new file mode 100644 index 000000000..0f5da0f09 --- /dev/null +++ b/Wefe/fusion/fusion-service/assembly/scripts/install.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +workdir=$(pwd) + +cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +mvn clean install -Dmaven.test.skip=true -am -pl fusion/fusion-service + +echo "将加密后的包重命名为 wefe-data-fusion-service.jar" +mv fusion/fusion-service/target/fusion-service.jar fusion/fusion-service/target/wefe-data-fusion-service.jar + + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "${workdir}/fusion/fusion-service/target" +} +EOF + +exit 0 \ No newline at end of file diff --git a/Wefe/fusion/fusion-website/assembly/scripts/install.sh b/Wefe/fusion/fusion-website/assembly/scripts/install.sh new file mode 100644 index 000000000..df725c60c --- /dev/null +++ b/Wefe/fusion/fusion-website/assembly/scripts/install.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +# workdir=$(dirname $0)/../../ ; cd $workdir +workdir=$(pwd)/$(dirname $0)/../../ ; cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +## CI_ 打头的为和运维约定好的变量,CI_DEPLOY_ENV 代表编译环境 +[ -e $HOME/.nvm/nvm.sh ] && source $HOME/.nvm/nvm.sh + +rm -rf node_modules +nvm use 16.13.0 || : +nrm use npm +npm install +npm run build -- $CI_DEPLOY_ENV=$CI_SERVICE_NAME + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "$workdir/dist" +} +EOF + +exit 0 diff --git a/Wefe/gateway/assembly/scripts/install.sh b/Wefe/gateway/assembly/scripts/install.sh new file mode 100644 index 000000000..5e24ee305 --- /dev/null +++ b/Wefe/gateway/assembly/scripts/install.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +workdir=$(pwd) + +cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +mvn clean install -Dmaven.test.skip=true -am -pl gateway +echo "打包完毕" + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "${workdir}/gateway/target" +} +EOF + +exit 0 \ No newline at end of file diff --git a/Wefe/manager/manager-service/assembly/scripts/install.sh b/Wefe/manager/manager-service/assembly/scripts/install.sh new file mode 100644 index 000000000..f3234f908 --- /dev/null +++ b/Wefe/manager/manager-service/assembly/scripts/install.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +workdir=$(pwd) + +cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +mvn clean install -Dmaven.test.skip=true -am -pl manager/manager-service + +echo "将加密后的包重命名为 wefe-manager-service.jar" +mv manager/manager-service/target/manager-service.jar manager/manager-service/target/wefe-manager-service.jar + + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "${workdir}/manager/manager-service/target" +} +EOF + +exit 0 diff --git a/Wefe/manager/manager-website/assembly/scripts/install.sh b/Wefe/manager/manager-website/assembly/scripts/install.sh new file mode 100644 index 000000000..df725c60c --- /dev/null +++ b/Wefe/manager/manager-website/assembly/scripts/install.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +# workdir=$(dirname $0)/../../ ; cd $workdir +workdir=$(pwd)/$(dirname $0)/../../ ; cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +## CI_ 打头的为和运维约定好的变量,CI_DEPLOY_ENV 代表编译环境 +[ -e $HOME/.nvm/nvm.sh ] && source $HOME/.nvm/nvm.sh + +rm -rf node_modules +nvm use 16.13.0 || : +nrm use npm +npm install +npm run build -- $CI_DEPLOY_ENV=$CI_SERVICE_NAME + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "$workdir/dist" +} +EOF + +exit 0 diff --git a/Wefe/serving/serving-service/assembly/scripts/install.sh b/Wefe/serving/serving-service/assembly/scripts/install.sh new file mode 100644 index 000000000..d7b666544 --- /dev/null +++ b/Wefe/serving/serving-service/assembly/scripts/install.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +workdir=$(pwd) + +cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +mvn clean install -Dmaven.test.skip=true -am -pl serving/serving-service + +echo "将加密后的包重命名为 wefe-data-fusion-service.jar" +mv serving/serving-service/target/serving-service.jar serving/serving-service/target/wefe-serving-service.jar + + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "${workdir}/serving/serving-service/target" +} +EOF + +exit 0 \ No newline at end of file diff --git a/Wefe/serving/serving-website/assembly/scripts/install.sh b/Wefe/serving/serving-website/assembly/scripts/install.sh new file mode 100644 index 000000000..6021572e0 --- /dev/null +++ b/Wefe/serving/serving-website/assembly/scripts/install.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +# workdir=$(dirname $0)/../../ ; cd $workdir +workdir=$(pwd)/$(dirname $0)/../../ ; cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +## CI_ 打头的为和运维约定好的变量,CI_DEPLOY_ENV 代表编译环境 +[ -e $HOME/.nvm/nvm.sh ] && source $HOME/.nvm/nvm.sh + +rm -rf node_modules +nvm use 16.13.0 || : +npm install +npm run build -- $CI_DEPLOY_ENV=$CI_SERVICE_NAME tail=2 + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "$workdir/dist" +} +EOF + +exit 0 diff --git a/Wefe/union/blockchain-data-sync/assembly/scripts/install.sh b/Wefe/union/blockchain-data-sync/assembly/scripts/install.sh new file mode 100644 index 000000000..c5893faee --- /dev/null +++ b/Wefe/union/blockchain-data-sync/assembly/scripts/install.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +workdir=$(pwd) + +cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +mvn clean install -Dmaven.test.skip=true -am -pl union/blockchain-data-sync + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "${workdir}/blockchain/wefe-blockchain-data-sync/target" +} +EOF + +exit 0 \ No newline at end of file diff --git a/Wefe/union/union-service/assembly/scripts/install.sh b/Wefe/union/union-service/assembly/scripts/install.sh new file mode 100644 index 000000000..e317eb6df --- /dev/null +++ b/Wefe/union/union-service/assembly/scripts/install.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +workdir=$(pwd) + +cd $workdir + +## 子项目编译命令,需要根据实际项目更改 +mvn clean install -Dmaven.test.skip=true -am -pl union/union-service + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/$application <<-EOF +{ + "targetPath": "${workdir}/union/union-service/target" +} +EOF + +exit 0 \ No newline at end of file diff --git a/assembly/scripts/install.sh b/assembly/scripts/install.sh new file mode 100644 index 000000000..d8c0a2b97 --- /dev/null +++ b/assembly/scripts/install.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 检索当前要编译的项目,再去调具体子项目的编译脚本 +case "$application" in + wefe-board-service) + /bin/bash ./board/board-service/assembly/scripts/install.sh $application + ;; + wefe-board-website) + /bin/bash -x ./board/board-website/assembly/scripts/install.sh $application + ;; + wefe-union-service) + /bin/bash ./union/union-service/assembly/scripts/install.sh $application + ;; + wefe-union-website) + /bin/bash ./union/union-website/assembly/scripts/install.sh $application + ;; + wefe-gateway) + /bin/bash ./gateway/assembly/scripts/install.sh $application + ;; + wefe-serving-service) + /bin/bash ./serving/serving-service/assembly/scripts/install.sh $application + ;; + wefe-serving-website) + /bin/bash ./serving/serving-website/assembly/scripts/install.sh $application + ;; + wefe-blockchain-data-sync) + /bin/bash ./union/blockchain-data-sync/assembly/scripts/install.sh $application + ;; + wefe-data-fusion-service) + /bin/bash ./fusion/fusion-service/assembly/scripts/install.sh $application + ;; + wefe-data-fusion-website) + /bin/bash ./fusion/fusion-website/assembly/scripts/install.sh $application + ;; + wefe-manager-website) + /bin/bash ./manager/manager-website/assembly/scripts/install.sh $application + ;; + wefe-manager-service) + /bin/bash ./manager/manager-service/assembly/scripts/install.sh $application + ;; +esac + +exit 0 diff --git a/board/board-service/assembly/scripts/install.sh b/board/board-service/assembly/scripts/install.sh new file mode 100644 index 000000000..1fc15282a --- /dev/null +++ b/board/board-service/assembly/scripts/install.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +## -x: Debug mode +## -e: exit the script if any statement returns a non-true return value +[ x"${DEBUG}" == x"true" ] && set -ex || set -e + +## application 值为 Jenkins 编译时传入,对应 Jenkins 的 JOB_BASE_NAME,即 APP_NAME +application=$1 + +## --- 该分割线以上代码不用动 --- + +## 切换到具体的子项目顶层目录 +workdir=$(pwd) + +cd "$workdir" + +mvn clean install -Dmaven.test.skip=true -am -pl board/board-service +echo "打包完毕" + +## 生成 JSON 配置文件,此文件作用告知运维怎么拿到实际要部署的代码、配置文件(以目录形式存放) +## JSON 中的 key 值,事先和运维约定好 +cat > /tmp/"$application" <<-EOF +{ + "targetPath": "${workdir}/board/board-service/target" +} +EOF + +exit 0 diff --git a/board/board-service/pom.xml b/board/board-service/pom.xml index cbc353c51..13dbebfac 100644 --- a/board/board-service/pom.xml +++ b/board/board-service/pom.xml @@ -31,6 +31,12 @@ common-data-storage ${project.parent.version} + + + com.welab.wefe + fusion-core + ${project.parent.version} + org.hibernate.javax.persistence hibernate-jpa-2.1-api @@ -105,7 +111,6 @@ guava 30.0-jre - org.springframework.boot spring-boot-starter-mail @@ -120,7 +125,6 @@ - org.apache.hadoop hadoop-client @@ -142,6 +146,30 @@ + + com.welab.wefe + fusion-service + 1.0.0 + compile + + + + + com.welab.wefe + common-verification-code + ${project.parent.version} + + + + com.aliyun + dysmsapi20170525 + 2.0.5 + + + org.apache.httpcomponents + httpclient + + diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/BoardService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/BoardService.java index 716e0583c..428bd7826 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/BoardService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/BoardService.java @@ -1,11 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,7 +20,7 @@ import com.welab.wefe.board.service.base.OnlineDemoApi; import com.welab.wefe.board.service.constant.Config; import com.welab.wefe.board.service.exception.FlowNodeException; -import com.welab.wefe.board.service.operation.OperationLogAfterApiExecute; +import com.welab.wefe.board.service.operation.BoardApiLogger; import com.welab.wefe.board.service.service.CacheObjects; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.storage.StorageManager; @@ -31,6 +32,7 @@ import com.welab.wefe.common.web.config.ApiBeanNameGenerator; import com.welab.wefe.common.web.dto.ApiResult; import com.welab.wefe.common.web.dto.SignedApiInput; +import com.welab.wefe.common.wefe.checkpoint.CheckpointManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.BeansException; @@ -52,7 +54,12 @@ @ComponentScan( lazyInit = true, nameGenerator = ApiBeanNameGenerator.class, - basePackageClasses = {BoardService.class, Launcher.class, StorageManager.class} + basePackageClasses = { + BoardService.class, + Launcher.class, + StorageManager.class, + CheckpointManager.class + } ) public class BoardService implements ApplicationContextAware { @@ -61,6 +68,7 @@ public class BoardService implements ApplicationContextAware { public static void main(String[] args) { Launcher .instance() + .apiLogger(new BoardApiLogger()) .apiPackageClass(BoardService.class) // Login status check method .checkSessionTokenFunction((api, annotation, token) -> CurrentAccount.get() != null) @@ -89,7 +97,7 @@ public static void main(String[] args) { // 在线体验版专用 api 权限检查 OnlineDemoApi onlineDemoApi = api.getClass().getAnnotation(OnlineDemoApi.class); if (onlineDemoApi != null) { - Config config = Launcher.CONTEXT.getBean(Config.class); + Config config = Launcher.getBean(Config.class); if (!config.isOnlineDemo()) { throw new StatusCodeWithException("The current environment does not allow this API to be called", StatusCode.SYSTEM_ERROR); } @@ -100,10 +108,6 @@ public static void main(String[] args) { } }) .launch(BoardService.class, args); - - Launcher - .instance() - .afterApiExecuteFunction(Launcher.CONTEXT.getBean(OperationLogAfterApiExecute.class)); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/AuditApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/AuditApi.java index ad1e6b246..4ce3c8132 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/AuditApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/AuditApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,7 +18,6 @@ import com.welab.wefe.board.service.service.account.AccountService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.AuditStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.StringUtil; @@ -26,6 +25,7 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.AuditStatus; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/CaptchaApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/CaptchaApi.java index 6716450ab..089de13e4 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/CaptchaApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/CaptchaApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/EnableApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/EnableApi.java index af1ded1fb..d239f2130 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/EnableApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/EnableApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/ForgetPasswordApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/ForgetPasswordApi.java index a0d6d7a55..08ae9f95f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/ForgetPasswordApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/ForgetPasswordApi.java @@ -1,11 +1,11 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/LoginApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/LoginApi.java index f8843a794..b289ada83 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/LoginApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/LoginApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,9 +16,11 @@ package com.welab.wefe.board.service.api.account; -import com.welab.wefe.board.service.database.entity.AccountMySqlModel; +import com.welab.wefe.board.service.database.entity.AccountMysqlModel; +import com.welab.wefe.board.service.database.repository.AccountRepository; import com.welab.wefe.board.service.service.account.AccountService; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.board.service.util.BoardSM4Util; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; @@ -27,6 +29,7 @@ import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.AbstractApiOutput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.service.account.AccountInfo; import org.springframework.beans.factory.annotation.Autowired; /** @@ -37,6 +40,8 @@ public class LoginApi extends AbstractApi { @Autowired private AccountService accountService; + @Autowired + private AccountRepository accountRepository; @Autowired private GlobalConfigService globalConfigService; @@ -44,7 +49,9 @@ public class LoginApi extends AbstractApi { @Override protected ApiResult handle(Input input) throws StatusCodeWithException { - Output output = accountService.login(input.phoneNumber, input.password, input.key, input.code); + String token = accountService.login(input.phoneNumber, input.password, input.key, input.code); + AccountMysqlModel model = accountRepository.findByPhoneNumber(BoardSM4Util.encryptPhoneNumber(input.phoneNumber)); + Output output = new Output(token, model); /** * After successful login, check whether the system has been initialized @@ -60,7 +67,7 @@ protected ApiResult handle(Input input) throws StatusCodeWithException { } // If you are not a super administrator, you will be prompted that you cannot log in. else { - AccountMySqlModel superAdmin = accountService.findSuperAdmin(); + AccountInfo superAdmin = accountService.getSuperAdmin(); return fail("The system has not been initialized, please contact " + superAdmin.getNickname() + " Initialize the system."); } } @@ -135,6 +142,18 @@ public static class Output extends AbstractApiOutput { private Boolean adminRole; + public Output() { + } + + public Output(String token, AccountMysqlModel model) throws StatusCodeWithException { + this.id = model.getId(); + this.token = token; + this.phoneNumber = model.getPhoneNumber(); + this.nickname = model.getNickname(); + this.email = model.getEmail(); + this.superAdminRole = model.getSuperAdminRole(); + this.adminRole = model.getAdminRole(); + } //region getter/setter diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryApi.java index eaa5825f0..d1404a53b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,11 +20,11 @@ import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.dto.entity.AccountOutputModel; import com.welab.wefe.board.service.service.account.AccountService; -import com.welab.wefe.common.enums.AuditStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.AuditStatus; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryMemberAccountsApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryMemberAccountsApi.java index 5a3e6df67..208a01c4d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryMemberAccountsApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryMemberAccountsApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryOnlineApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryOnlineApi.java index 51ded0c69..3ad328153 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryOnlineApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/QueryOnlineApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -51,9 +51,7 @@ public static class Input extends AbstractApiInput { */ @Check(require = true) private String memberId; - /** - * Account ID (if it is not empty, it means that you specify to query the online status of the account) - */ + @Check(name = "Account ID (if it is not empty, it means that you specify to query the online status of the account)") private String accountId; public String getMemberId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/RegisterApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/RegisterApi.java index 7400cff94..a6b94771c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/RegisterApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/RegisterApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,16 +16,18 @@ package com.welab.wefe.board.service.api.account; +import com.welab.wefe.board.service.constant.Config; import com.welab.wefe.board.service.dto.vo.AccountInputModel; import com.welab.wefe.board.service.service.account.AccountService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.BoardUserSource; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.Launcher; import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.ApiResult; import com.welab.wefe.common.web.service.CaptchaService; +import com.welab.wefe.common.wefe.enums.BoardUserSource; import org.springframework.beans.factory.annotation.Autowired; /** @@ -54,12 +56,14 @@ public static class Input extends AccountInputModel { @Override public void checkAndStandardize() throws StatusCodeWithException { super.checkAndStandardize(); - - // Verification code verification - if (!CaptchaService.verify(key, code)) { - throw new StatusCodeWithException("验证码错误!", StatusCode.PARAMETER_VALUE_INVALID); + if (Launcher.getBean(Config.class).getEnvName().isProductionEnv()) { + // Verification code verification + if (!CaptchaService.verify(key, code)) { + throw new StatusCodeWithException("验证码错误!", StatusCode.PARAMETER_VALUE_INVALID); + } } + } //region getter/setter diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/ResetPasswordApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/ResetPasswordApi.java index 3e0ff5eba..9ff95f88e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/ResetPasswordApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/ResetPasswordApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -44,6 +44,11 @@ public static class Input extends AbstractApiInput { @Check(name = "用户唯一标识", require = true) private String id; + @Check(name = "操作者的密码", require = true) + private String operatorPassword; + + // region getter/setter + public String getId() { return id; } @@ -51,6 +56,16 @@ public String getId() { public void setId(String id) { this.id = id; } + + public String getOperatorPassword() { + return operatorPassword; + } + + public void setOperatorPassword(String operatorPassword) { + this.operatorPassword = operatorPassword; + } + + // endregion } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/SendForgetPasswordVerificationCodeApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/SendForgetPasswordVerificationCodeApi.java new file mode 100644 index 000000000..64f38cf2f --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/SendForgetPasswordVerificationCodeApi.java @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.account; + +import com.welab.wefe.board.service.service.verificationcode.VerificationCodeService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.dto.NoneApiOutput; +import com.welab.wefe.common.wefe.enums.VerificationCodeBusinessType; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * Send forget password verification code + * + * + * @author aaron.li + * @date 2021/11/11 09:45 + **/ +@Api(path = "account/send_forget_password_code", name = "send verification code", login = false) +public class SendForgetPasswordVerificationCodeApi extends AbstractApi { + + @Autowired + private VerificationCodeService verificationCodeService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + verificationCodeService.send(input.phoneNumber, VerificationCodeBusinessType.accountForgetPassword); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + private String phoneNumber; + + public String getPhoneNumber() { + return phoneNumber; + } + + public void setPhoneNumber(String phoneNumber) { + this.phoneNumber = phoneNumber; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/SuperAdminChangeApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/SuperAdminChangeApi.java index 29de8e91a..cea9c0387 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/SuperAdminChangeApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/SuperAdminChangeApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,7 @@ package com.welab.wefe.board.service.api.account; -import com.welab.wefe.board.service.database.entity.AccountMySqlModel; +import com.welab.wefe.board.service.database.entity.AccountMysqlModel; import com.welab.wefe.board.service.database.repository.AccountRepository; import com.welab.wefe.board.service.service.account.AccountService; import com.welab.wefe.common.StatusCode; @@ -42,7 +42,7 @@ public class SuperAdminChangeApi extends AbstractApi handle(SuperAdminChangeApi.Input input) throws StatusCodeWithException { - AccountMySqlModel account = accountRepository.findById(input.getId()).orElse(null); + AccountMysqlModel account = accountRepository.findById(input.getId()).orElse(null); if (account == null) { throw new StatusCodeWithException("指定用户不存在", StatusCode.DATA_NOT_FOUND); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/UpdateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/UpdateApi.java index 21930eb3b..e0903072e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/UpdateApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/UpdateApi.java @@ -1,12 +1,12 @@ /* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/UpdatePasswordApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/UpdatePasswordApi.java index 829ffb2be..003ed0e94 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/UpdatePasswordApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/UpdatePasswordApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/VerificationCodeSendChannelApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/VerificationCodeSendChannelApi.java new file mode 100644 index 000000000..53f7b3293 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/VerificationCodeSendChannelApi.java @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.account; + + +import com.welab.wefe.board.service.service.verificationcode.VerificationCodeService; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.*; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * Get verification code send channel + */ +@Api(path = "account/verification_code_send_channel", name = "Get verification code send channel", login = false) +public class VerificationCodeSendChannelApi extends AbstractApi { + + @Autowired + private VerificationCodeService verificationCodeService; + + @Override + protected ApiResult handle(NoneApiInput input) throws Exception { + Output output = new Output(); + output.setChannel(verificationCodeService.getSendChannel()); + return success(output); + } + + public static class Output { + private String channel; + + public String getChannel() { + return channel; + } + + public void setChannel(String channel) { + this.channel = channel; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/account-register.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/account-register.http index df2e00ee8..d4ce1a0b0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/account-register.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/account-register.http @@ -1,13 +1,15 @@ ### 正常注册 -POST {{baseUrl}}/account/register +POST http://localhost:8080/board-service/account/register Content-Type: application/json { "phone_number": "13100000001", "nickname": "小甜甜", "password": "password", - "email": "email@email.com" + "email": "email@email.com", + "code": "test", + "key": "key" } > {% @@ -20,7 +22,7 @@ client.test("Request executed successfully", function() { ### 手机号冲突 -POST {{baseUrl}}/account/register +POST http://localhost:8080/board-service/account/register Content-Type: application/json { @@ -40,7 +42,7 @@ client.test("Request executed successfully", function() { ### 手机号错误 -POST {{baseUrl}}/account/register +POST http://localhost:8080/board-service/account/register Content-Type: application/json { @@ -60,7 +62,7 @@ client.test("Request executed successfully", function() { ### 邮箱错误 -POST {{baseUrl}}/account/register +POST http://localhost:8080/board-service/account/register Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/account.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/account.http index a2f0bbf36..f019315c6 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/account.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/account.http @@ -1,11 +1,13 @@ ### 登录 -POST {{baseUrl}}/account/login +POST http://localhost:8080/board-service/account/login Content-Type: application/json { "phone_number": "13100000001", - "password": "password" + "password": "password", + "code": "code", + "key": "key" } > {% @@ -21,7 +23,7 @@ client.global.set("token", response.body.data.token); ### 修改密码 -POST {{baseUrl}}/account/update_password +POST http://localhost:8080/board-service/account/update_password Content-Type: application/json token: {{token}} @@ -33,7 +35,7 @@ token: {{token}} ### 修改密码后再次登录 -POST {{baseUrl}}/account/login +POST http://localhost:8080/board-service/account/login Content-Type: application/json { @@ -53,7 +55,7 @@ client.global.set("token", response.body.data.token); ### 再把密码改回去 -POST {{baseUrl}}/account/update_password +POST http://localhost:8080/board-service/account/update_password Content-Type: application/json token: {{token}} @@ -64,7 +66,7 @@ token: {{token}} ### 分页查询 -POST {{baseUrl}}/account/query +POST http://localhost:8080/board-service/account/query Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/enable.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/enable.http index a81c77d2e..ebc27dbf4 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/enable.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/enable.http @@ -1,10 +1,10 @@ -POST {{baseUrl}}/account/enable +POST http://localhost:8080/board-service/account/enable Content-Type: application/json { - "id":"9514861dd5a24ad8bf2f6f77a412867b", - "enable":false + "id": "9514861dd5a24ad8bf2f6f77a412867b", + "enable": false } ### \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/reset-password.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/reset-password.http index 654bf0358..eec0fb453 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/reset-password.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/reset-password.http @@ -1,9 +1,9 @@ -POST {{baseUrl}}/account/reset/password +POST http://localhost:8080/board-service/account/reset/password Content-Type: application/json token: {{token}} { - "id":"6ac57a3898714273a2bb38f6ba959c78" + "id": "6ac57a3898714273a2bb38f6ba959c78" } ### \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/update.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/update.http index 075c3d3a3..e562acead 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/update.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/account/test/update.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/account/update +POST http://localhost:8080/board-service/account/update Content-Type: application/json token: {{token}} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/AddApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/AddApi.java index 120b8da27..2495f23ef 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/AddApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/AddApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/BlacklistApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/BlacklistApi.java index 9271e7115..f4de25bef 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/BlacklistApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/BlacklistApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/BlacklistMemberApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/BlacklistMemberApi.java index 75ef0812f..cc5b2ab91 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/BlacklistMemberApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/BlacklistMemberApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/DeleteApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/DeleteApi.java index 35bceff18..d7394c1c0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/DeleteApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/DeleteApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/test/blacklist.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/test/blacklist.http index 23bafc9c8..0c1851c77 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/test/blacklist.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/blacklist/test/blacklist.http @@ -1,5 +1,5 @@ ### 查询全部数据集 -POST {{baseUrl}}/blacklist/list +POST http://localhost:8080/board-service/blacklist/list Content-Type: application/json token:9a90763c-1313-41c8-b2a2-b93d0dd66d93 @@ -15,13 +15,15 @@ client.test("Request executed successfully", function() { ### -POST {{baseUrl}}/blacklist/add +POST http://localhost:8080/board-service/blacklist/add Content-Type: application/json token:4c1d7dcf-89ee-44e4-8649-91f1055062b5 { - "memberIds":["123"], - "remark":"test" + "memberIds": [ + "123" + ], + "remark": "test" } ### \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/AddChatLastAccountApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/AddChatLastAccountApi.java index a05cb2040..d870c6d86 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/AddChatLastAccountApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/AddChatLastAccountApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/DeleteChatLastAccountApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/DeleteChatLastAccountApi.java index 95fd9e57a..4cf657051 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/DeleteChatLastAccountApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/DeleteChatLastAccountApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/QueryChatDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/QueryChatDetailApi.java index 757c36b4f..ce2e9933b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/QueryChatDetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/QueryChatDetailApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -57,9 +57,7 @@ public static class Input extends PagingInput { */ @Check(require = true) private String toAccountId; - /** - * query timestamp limit - */ + @Check(name = "query timestamp limit") private Long limitCreateTime; public String getFromAccountId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/QueryChatLastAccountApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/QueryChatLastAccountApi.java index e461d4cda..f571cbf7c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/QueryChatLastAccountApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/QueryChatLastAccountApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/ResendMessageApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/ResendMessageApi.java index 7eebbf84e..c678502b4 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/ResendMessageApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/ResendMessageApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/SendMessageApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/SendMessageApi.java index 1e89457c7..5fbff269a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/SendMessageApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/SendMessageApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -79,9 +79,7 @@ public static class Input extends AbstractApiInput { */ @Check(require = true) private String content; - /** - * Message ID used by the front end - */ + @Check(name = "Message ID used by the front end") private String messageId; public String getToMemberId() { @@ -134,13 +132,9 @@ public void setMessageId(String messageId) { } public static class Output extends AbstractApiOutput { - /** - * Message ID used by the front end - */ + @Check(name = "Message ID used by the front end") private String messageId; - /** - * Back end database message ID - */ + @Check(name = "Back end database message ID") private String memberChatId; public String getMessageId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/UnreadMessageIncreaseOneApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/UnreadMessageIncreaseOneApi.java index be44ecd95..d14cb77a8 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/UnreadMessageIncreaseOneApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/UnreadMessageIncreaseOneApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/UpdateToReadApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/UpdateToReadApi.java index 5f8fc1115..7adfaefbd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/UpdateToReadApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/chat/UpdateToReadApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/component/ListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/component/ListApi.java index 4d4eef656..76772e1fc 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/component/ListApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/component/ListApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,14 +17,14 @@ package com.welab.wefe.board.service.api.component; import com.welab.wefe.board.service.dto.entity.component.ComponentOutputModel; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.FederatedLearningType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; import java.util.Arrays; import java.util.List; @@ -42,6 +42,8 @@ protected ApiResult> handle(Input input) throws Statu List list = Arrays.stream(ComponentType.values()) .filter(x -> input.getFederatedLearningType() == null || x.getFederatedLearningTypes() == null || x.getFederatedLearningTypes().contains(input.federatedLearningType)) + // 排除深度学习组件 + .filter(x -> !x.isDeepLearningComponents()) // Exclude the relevant components of the validation data set, which have not been developed yet. .filter(x -> !x.name().contains("ValidationDataSetLoader")) .map(x -> new ComponentOutputModel(x.name(), x.getLabel(), x.getDesc())).collect(Collectors.toList()); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/component/test/list.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/component/test/list.http index ae08eea3b..bd6547653 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/component/test/list.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/component/test/list.http @@ -1,11 +1,11 @@ ### 查全部 -POST {{baseUrl}}/component/list +POST http://localhost:8080/board-service/component/list Content-Type: application/json {} ### 查横向 -POST {{baseUrl}}/component/list +POST http://localhost:8080/board-service/component/list Content-Type: application/json { @@ -14,7 +14,7 @@ Content-Type: application/json ### 查纵向 -POST {{baseUrl}}/component/list +POST http://localhost:8080/board-service/component/list Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/ModelExportController.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/ModelExportController.java index 945a58379..674172c56 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/ModelExportController.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/ModelExportController.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,10 +18,11 @@ import com.welab.wefe.board.service.service.modelexport.ModelExportService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.ModelExportLanguage; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.web.service.account.AccountInfo; +import com.welab.wefe.common.wefe.enums.ModelExportLanguage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -60,7 +61,7 @@ public void download(HttpServletRequest httpServletRequest, HttpServletResponse token = httpServletRequest.getParameter("token"); httpServletResponse.setCharacterEncoding("UTF-8"); out = httpServletResponse.getWriter(); - CurrentAccount.Info info = CurrentAccount.get(token); + AccountInfo info = CurrentAccount.get(token); if (null == info) { out.write(JObject.create().append("code", StatusCode.PARAMETER_VALUE_INVALID.getCode()) .append("message", "未登录").toString()); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/ModelExportToFileApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/ModelExportToFileApi.java new file mode 100644 index 000000000..786ae716a --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/ModelExportToFileApi.java @@ -0,0 +1,91 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.data_output_info; + +import com.alibaba.fastjson.JSON; +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.service.ServingService; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.ResponseEntity; + +import java.io.File; +import java.util.TreeMap; + +/** + * @author hunter.zhao + * @date 2022/3/7 + */ +@Api(path = "data_output_info/model_export_to_file", name = "导出模型到文件中") +public class ModelExportToFileApi extends AbstractApi> { + + + @Autowired + ServingService servingService; + + @Override + protected ApiResult> handle(Input input) throws Exception { + + TreeMap body = servingService.setBody(input.getTaskId(), input.getRole()); + + File file = WeFeFileSystem + .getBaseDir(WeFeFileSystem.UseType.Temp) + .resolve(input.getTaskId() + ".json") + .toFile(); + + FileUtil.writeTextToFile(JSON.toJSONString(body), file.toPath(), false); + + return file(file); + } + + + public static class Input extends AbstractApiInput { + + @Check(name = "taskId", require = true) + private String taskId; + + @Check(name = "模型角色", require = true) + private JobMemberRole role; + + //region getter/setter + + public String getTaskId() { + return taskId; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public JobMemberRole getRole() { + return role; + } + + public void setRole(JobMemberRole role) { + this.role = role; + } + + + //endregion + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/SyncModelToServingApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/SyncModelToServingApi.java index 998fe72e4..bb103cdbd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/SyncModelToServingApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/SyncModelToServingApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,13 +17,13 @@ package com.welab.wefe.board.service.api.data_output_info; import com.welab.wefe.board.service.service.ServingService; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; /** @@ -46,9 +46,6 @@ public static class Input extends AbstractApiInput { @Check(name = "taskId", require = true) private String taskId; -// @Check(name = "type", require = true) -// private TaskResultType type = TaskResultType.model_train; - @Check(name = "模型角色", require = true) private JobMemberRole role; @@ -62,14 +59,6 @@ public void setTaskId(String taskId) { this.taskId = taskId; } -// public TaskResultType getType() { -// return type; -// } -// -// public void setType(TaskResultType type) { -// this.type = type; -// } - public JobMemberRole getRole() { return role; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/test/export.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/test/export.http new file mode 100644 index 000000000..8546ecc0e --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/test/export.http @@ -0,0 +1,13 @@ + +### + +POST http://localhost:8080/board-service/data_output_info/model_export_to_file +Content-Type: application/json +token:{{token}} + +{ + "taskId": "test", + "role": "promoter" +} + +### \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/test/model_push.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/test/model_push.http index 3d3361340..bf7e2b1f9 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/test/model_push.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_output_info/test/model_push.http @@ -1,12 +1,12 @@ ### -POST {{baseUrl}}/data_output_info/model_push +POST http://localhost:8080/board-service/data_output_info/model_push Content-Type: application/json token:4c1d7dcf-89ee-44e4-8649-91f1055062b5 { - "id":"a475edb2c3e34b33896e07da468bdf5e" + "id": "a475edb2c3e34b33896e07da468bdf5e" } ### \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/DataResourceQueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/DataResourceQueryApi.java new file mode 100644 index 000000000..15e961153 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/DataResourceQueryApi.java @@ -0,0 +1,129 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource; + +import com.welab.wefe.board.service.dto.base.PagingInput; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.service.data_resource.DataResourceService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.List; + +/** + * @author Zane + */ +@Api(path = "data_resource/query", name = "query all kinds of data resource") +public class DataResourceQueryApi extends AbstractApi> { + + @Autowired + private DataResourceService dataResourceService; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + PagingOutput output = dataResourceService.query(input); + return success(output); + + } + + public static class Input extends PagingInput { + @Check(name = "资源Id") + private String id; + @Check(name = "过滤器名称") + private String name; + @Check(name = "标签") + private String tag; + @Check(name = "上传者") + private String creator; + @Check(name = "资源类型") + private List dataResourceType; + + /***********↓ TableDataSet ↓***********/ + @Check(name = "是否包含 Y 值") + private Boolean containsY; + + /***********↓ ImageDataSet ↓***********/ + @Check(name = "任务类型(分类、目标检测)") + private DeepLearningJobType forJobType; + + // region getter/setter + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getTag() { + return tag; + } + + public void setTag(String tag) { + this.tag = tag; + } + + public String getCreator() { + return creator; + } + + public void setCreator(String creator) { + this.creator = creator; + } + + public List getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(List dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public Boolean getContainsY() { + return containsY; + } + + public void setContainsY(Boolean containsY) { + this.containsY = containsY; + } + + public DeepLearningJobType getForJobType() { + return forJobType; + } + + public void setForJobType(DeepLearningJobType forJobType) { + this.forJobType = forJobType; + } + + // endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/ListTagsApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/ListTagsApi.java new file mode 100644 index 000000000..a429412e4 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/ListTagsApi.java @@ -0,0 +1,99 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource; + + +import com.welab.wefe.board.service.database.repository.data_resource.TableDataSetRepository; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * @author Zane + */ +@Api(path = "data_resource/tags", name = "all of the table data set tags") +public class ListTagsApi extends AbstractApi { + + @Autowired + TableDataSetRepository repo; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + + List list = CacheObjects + .getDataResourceTags(null) + .entrySet() + .stream() + .filter(x -> { + if (StringUtil.isEmpty(input.tag)) { + return true; + } + return x.getKey().contains(input.tag); + }) + .map(x -> new Item(x.getKey(), x.getValue())) + .collect(Collectors.toList()); + + list.sort(Comparator.comparingInt(x -> x.count)); + Collections.reverse(list); + + return success(new Output(list)); + } + + public static class Input extends AbstractApiInput { + @Check(name = "tag关键字,用于模糊搜索(联想输入)") + public String tag; + + @Check(name = "资源类型") + public List dataResourceType; + } + + public static class Output { + public List list; + + public Output() { + } + + public Output(List list) { + this.list = list; + } + } + + public static class Item { + public int count; + public String tagName; + + public Item() { + } + + public Item(String tagName, int count) { + this.count = count; + this.tagName = tagName; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/UsageInProjectListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/UsageInProjectListApi.java new file mode 100644 index 000000000..c3c5cd024 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/UsageInProjectListApi.java @@ -0,0 +1,61 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource; + +import com.welab.wefe.board.service.dto.entity.project.ProjectUsageDetailOutputModel; +import com.welab.wefe.board.service.service.data_resource.DataResourceService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; +import java.util.List; + +/** + * @author zane.luo + */ +@Api(path = "data_resource/usage_in_project_list", name = "list project by data resource usage") +public class UsageInProjectListApi extends AbstractApi> { + @Autowired + private DataResourceService dataResourceService; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException, IOException { + return success(dataResourceService.queryUsageInProject(input.getDataResourceId())); + } + + public static class Input extends AbstractApiInput { + @Check(name = "资源Id", require = true) + private String dataResourceId; + + //region getter/setter + + public String getDataResourceId() { + return dataResourceId; + } + + public void setDataResourceId(String dataResourceId) { + this.dataResourceId = dataResourceId; + } + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterAddApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterAddApi.java new file mode 100644 index 000000000..48c61ac12 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterAddApi.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.bloom_filter; + +import com.welab.wefe.board.service.dto.vo.data_resource.BloomFilterAddInputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.DataResourceAddOutputModel; +import com.welab.wefe.board.service.service.data_resource.add.BloomFilterAddService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * @author jacky.jiang + */ +@Api(path = "bloom_filter/add", name = "add bloom_filter") +public class BloomFilterAddApi extends AbstractApi { + + @Autowired + private BloomFilterAddService bloomfilterAddService; + + @Override + protected ApiResult handle(BloomFilterAddInputModel input) throws StatusCodeWithException, IOException { + DataResourceAddOutputModel bloomfilterTaskMysqlModel = bloomfilterAddService.add(input); + return success(bloomfilterTaskMysqlModel); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterDataResourceListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterDataResourceListApi.java new file mode 100644 index 000000000..fa3e6d208 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterDataResourceListApi.java @@ -0,0 +1,97 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.bloom_filter; + + +import com.welab.wefe.board.service.dto.entity.BloomFilterDataResourceListOutputModel; +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * @author jacky.jiang + */ +@Api(path = "data_resource/member/query", name = "query data_resource") +public class BloomFilterDataResourceListApi extends AbstractApi { + + @Autowired + private BloomFilterService bloomfilterService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + return success(bloomfilterService.query(input)); + } + + public static class Input extends AbstractApiInput { + @Check(name = "工程 Id", require = true) + private String projectId; + + @Check(name = "成员 Id", require = true) + private String memberId; + + @Check(name = "成员类型", require = true) + private JobMemberRole role; + + @Check(name = "数据资源名称") + private String name; + + //region getter/setter + + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getMemberId() { + return memberId; + } + + public void setMemberId(String memberId) { + this.memberId = memberId; + } + + public JobMemberRole getRole() { + return role; + } + + public void setRole(JobMemberRole role) { + this.role = role; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterDeleteApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterDeleteApi.java new file mode 100644 index 000000000..be639126d --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterDeleteApi.java @@ -0,0 +1,61 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.bloom_filter; + + +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author jacky.jiang + */ +@Api(path = "bloom_filter/delete", name = "delete bloom_filter") +public class BloomFilterDeleteApi extends AbstractNoneOutputApi { + + @Autowired + private BloomFilterService bloomfilterService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + bloomfilterService.delete(input); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(name = "数据集 Id", require = true) + private String id; + + //region getter/setter + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterDetailApi.java new file mode 100644 index 000000000..30b0a6a18 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterDetailApi.java @@ -0,0 +1,71 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.bloom_filter; + + +import com.welab.wefe.board.service.database.entity.data_resource.BloomFilterMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.BloomFilterRepository; +import com.welab.wefe.board.service.dto.entity.data_resource.output.BloomFilterOutputModel; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Jacky.jiang + */ +@Api(path = "bloom_filter/detail", name = "get BloomFilter detail") +public class BloomFilterDetailApi extends AbstractApi { + + @Autowired + BloomFilterRepository bloomFilterRepository; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + + BloomFilterMysqlModel model = bloomFilterRepository.findById(input.id).orElse(null); + + if (model == null) { + return success(); + } + + BloomFilterOutputModel output = ModelMapper.map(model, BloomFilterOutputModel.class); + + return success(output); + + } + + public static class Input extends AbstractApiInput { + private String id; + + //region getter/setter + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterPreviewApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterPreviewApi.java new file mode 100644 index 000000000..5408f67f9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterPreviewApi.java @@ -0,0 +1,363 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.bloom_filter; + +import com.welab.wefe.board.service.constant.BloomfilterAddMethod; +import com.welab.wefe.board.service.constant.DataSetAddMethod; +import com.welab.wefe.board.service.database.entity.DataSourceMysqlModel; +import com.welab.wefe.board.service.dto.entity.data_set.DataSetColumnOutputModel; +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterService; +import com.welab.wefe.board.service.util.AbstractTableDataSetReader; +import com.welab.wefe.board.service.util.CsvTableDataSetReader; +import com.welab.wefe.board.service.util.ExcelTableDataSetReader; +import com.welab.wefe.board.service.util.JdbcManager; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.ListUtil; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.ColumnDataType; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.File; +import java.io.IOException; +import java.sql.Connection; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** + * @author Jacky.jiang + */ +@Api(path = "bloom_filter/preview", name = "preview bloom_filter rows") +public class BloomFilterPreviewApi extends AbstractApi { + + private static final Pattern MATCH_INTEGER_PATTERN = Pattern.compile("^-?\\d{1,9}$"); + private static final Pattern MATCH_LONG_PATTERN = Pattern.compile("^-?\\d{10,}$"); + private static final Pattern MATCH_DOUBLE_PATTERN = Pattern.compile("^-?\\d+\\.\\d+$"); + + @Autowired + BloomFilterService bloomfilterService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + + Output output = new Output(); + // Read data from the database for preview + if (BloomfilterAddMethod.Database.equals(input.getBloomfilterAddMethod())) { + // Test whether SQL can be queried normally + boolean result = bloomfilterService.testSqlQuery(input.getDataSourceId(), input.getSql()); + if (result) { + output = readFromDatabase(input.getDataSourceId(), input.getSql()); + } + } else { + String filename = input.getFilename(); + File file = bloomfilterService.getBloomfilterFile(input.getBloomfilterAddMethod(), filename); + try { + output = readFile(file); + } catch (IOException e) { + LOG.error(e.getClass().getSimpleName() + " " + e.getMessage(), e); + throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); + } + } + + //generateMetadata(output); + + return success(output); + } + + /** + * Parse the dataset file + */ + private Output readFile(File file) throws IOException, StatusCodeWithException { + + + Output output = new Output(); + LinkedHashMap metadata = new LinkedHashMap<>(); + + + // How to consume the first row of a column + Consumer> headRowConsumer = row -> { + + output.header.addAll(row); + + for (String name : output.header) { + DataSetColumnOutputModel column = new DataSetColumnOutputModel(); + column.setName(name); + metadata.put(name, column); + } + + }; + + // Data line consumer + DataRowConsumer dataRowConsumer = new DataRowConsumer(metadata, output); + + + try ( + AbstractTableDataSetReader reader = file.getName().endsWith("csv") + ? new CsvTableDataSetReader(file) + : new ExcelTableDataSetReader(file) + ) { + // Get column header + headRowConsumer.accept(reader.getHeader()); + // Read data row + reader.read(dataRowConsumer, 10000, 10_000); + } + + output.setMetadataList(new ArrayList<>(metadata.values())); + + + return output; + } + + /** + * Data line consumer + */ + private static class DataRowConsumer implements Consumer> { + + private final LinkedHashMap metadata; + private final Output output; + + private boolean allColumnKnowDataType = false; + + + public DataRowConsumer(LinkedHashMap metadata, Output output) { + this.metadata = metadata; + this.output = output; + } + + @Override + public void accept(LinkedHashMap x) { + // The front end only previews 10 rows of data, too many interfaces will freeze. + if (output.rawDataList.size() < 10) { + output.rawDataList.add(x); + } + + if (allColumnKnowDataType) { + return; + } + + // Infer data type + boolean hasUnkonow = true; + for (String name : output.header) { + + DataSetColumnOutputModel column = metadata.get(name); + if (column.getDataType() == null) { + + Object value = x.get(name); + ColumnDataType dataType = inferDataType(String.valueOf(value)); + + if (dataType != null) { + column.setDataType(dataType); + } else { + hasUnkonow = true; + } + } + } + + if (!hasUnkonow) { + allColumnKnowDataType = true; + } + + } + + /** + * Infer data type + */ + private ColumnDataType inferDataType(String value) { + if (AbstractTableDataSetReader.isEmptyValue(value)) { + return null; + } + + if (MATCH_DOUBLE_PATTERN.matcher(value).find()) { + return ColumnDataType.Double; + } + + if (MATCH_LONG_PATTERN.matcher(value).find()) { + return ColumnDataType.Long; + } + + if (MATCH_INTEGER_PATTERN.matcher(value).find()) { + return ColumnDataType.Integer; + } + + return ColumnDataType.String; + } + } + + private Output readFromDatabase(String dataSourceId, String sql) throws StatusCodeWithException { + DataSourceMysqlModel model = bloomfilterService.getDataSourceById(dataSourceId); + if (model == null) { + throw new StatusCodeWithException("dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); + } + + Connection conn = JdbcManager.getConnection( + model.getDatabaseType(), + model.getHost(), + model.getPort(), + model.getUserName(), + model.getPassword(), + model.getDatabaseName() + ); + + // Get the column header of the data set + List header = JdbcManager.getRowHeaders(conn, sql); + if (header.stream().distinct().count() != header.size()) { + throw new StatusCodeWithException("数据集包含重复的字段,请处理后重新上传。", StatusCode.PARAMETER_VALUE_INVALID); + } + + // Convert uppercase Y to lowercase y + header = header.stream().map(x -> "Y".equals(x) ? "y" : x).collect(Collectors.toList()); + + boolean containsY = header.contains("y"); + int yIndex = header.indexOf("y"); + + // If there is a y column, move y to the second column (the first column is the primary key). + if (containsY) { + ListUtil.moveElement(header, yIndex, 1); + } + + Output output = new Output(); + LinkedHashMap metadata = new LinkedHashMap<>(); + output.setHeader(header); + + for (String name : output.header) { + DataSetColumnOutputModel column = new DataSetColumnOutputModel(); + column.setName(name); + metadata.put(name, column); + } + + // Data line consumer + DataRowConsumer dataRowConsumer = new DataRowConsumer(metadata, output); + + JdbcManager.readWithFieldRow(conn, sql, dataRowConsumer, 10); + + + output.setMetadataList(new ArrayList<>(metadata.values())); + + return output; + } + + + //region dto + + public static class Input extends AbstractApiInput { + + @Check(require = true, name = "文件名", messageOnEmpty = "请指定过滤器文件") + private String filename; + + @Check(require = true, name = "数据集添加方法") + private BloomfilterAddMethod bloomfilterAddMethod; + + @Check(name = "数据源id") + private String dataSourceId; + + @Check(name = "sql脚本") + private String sql; + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + // If the source is a database, dataSourceId and sql must not be empty. + if (DataSetAddMethod.Database.equals(bloomfilterAddMethod)) { + if (StringUtils.isEmpty(dataSourceId)) { + throw new StatusCodeWithException("dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); + } + + if (StringUtils.isEmpty(sql)) { + throw new StatusCodeWithException("请填入sql查询语句", StatusCode.PARAMETER_CAN_NOT_BE_EMPTY); + } + } + } + + public String getFilename() { + return filename; + } + + public void setFilename(String filename) { + this.filename = filename; + } + + public BloomfilterAddMethod getBloomfilterAddMethod() { + return bloomfilterAddMethod; + } + + public void setBloomfilterAddMethod(BloomfilterAddMethod bloomfilterAddMethod) { + this.bloomfilterAddMethod = bloomfilterAddMethod; + } + + public String getDataSourceId() { + return dataSourceId; + } + + public void setDataSourceId(String dataSourceId) { + this.dataSourceId = dataSourceId; + } + + public String getSql() { + return sql; + } + + public void setSql(String sql) { + this.sql = sql; + } + } + + public static class Output { + @Check(name = "字段列表") + private List header = new ArrayList<>(); + @Check(name = "原始数据列表") + private List> rawDataList = new ArrayList<>(); + @Check(name = "元数据信息") + private List metadataList = new ArrayList<>(); + + + public List getHeader() { + return header; + } + + public void setHeader(List header) { + this.header = header; + } + + public List> getRawDataList() { + return rawDataList; + } + + public void setRawDataList(List> rawDataList) { + this.rawDataList = rawDataList; + } + + public List getMetadataList() { + return metadataList; + } + + public void setMetadataList(List metadataList) { + this.metadataList = metadataList; + } + } + + + //endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterUpdateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterUpdateApi.java new file mode 100644 index 000000000..50ee5babe --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/bloom_filter/BloomFilterUpdateApi.java @@ -0,0 +1,43 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.bloom_filter; + +import com.welab.wefe.board.service.dto.vo.data_resource.BloomFilterUpdateInputModel; +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Jacky.jiang + */ +@Api(path = "bloom_filter/update", name = "update bloom filter info") +public class BloomFilterUpdateApi extends AbstractNoneOutputApi { + + @Autowired + private BloomFilterService bloomFilterService; + + @Override + protected ApiResult handler(BloomFilterUpdateInputModel input) throws StatusCodeWithException { + bloomFilterService.update(input); + + return success(); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetAddApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetAddApi.java new file mode 100644 index 000000000..255065897 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetAddApi.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.image_data_set; + +import com.welab.wefe.board.service.dto.vo.data_resource.DataResourceAddOutputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.ImageDataSetAddInputModel; +import com.welab.wefe.board.service.service.data_resource.add.ImageDataSetAddService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * @author Zane + */ +@Api(path = "image_data_set/add", name = "add image data set") +public class ImageDataSetAddApi extends AbstractApi { + + @Autowired + private ImageDataSetAddService imageDataSetAddService; + + @Override + protected ApiResult handle(ImageDataSetAddInputModel input) throws StatusCodeWithException, IOException { + DataResourceAddOutputModel output = imageDataSetAddService.add(input); + return success(output); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetDeleteApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetDeleteApi.java new file mode 100644 index 000000000..27e5087f1 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetDeleteApi.java @@ -0,0 +1,61 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.image_data_set; + + +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "image_data_set/delete", name = "delete data set") +public class ImageDataSetDeleteApi extends AbstractNoneOutputApi { + + @Autowired + private ImageDataSetService imageDataSetService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + imageDataSetService.delete(input); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(name = "数据集 Id", require = true) + private String id; + + //region getter/setter + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetDetailApi.java new file mode 100644 index 000000000..fa8021e00 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetDetailApi.java @@ -0,0 +1,71 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.image_data_set; + + +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.ImageDataSetRepository; +import com.welab.wefe.board.service.dto.entity.data_resource.output.ImageDataSetOutputModel; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "image_data_set/detail", name = "get a image data set detail") +public class ImageDataSetDetailApi extends AbstractApi { + + @Autowired + ImageDataSetRepository imageDataSetRepository; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + + ImageDataSetMysqlModel model = imageDataSetRepository.findById(input.id).orElse(null); + + if (model == null) { + return success(); + } + + ImageDataSetOutputModel output = ModelMapper.map(model, ImageDataSetOutputModel.class); + + return success(output); + + } + + public static class Input extends AbstractApiInput { + private String id; + + //region getter/setter + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetDownloadApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetDownloadApi.java new file mode 100644 index 000000000..986d832d5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetDownloadApi.java @@ -0,0 +1,53 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.image_data_set; + + +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.ResponseEntity; + +import java.io.File; + +/** + * @author Zane + */ +@Api(path = "image_data_set/download", name = "download image data set file", login = false) +public class ImageDataSetDownloadApi extends AbstractApi> { + + @Autowired + private ImageDataSetService imageDataSetService; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + File file = imageDataSetService.download(input.dataSetId, input.jobId); + return file(file); + } + + public static class Input extends AbstractApiInput { + @Check(name = "数据集 Id", require = true) + public String dataSetId; + public String jobId; + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetUpdateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetUpdateApi.java new file mode 100644 index 000000000..48612dce7 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/ImageDataSetUpdateApi.java @@ -0,0 +1,43 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.image_data_set; + +import com.welab.wefe.board.service.dto.vo.data_resource.ImageDataSetUpdateInputModel; +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "image_data_set/update", name = "update data set info") +public class ImageDataSetUpdateApi extends AbstractNoneOutputApi { + + @Autowired + private ImageDataSetService imageDataSetService; + + @Override + protected ApiResult handler(ImageDataSetUpdateInputModel input) throws StatusCodeWithException { + imageDataSetService.update(input); + + return success(); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleDeleteApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleDeleteApi.java new file mode 100644 index 000000000..21f4784fb --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleDeleteApi.java @@ -0,0 +1,49 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.image_data_set.sample; + + +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetSampleService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "image_data_set_sample/delete", name = "delete image data set sample", login = false) +public class ImageDataSetSampleDeleteApi extends AbstractNoneOutputApi { + + @Autowired + private ImageDataSetSampleService imageDataSetSampleService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + imageDataSetSampleService.delete(input.id); + + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + public String id; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleDownloadApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleDownloadApi.java new file mode 100644 index 000000000..b2cc84669 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleDownloadApi.java @@ -0,0 +1,54 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.image_data_set.sample; + +import com.welab.wefe.board.service.database.entity.data_set.ImageDataSetSampleMysqlModel; +import com.welab.wefe.board.service.database.repository.ImageDataSetSampleRepository; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.ResponseEntity; + +import java.io.File; +import java.io.IOException; + +/** + * @author Zane + */ +@Api(path = "image_data_set_sample/download", name = "download image data set sample") +public class ImageDataSetSampleDownloadApi extends AbstractApi> { + + @Autowired + private ImageDataSetSampleRepository imageDataSetSampleRepository; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException, IOException { + ImageDataSetSampleMysqlModel sample = imageDataSetSampleRepository.findById(input.id).orElse(null); + File file = new File(sample.getFilePath()); + + return file(file); + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + public String id; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleQueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleQueryApi.java new file mode 100644 index 000000000..180012ca6 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleQueryApi.java @@ -0,0 +1,88 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.image_data_set.sample; + +import com.welab.wefe.board.service.dto.base.PagingInput; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.entity.data_set.ImageDataSetSampleOutputModel; +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetSampleService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "image_data_set_sample/query", name = "query image data set samples") +public class ImageDataSetSampleQueryApi extends AbstractApi> { + + @Autowired + private ImageDataSetSampleService imageDataSetSampleService; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + return success(imageDataSetSampleService.query(input)); + } + + public static class Input extends PagingInput { + + @Check(name = "数据集Id") + private String dataSetId; + + @Check(name = "标签名称") + private String label; + + @Check(name = "标签名称使用模糊匹配") + public boolean labelMatchWithContains = false; + + @Check(name = "是否已标注") + private Boolean labeled; + + //region getter/setter + + + public String getDataSetId() { + return dataSetId; + } + + public void setDataSetId(String dataSetId) { + this.dataSetId = dataSetId; + } + + public String getLabel() { + return label; + } + + public void setLabel(String label) { + this.label = label; + } + + public Boolean getLabeled() { + return labeled; + } + + public void setLabeled(Boolean labeled) { + this.labeled = labeled; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleStatisticsApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleStatisticsApi.java new file mode 100644 index 000000000..c52fad77c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleStatisticsApi.java @@ -0,0 +1,89 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.data_resource.image_data_set.sample; + + +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetSampleService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * @author zane + * @date 2021/11/15 + */ +@Api(path = "image_data_set_sample/statistics", name = "statistics the data set labels distribute", login = false) +public class ImageDataSetSampleStatisticsApi extends AbstractApi { + @Autowired + private ImageDataSetSampleService imageDataSetSampleService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + Output output = imageDataSetSampleService.statistics(input.dataSetId); + return success(output); + } + + public static class Output { + + @Check(name = "按 label 统计 label 数量", desc = "例:一个样本中有三个 apple,apple 计数三次。") + public List countByLabel; + @Check(name = "按样本统计 label 数量", desc = "例:一个样本中有三个 apple,apple 计数一次。") + public List countBySample; + + public Output() { + } + + public Output(Map countByLabel, Map countBySample) { + this.countByLabel = mapToList(countByLabel); + this.countBySample = mapToList(countBySample); + } + + private List mapToList(Map map) { + return map + .entrySet() + .stream() + .map(x -> new Item(x.getKey(), x.getValue())) + .collect(Collectors.toList()); + } + + public static class Item { + public String label; + public int count; + + public Item() { + } + + public Item(String label, int count) { + this.label = label; + this.count = count; + } + } + } + + public static class Input extends AbstractApiInput { + @Check(name = "数据集Id") + public String dataSetId; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleUpdateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleUpdateApi.java new file mode 100644 index 000000000..94aa805c3 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/ImageDataSetSampleUpdateApi.java @@ -0,0 +1,51 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.image_data_set.sample; + +import com.welab.wefe.board.service.dto.vo.data_set.image_data_set.LabelInfo; +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetSampleService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "image_data_set_sample/update", name = "update image data set sample info") +public class ImageDataSetSampleUpdateApi extends AbstractNoneOutputApi { + + @Autowired + private ImageDataSetSampleService imageDataSetSampleService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + imageDataSetSampleService.update(input); + + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + public String id; + @Check(require = true, name = "标注信息") + public LabelInfo labelInfo; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/download.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/download.http new file mode 100644 index 000000000..87366773c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/download.http @@ -0,0 +1,8 @@ + +### 修改数据集 +POST http://localhost:8080/board-service/image_data_set_sample/download +Content-Type: application/json + +{ + "id": "02ad0876389241e58a846069508ce1d3" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/query.http new file mode 100644 index 000000000..c4a06cbe0 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/query.http @@ -0,0 +1,10 @@ + +### 修改数据集 +POST http://localhost:8080/board-service/image_data_set_sample/query +Content-Type: application/json +token: {{token}} + +{ + "label": "male", + "data_set_id": "d04839004d38421cbed66b6ef9791037" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/statistics.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/statistics.http new file mode 100644 index 000000000..552223fc1 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/statistics.http @@ -0,0 +1,8 @@ + +### 统计数据集样本分布 +POST http://localhost:8080/board-service/image_data_set_sample/statistics +Content-Type: application/json + +{ + "data_set_id": "f39cce565d70441d89643babe77bb27b" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/update.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/update.http new file mode 100644 index 000000000..21b8006ef --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/sample/test/update.http @@ -0,0 +1,25 @@ + +### 更新样本标注信息 +POST http://localhost:8080/board-service/image_data_set_sample/update +Content-Type: application/json + +{ + "id": "00d0a0bf6b38414ca12575e18aba88d7", + "label_info": { + "objects": [ + { + "label": "hello", + "points": [ + { + "x": 123, + "y": 123 + }, + { + "x": 123, + "y": 123 + } + ] + } + ] + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/add.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/add.http new file mode 100644 index 000000000..452b38700 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/add.http @@ -0,0 +1,36 @@ + +### 添加图片数据集-目标检测 +POST localhost:8080/board-service/image_data_set/add +Content-Type: application/json +token: {{token}} + +{ + "publicLevel": "Public", + "name": "zane test10", + "tags": [ + "12321" + ], + "description": "", + "public_member_list": "", + "filename": "fl_fruit_yippee.zip", + "for_job_type": "detection" +} + + +### 添加图片数据集-图像分类 +POST localhost:8080/board-service/image_data_set/add +Content-Type: application/json +token: {{token}} + +{ + "publicLevel": "Public", + "name": "zane test12", + "tags": [ + "12321" + ], + "description": "", + "public_member_list": "", + "filename": "flowers.zip", + "for_job_type": "classify" +} + diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/detail.http new file mode 100644 index 000000000..dede247fc --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/detail.http @@ -0,0 +1,9 @@ + +### 查询单个数据集 +POST http://localhost:8080/board-service/image_data_set/detail +Content-Type: application/json +token: {{token}} + +{ + "id": "37faca29272b462c95c29d631dffb342" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/download.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/download.http new file mode 100644 index 000000000..23d7c78bf --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/download.http @@ -0,0 +1,10 @@ + +### 下载数据集文件 +POST http://localhost:8080/board-service/image_data_set/download +Content-Type: application/json +token: {{token}} + +{ + "data_set_id": "5217bd815d0b4612a93aff0c1fac3504", + "job_id": "777c3d2841ea4b078b8ed46aef4604c4" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/query.http new file mode 100644 index 000000000..4d83c8c69 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/query.http @@ -0,0 +1,9 @@ + +### 查询全部数据集 +POST http://localhost:8080/board-service/image_data_set/query +Content-Type: application/json +token: {{token}} + +{ + "id": "d1eb80bdeb744af19f4ea43604ca68ab" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/update.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/update.http new file mode 100644 index 000000000..47f24c8c9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/image_data_set/test/update.http @@ -0,0 +1,15 @@ + +### 修改数据集 +POST http://localhost:8080/board-service/image_data_set/update +Content-Type: application/json +token: {{token}} + +{ + "id": "154c6d6d69124b14adeb8d148de425e7", + "name": "Euler图学习开源数据集", + "tags": [ + "图" + ], + "description": "Euler图学习平台自研算法对应的开源图数据与样本数据", + "publicLevel": "Public" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/AddApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/AddApi.java new file mode 100644 index 000000000..94bfdfec5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/AddApi.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.table_data_set; + +import com.welab.wefe.board.service.dto.vo.data_resource.DataResourceAddOutputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.TableDataSetAddInputModel; +import com.welab.wefe.board.service.service.data_resource.add.TableDataSetAddService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * @author Zane + */ +@Api(path = "table_data_set/add", name = "add data set") +public class AddApi extends AbstractApi { + + @Autowired + private TableDataSetAddService tableDataSetAddService; + + @Override + protected ApiResult handle(TableDataSetAddInputModel input) throws StatusCodeWithException, IOException { + DataResourceAddOutputModel output = tableDataSetAddService.add(input); + return success(output); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/DetailApi.java new file mode 100644 index 000000000..489f18f94 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/DetailApi.java @@ -0,0 +1,71 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.table_data_set; + + +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.TableDataSetRepository; +import com.welab.wefe.board.service.dto.entity.data_resource.output.TableDataSetOutputModel; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "table_data_set/detail", name = "get data set detail") +public class DetailApi extends AbstractApi { + + @Autowired + TableDataSetRepository dataSetRepository; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + + TableDataSetMysqlModel model = dataSetRepository.findById(input.id).orElse(null); + + if (model == null) { + return success(); + } + + TableDataSetOutputModel output = ModelMapper.map(model, TableDataSetOutputModel.class); + + return success(output); + + } + + public static class Input extends AbstractApiInput { + private String id; + + //region getter/setter + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/ListServerLocalFilesApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/ListServerLocalFilesApi.java new file mode 100644 index 000000000..7e8d6d205 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/ListServerLocalFilesApi.java @@ -0,0 +1,94 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.table_data_set; + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiOutput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.dto.NoneApiInput; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; + +/** + * @author Johnny.lin + */ +@Api(path = "data_set/list_local_data_set_files", name = "query the files in the specified directory on the server") +public class ListServerLocalFilesApi extends AbstractApi { + + @Autowired + private Config config; + + private static final List SUPPORT_SUFFIX = new ArrayList(); + + static { + SUPPORT_SUFFIX.add("xls"); + SUPPORT_SUFFIX.add("xlsx"); + SUPPORT_SUFFIX.add("csv"); + } + + @Override + protected ApiResult handle(NoneApiInput input) throws StatusCodeWithException { + List files = new ArrayList<>(); + File file = WeFeFileSystem.getRootDir().toFile(); + LOG.info("file.exists(): " + file.exists()); + + File[] tempList = file.listFiles(); + for (File fileObj : tempList) { + if (fileObj.isFile()) { + LOG.info("file: " + fileObj); + + //File name, excluding path + String fileName = fileObj.getName(); + String suffix = fileName.substring(fileName.lastIndexOf(".") + 1); + + //Only XLS, xlsx, and CSV files are displayed + if (!SUPPORT_SUFFIX.contains(suffix.toLowerCase())) { + continue; + } + + files.add(fileName); + } + } + + Output output = new Output(); + output.setFiles(files); + return success(output); + } + + public static class Output extends AbstractApiOutput { + private List files; + + public List getFiles() { + return files; + } + + public void setFiles(List files) { + this.files = files; + } + + public Output() { + + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/PreviewApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/PreviewApi.java similarity index 90% rename from board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/PreviewApi.java rename to board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/PreviewApi.java index 226a57bd1..4f76111fa 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/PreviewApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/PreviewApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,18 +14,17 @@ * limitations under the License. */ -package com.welab.wefe.board.service.api.dataset; +package com.welab.wefe.board.service.api.data_resource.table_data_set; import com.welab.wefe.board.service.constant.DataSetAddMethod; -import com.welab.wefe.board.service.database.entity.DataSourceMySqlModel; +import com.welab.wefe.board.service.database.entity.DataSourceMysqlModel; import com.welab.wefe.board.service.dto.entity.data_set.DataSetColumnOutputModel; -import com.welab.wefe.board.service.service.DataSetService; -import com.welab.wefe.board.service.util.AbstractDataSetReader; -import com.welab.wefe.board.service.util.CsvDataSetReader; -import com.welab.wefe.board.service.util.ExcelDataSetReader; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; +import com.welab.wefe.board.service.util.AbstractTableDataSetReader; +import com.welab.wefe.board.service.util.CsvTableDataSetReader; +import com.welab.wefe.board.service.util.ExcelTableDataSetReader; import com.welab.wefe.board.service.util.JdbcManager; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.ColumnDataType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.ListUtil; @@ -33,6 +32,7 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.ColumnDataType; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -50,7 +50,7 @@ /** * @author Zane */ -@Api(path = "data_set/preview", name = "preview data set rows") +@Api(path = "table_data_set/preview", name = "preview data set rows") public class PreviewApi extends AbstractApi { private static final Pattern MATCH_INTEGER_PATTERN = Pattern.compile("^-?\\d{1,9}$"); @@ -58,7 +58,7 @@ public class PreviewApi extends AbstractApi private static final Pattern MATCH_DOUBLE_PATTERN = Pattern.compile("^-?\\d+\\.\\d+$"); @Autowired - DataSetService dataSetService; + TableDataSetService tableDataSetService; @Override protected ApiResult handle(Input input) throws StatusCodeWithException { @@ -67,13 +67,13 @@ protected ApiResult handle(Input input) throws StatusCodeWithException { // Read data from the database for preview if (DataSetAddMethod.Database.equals(input.getDataSetAddMethod())) { // Test whether SQL can be queried normally - boolean result = dataSetService.testSqlQuery(input.getDataSourceId(), input.getSql()); + boolean result = tableDataSetService.testSqlQuery(input.getDataSourceId(), input.getSql()); if (result) { output = readFromDatabase(input.getDataSourceId(), input.getSql()); } } else { - File file = dataSetService.getDataSetFile(input.getDataSetAddMethod(), input.getFilename()); + File file = tableDataSetService.getDataSetFile(input.getDataSetAddMethod(), input.getFilename()); try { output = readFile(file); } catch (IOException e) { @@ -115,9 +115,9 @@ private Output readFile(File file) throws IOException, StatusCodeWithException { try ( - AbstractDataSetReader reader = file.getName().endsWith("csv") - ? new CsvDataSetReader(file) - : new ExcelDataSetReader(file) + AbstractTableDataSetReader reader = file.getName().endsWith("csv") + ? new CsvTableDataSetReader(file) + : new ExcelTableDataSetReader(file) ) { // Get column header headRowConsumer.accept(reader.getHeader()); @@ -186,7 +186,7 @@ public void accept(LinkedHashMap x) { * Infer data type */ private ColumnDataType inferDataType(String value) { - if (AbstractDataSetReader.isEmptyValue(value)) { + if (AbstractTableDataSetReader.isEmptyValue(value)) { return null; } @@ -207,7 +207,7 @@ private ColumnDataType inferDataType(String value) { } private Output readFromDatabase(String dataSourceId, String sql) throws StatusCodeWithException { - DataSourceMySqlModel model = dataSetService.getDataSourceById(dataSourceId); + DataSourceMysqlModel model = tableDataSetService.getDataSourceById(dataSourceId); if (model == null) { throw new StatusCodeWithException("dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/TableDataSetDeleteApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/TableDataSetDeleteApi.java new file mode 100644 index 000000000..201983930 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/TableDataSetDeleteApi.java @@ -0,0 +1,61 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.table_data_set; + + +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "table_data_set/delete", name = "delete data set") +public class TableDataSetDeleteApi extends AbstractNoneOutputApi { + + @Autowired + private TableDataSetService tableDataSetService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + tableDataSetService.delete(input); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(name = "数据集 Id", require = true) + private String id; + + //region getter/setter + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/UpdateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/UpdateApi.java new file mode 100644 index 000000000..8aa0ed752 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/UpdateApi.java @@ -0,0 +1,43 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.table_data_set; + +import com.welab.wefe.board.service.dto.vo.data_resource.TableDataSetUpdateInputModel; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "table_data_set/update", name = "update data set info") +public class UpdateApi extends AbstractNoneOutputApi { + + @Autowired + private TableDataSetService tableDataSetService; + + @Override + protected ApiResult handler(TableDataSetUpdateInputModel input) throws StatusCodeWithException { + tableDataSetService.update(input); + + return success(); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/column/ListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/column/ListApi.java new file mode 100644 index 000000000..bd4249b95 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/column/ListApi.java @@ -0,0 +1,62 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.table_data_set.column; + +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.entity.data_set.DataSetColumnOutputModel; +import com.welab.wefe.board.service.service.DataSetColumnService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "table_data_set/column/list", name = "list of data set fields") +public class ListApi extends AbstractApi> { + + @Autowired + private DataSetColumnService service; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + return success(service.list(input.getDataSetId())); + } + + public static class Input extends AbstractApiInput { + + @Check(require = true, name = "数据集Id") + private String dataSetId; + + //region getter/setter + + public String getDataSetId() { + return dataSetId; + } + + public void setDataSetId(String dataSetId) { + this.dataSetId = dataSetId; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-add.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-add.http new file mode 100644 index 000000000..c87f5ac1f --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-add.http @@ -0,0 +1,4927 @@ + +### 添加数据集 +POST http://localhost:8080/board-service/table_data_set/add +Content-Type: application/json +token: {{token}} + +{ + "publicLevel": "Public", + "name": "zane test011", + "tags": [ + "12321" + ], + "description": "", + "public_member_list": "", + "filename": "01-600-y-5.csv", + "data_set_add_method": "HttpUpload", + "metadata_list": [ + { + "data_type": "Integer", + "name": "id", + "comment": "" + }, + { + "data_type": "Integer", + "name": "y", + "comment": "" + }, + { + "data_type": "Double", + "name": "x1", + "comment": "" + }, + { + "data_type": "Double", + "name": "x2", + "comment": "" + }, + { + "data_type": "Double", + "name": "x3", + "comment": "" + }, + { + "data_type": "Double", + "name": "x4", + "comment": "" + }, + { + "data_type": "Double", + "name": "x5", + "comment": "" + } + ], + "deduplication": true +} + + +### 添加数据集 +POST http://localhost:8080/board-service/table_data_set/add +Content-Type: application/json +token: {{token}} + +{ + "publicLevel": "Public", + "name": "zane test011", + "tags": [ + "12321" + ], + "description": "", + "public_member_list": "", + "filename": "data0.csv", + "data_set_add_method": "HttpUpload", + "metadata_list": [ + { + "data_type": "Integer", + "name": "id", + "comment": "" + }, + { + "data_type": "Double", + "name": "x0", + "comment": "" + }, + { + "data_type": "Double", + "name": "x1", + "comment": "" + }, + { + "data_type": "Double", + "name": "x2", + "comment": "" + }, + { + "data_type": "Integer", + "name": "y", + "comment": "" + } + ], + "deduplication": true +} + + +### 添加数据集 +POST http://localhost:8080/board-service/table_data_set/add +Content-Type: application/json +token: {{token}} + +{ + "dataSourceId": "0ac9fb77f11448e49b44e68c76b51495", + "data_set_add_method": "HttpUpload", + "databaseName": "jacky数据源测试", + "databaseType": "Database", + "deduplication": true, + "description": "", + "filename": "cb1d6be1-25ce-478a-914d-649830d1b72a-missing_rate_50_ym_data.csv", + "for_job_type": "classify", + "metadata_list": [ + { + "comment": "", + "data_type": "Integer", + "name": "id", + "$index": 0 + }, + { + "comment": "", + "data_type": "Double", + "name": "2057", + "$index": 1 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1571", + "$index": 2 + }, + { + "comment": "", + "data_type": "Double", + "name": "2050", + "$index": 3 + }, + { + "comment": "", + "data_type": "Double", + "name": "k330", + "$index": 4 + }, + { + "comment": "", + "data_type": "Double", + "name": "k641", + "$index": 5 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1208", + "$index": 6 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1083", + "$index": 7 + }, + { + "comment": "", + "data_type": "Double", + "name": "k441", + "$index": 8 + }, + { + "comment": "", + "data_type": "Double", + "name": "2009", + "$index": 9 + }, + { + "comment": "", + "data_type": "Double", + "name": "k175", + "$index": 10 + }, + { + "comment": "", + "data_type": "Double", + "name": "k517", + "$index": 11 + }, + { + "comment": "", + "data_type": "Double", + "name": "k713", + "$index": 12 + }, + { + "comment": "", + "data_type": "Double", + "name": "323", + "$index": 13 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1142", + "$index": 14 + }, + { + "comment": "", + "data_type": "Double", + "name": "10", + "$index": 15 + }, + { + "comment": "", + "data_type": "Double", + "name": "k776", + "$index": 16 + }, + { + "comment": "", + "data_type": "Double", + "name": "k634", + "$index": 17 + }, + { + "comment": "", + "data_type": "Double", + "name": "k854", + "$index": 18 + }, + { + "comment": "", + "data_type": "Double", + "name": "k490", + "$index": 19 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1386", + "$index": 20 + }, + { + "comment": "", + "data_type": "Double", + "name": "k694", + "$index": 21 + }, + { + "comment": "", + "data_type": "Double", + "name": "k678", + "$index": 22 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1265", + "$index": 23 + }, + { + "comment": "", + "data_type": "Double", + "name": "k559", + "$index": 24 + }, + { + "comment": "", + "data_type": "Double", + "name": "k358", + "$index": 25 + }, + { + "comment": "", + "data_type": "Double", + "name": "287", + "$index": 26 + }, + { + "comment": "", + "data_type": "Double", + "name": "281", + "$index": 27 + }, + { + "comment": "", + "data_type": "Double", + "name": "296", + "$index": 28 + }, + { + "comment": "", + "data_type": "Double", + "name": "2019", + "$index": 29 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1690", + "$index": 30 + }, + { + "comment": "", + "data_type": "Double", + "name": "6", + "$index": 31 + }, + { + "comment": "", + "data_type": "Double", + "name": "k74", + "$index": 32 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1242", + "$index": 33 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1178", + "$index": 34 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1538", + "$index": 35 + }, + { + "comment": "", + "data_type": "Double", + "name": "322", + "$index": 36 + }, + { + "comment": "", + "data_type": "Double", + "name": "k186", + "$index": 37 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1670", + "$index": 38 + }, + { + "comment": "", + "data_type": "Double", + "name": "k481", + "$index": 39 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1745", + "$index": 40 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1552", + "$index": 41 + }, + { + "comment": "", + "data_type": "Double", + "name": "k58", + "$index": 42 + }, + { + "comment": "", + "data_type": "Double", + "name": "362", + "$index": 43 + }, + { + "comment": "", + "data_type": "Double", + "name": "302", + "$index": 44 + }, + { + "comment": "", + "data_type": "Double", + "name": "276", + "$index": 45 + }, + { + "comment": "", + "data_type": "Double", + "name": "271", + "$index": 46 + }, + { + "comment": "", + "data_type": "Double", + "name": "275", + "$index": 47 + }, + { + "comment": "", + "data_type": "Double", + "name": "k735", + "$index": 48 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1535", + "$index": 49 + }, + { + "comment": "", + "data_type": "Double", + "name": "k944", + "$index": 50 + }, + { + "comment": "", + "data_type": "Double", + "name": "301", + "$index": 51 + }, + { + "comment": "", + "data_type": "Double", + "name": "k921", + "$index": 52 + }, + { + "comment": "", + "data_type": "Double", + "name": "303", + "$index": 53 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1247", + "$index": 54 + }, + { + "comment": "", + "data_type": "Double", + "name": "k487", + "$index": 55 + }, + { + "comment": "", + "data_type": "Double", + "name": "k547", + "$index": 56 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1637", + "$index": 57 + }, + { + "comment": "", + "data_type": "Double", + "name": "285", + "$index": 58 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1555", + "$index": 59 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1517", + "$index": 60 + }, + { + "comment": "", + "data_type": "Double", + "name": "97", + "$index": 61 + }, + { + "comment": "", + "data_type": "Double", + "name": "98", + "$index": 62 + }, + { + "comment": "", + "data_type": "Double", + "name": "295", + "$index": 63 + }, + { + "comment": "", + "data_type": "Double", + "name": "293", + "$index": 64 + }, + { + "comment": "", + "data_type": "Double", + "name": "k354", + "$index": 65 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1407", + "$index": 66 + }, + { + "comment": "", + "data_type": "Double", + "name": "325", + "$index": 67 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1660", + "$index": 68 + }, + { + "comment": "", + "data_type": "Double", + "name": "306", + "$index": 69 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1343", + "$index": 70 + }, + { + "comment": "", + "data_type": "Double", + "name": "300", + "$index": 71 + }, + { + "comment": "", + "data_type": "Double", + "name": "307", + "$index": 72 + }, + { + "comment": "", + "data_type": "Double", + "name": "k496", + "$index": 73 + }, + { + "comment": "", + "data_type": "Double", + "name": "k583", + "$index": 74 + }, + { + "comment": "", + "data_type": "Double", + "name": "k26", + "$index": 75 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1578", + "$index": 76 + }, + { + "comment": "", + "data_type": "Double", + "name": "k254", + "$index": 77 + }, + { + "comment": "", + "data_type": "Double", + "name": "k963", + "$index": 78 + }, + { + "comment": "", + "data_type": "Double", + "name": "k546", + "$index": 79 + }, + { + "comment": "", + "data_type": "Double", + "name": "k498", + "$index": 80 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1381", + "$index": 81 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1514", + "$index": 82 + }, + { + "comment": "", + "data_type": "Double", + "name": "282", + "$index": 83 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1518", + "$index": 84 + }, + { + "comment": "", + "data_type": "Double", + "name": "k191", + "$index": 85 + }, + { + "comment": "", + "data_type": "Double", + "name": "2006", + "$index": 86 + }, + { + "comment": "", + "data_type": "Double", + "name": "304", + "$index": 87 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1577", + "$index": 88 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1046", + "$index": 89 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1288", + "$index": 90 + }, + { + "comment": "", + "data_type": "Double", + "name": "274", + "$index": 91 + }, + { + "comment": "", + "data_type": "Double", + "name": "8", + "$index": 92 + }, + { + "comment": "", + "data_type": "Double", + "name": "324", + "$index": 93 + }, + { + "comment": "", + "data_type": "Double", + "name": "315", + "$index": 94 + }, + { + "comment": "", + "data_type": "Double", + "name": "k170", + "$index": 95 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1141", + "$index": 96 + }, + { + "comment": "", + "data_type": "Double", + "name": "2329", + "$index": 97 + }, + { + "comment": "", + "data_type": "Double", + "name": "2330", + "$index": 98 + }, + { + "comment": "", + "data_type": "Double", + "name": "2331", + "$index": 99 + }, + { + "comment": "", + "data_type": "Double", + "name": "k859", + "$index": 100 + }, + { + "comment": "", + "data_type": "Double", + "name": "k199", + "$index": 101 + }, + { + "comment": "", + "data_type": "Double", + "name": "k786", + "$index": 102 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1402", + "$index": 103 + }, + { + "comment": "", + "data_type": "Double", + "name": "k92", + "$index": 104 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1368", + "$index": 105 + }, + { + "comment": "", + "data_type": "Double", + "name": "k97", + "$index": 106 + }, + { + "comment": "", + "data_type": "Double", + "name": "k265", + "$index": 107 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1544", + "$index": 108 + }, + { + "comment": "", + "data_type": "Double", + "name": "292", + "$index": 109 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1196", + "$index": 110 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1680", + "$index": 111 + }, + { + "comment": "", + "data_type": "Double", + "name": "k689", + "$index": 112 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1449", + "$index": 113 + }, + { + "comment": "", + "data_type": "Double", + "name": "k926", + "$index": 114 + }, + { + "comment": "", + "data_type": "Double", + "name": "k190", + "$index": 115 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1205", + "$index": 116 + }, + { + "comment": "", + "data_type": "Double", + "name": "k460", + "$index": 117 + }, + { + "comment": "", + "data_type": "Double", + "name": "k528", + "$index": 118 + }, + { + "comment": "", + "data_type": "Double", + "name": "k971", + "$index": 119 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1639", + "$index": 120 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1043", + "$index": 121 + }, + { + "comment": "", + "data_type": "Double", + "name": "k210", + "$index": 122 + }, + { + "comment": "", + "data_type": "Double", + "name": "k317", + "$index": 123 + }, + { + "comment": "", + "data_type": "Double", + "name": "k651", + "$index": 124 + }, + { + "comment": "", + "data_type": "Double", + "name": "k122", + "$index": 125 + }, + { + "comment": "", + "data_type": "Double", + "name": "312", + "$index": 126 + }, + { + "comment": "", + "data_type": "Double", + "name": "305", + "$index": 127 + }, + { + "comment": "", + "data_type": "Double", + "name": "k180", + "$index": 128 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1418", + "$index": 129 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1416", + "$index": 130 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1665", + "$index": 131 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1501", + "$index": 132 + }, + { + "comment": "", + "data_type": "Double", + "name": "k644", + "$index": 133 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1314", + "$index": 134 + }, + { + "comment": "", + "data_type": "Double", + "name": "k986", + "$index": 135 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1113", + "$index": 136 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1185", + "$index": 137 + }, + { + "comment": "", + "data_type": "Double", + "name": "k860", + "$index": 138 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1593", + "$index": 139 + }, + { + "comment": "", + "data_type": "Double", + "name": "k364", + "$index": 140 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1021", + "$index": 141 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1550", + "$index": 142 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1366", + "$index": 143 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1507", + "$index": 144 + }, + { + "comment": "", + "data_type": "Double", + "name": "5", + "$index": 145 + }, + { + "comment": "", + "data_type": "Double", + "name": "k744", + "$index": 146 + }, + { + "comment": "", + "data_type": "Double", + "name": "k728", + "$index": 147 + }, + { + "comment": "", + "data_type": "Double", + "name": "k937", + "$index": 148 + }, + { + "comment": "", + "data_type": "Double", + "name": "k959", + "$index": 149 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1551", + "$index": 150 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1136", + "$index": 151 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1128", + "$index": 152 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1345", + "$index": 153 + }, + { + "comment": "", + "data_type": "Double", + "name": "k534", + "$index": 154 + }, + { + "comment": "", + "data_type": "Double", + "name": "294", + "$index": 155 + }, + { + "comment": "", + "data_type": "Double", + "name": "k726", + "$index": 156 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1648", + "$index": 157 + }, + { + "comment": "", + "data_type": "Double", + "name": "k889", + "$index": 158 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1424", + "$index": 159 + }, + { + "comment": "", + "data_type": "Double", + "name": "298", + "$index": 160 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1492", + "$index": 161 + }, + { + "comment": "", + "data_type": "Double", + "name": "k829", + "$index": 162 + }, + { + "comment": "", + "data_type": "Double", + "name": "k964", + "$index": 163 + }, + { + "comment": "", + "data_type": "Double", + "name": "316", + "$index": 164 + }, + { + "comment": "", + "data_type": "Double", + "name": "2156", + "$index": 165 + }, + { + "comment": "", + "data_type": "Double", + "name": "2155", + "$index": 166 + }, + { + "comment": "", + "data_type": "Double", + "name": "2076", + "$index": 167 + }, + { + "comment": "", + "data_type": "Double", + "name": "2073", + "$index": 168 + }, + { + "comment": "", + "data_type": "Double", + "name": "313", + "$index": 169 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1084", + "$index": 170 + }, + { + "comment": "", + "data_type": "Double", + "name": "297", + "$index": 171 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1235", + "$index": 172 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1133", + "$index": 173 + }, + { + "comment": "", + "data_type": "Double", + "name": "445", + "$index": 174 + }, + { + "comment": "", + "data_type": "Double", + "name": "209", + "$index": 175 + }, + { + "comment": "", + "data_type": "Double", + "name": "210", + "$index": 176 + }, + { + "comment": "", + "data_type": "Double", + "name": "84", + "$index": 177 + }, + { + "comment": "", + "data_type": "Double", + "name": "83", + "$index": 178 + }, + { + "comment": "", + "data_type": "Double", + "name": "310", + "$index": 179 + }, + { + "comment": "", + "data_type": "Double", + "name": "k885", + "$index": 180 + }, + { + "comment": "", + "data_type": "Double", + "name": "k659", + "$index": 181 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1017", + "$index": 182 + }, + { + "comment": "", + "data_type": "Double", + "name": "2069", + "$index": 183 + }, + { + "comment": "", + "data_type": "Double", + "name": "2067", + "$index": 184 + }, + { + "comment": "", + "data_type": "Double", + "name": "2068", + "$index": 185 + }, + { + "comment": "", + "data_type": "Double", + "name": "2071", + "$index": 186 + }, + { + "comment": "", + "data_type": "Double", + "name": "2070", + "$index": 187 + }, + { + "comment": "", + "data_type": "Double", + "name": "2072", + "$index": 188 + }, + { + "comment": "", + "data_type": "Double", + "name": "2333", + "$index": 189 + }, + { + "comment": "", + "data_type": "Double", + "name": "283", + "$index": 190 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1412", + "$index": 191 + }, + { + "comment": "", + "data_type": "Double", + "name": "290", + "$index": 192 + }, + { + "comment": "", + "data_type": "Double", + "name": "7", + "$index": 193 + }, + { + "comment": "", + "data_type": "Double", + "name": "k274", + "$index": 194 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1006", + "$index": 195 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1673", + "$index": 196 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1432", + "$index": 197 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1059", + "$index": 198 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1640", + "$index": 199 + }, + { + "comment": "", + "data_type": "Double", + "name": "k218", + "$index": 200 + }, + { + "comment": "", + "data_type": "Double", + "name": "277", + "$index": 201 + }, + { + "comment": "", + "data_type": "Double", + "name": "k292", + "$index": 202 + }, + { + "comment": "", + "data_type": "Double", + "name": "321", + "$index": 203 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1303", + "$index": 204 + }, + { + "comment": "", + "data_type": "Double", + "name": "284", + "$index": 205 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1064", + "$index": 206 + }, + { + "comment": "", + "data_type": "Double", + "name": "320", + "$index": 207 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1658", + "$index": 208 + }, + { + "comment": "", + "data_type": "Double", + "name": "k968", + "$index": 209 + }, + { + "comment": "", + "data_type": "Double", + "name": "361", + "$index": 210 + }, + { + "comment": "", + "data_type": "Double", + "name": "288", + "$index": 211 + }, + { + "comment": "", + "data_type": "Double", + "name": "k904", + "$index": 212 + }, + { + "comment": "", + "data_type": "Double", + "name": "357", + "$index": 213 + }, + { + "comment": "", + "data_type": "Double", + "name": "k283", + "$index": 214 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1348", + "$index": 215 + }, + { + "comment": "", + "data_type": "Double", + "name": "k132", + "$index": 216 + }, + { + "comment": "", + "data_type": "Double", + "name": "2332", + "$index": 217 + }, + { + "comment": "", + "data_type": "Double", + "name": "k8", + "$index": 218 + }, + { + "comment": "", + "data_type": "Double", + "name": "268", + "$index": 219 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1156", + "$index": 220 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1603", + "$index": 221 + }, + { + "comment": "", + "data_type": "Double", + "name": "317", + "$index": 222 + }, + { + "comment": "", + "data_type": "Double", + "name": "358", + "$index": 223 + }, + { + "comment": "", + "data_type": "Double", + "name": "314", + "$index": 224 + }, + { + "comment": "", + "data_type": "Double", + "name": "272", + "$index": 225 + }, + { + "comment": "", + "data_type": "Double", + "name": "279", + "$index": 226 + }, + { + "comment": "", + "data_type": "Double", + "name": "291", + "$index": 227 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1078", + "$index": 228 + }, + { + "comment": "", + "data_type": "Double", + "name": "278", + "$index": 229 + }, + { + "comment": "", + "data_type": "Double", + "name": "318", + "$index": 230 + }, + { + "comment": "", + "data_type": "Double", + "name": "270", + "$index": 231 + }, + { + "comment": "", + "data_type": "Double", + "name": "k421", + "$index": 232 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1721", + "$index": 233 + }, + { + "comment": "", + "data_type": "Double", + "name": "311", + "$index": 234 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1565", + "$index": 235 + }, + { + "comment": "", + "data_type": "Double", + "name": "k371", + "$index": 236 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1447", + "$index": 237 + }, + { + "comment": "", + "data_type": "Double", + "name": "k974", + "$index": 238 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1720", + "$index": 239 + }, + { + "comment": "", + "data_type": "Double", + "name": "k640", + "$index": 240 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1391", + "$index": 241 + }, + { + "comment": "", + "data_type": "Double", + "name": "k331", + "$index": 242 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1212", + "$index": 243 + }, + { + "comment": "", + "data_type": "Double", + "name": "k770", + "$index": 244 + }, + { + "comment": "", + "data_type": "Double", + "name": "k24", + "$index": 245 + }, + { + "comment": "", + "data_type": "Double", + "name": "k448", + "$index": 246 + }, + { + "comment": "", + "data_type": "Double", + "name": "k208", + "$index": 247 + }, + { + "comment": "", + "data_type": "Double", + "name": "k272", + "$index": 248 + }, + { + "comment": "", + "data_type": "Double", + "name": "k333", + "$index": 249 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1270", + "$index": 250 + }, + { + "comment": "", + "data_type": "Double", + "name": "k782", + "$index": 251 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1045", + "$index": 252 + }, + { + "comment": "", + "data_type": "Double", + "name": "k673", + "$index": 253 + }, + { + "comment": "", + "data_type": "Double", + "name": "k817", + "$index": 254 + }, + { + "comment": "", + "data_type": "Double", + "name": "k251", + "$index": 255 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1573", + "$index": 256 + }, + { + "comment": "", + "data_type": "Double", + "name": "k710", + "$index": 257 + }, + { + "comment": "", + "data_type": "Double", + "name": "280", + "$index": 258 + }, + { + "comment": "", + "data_type": "Double", + "name": "k282", + "$index": 259 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1697", + "$index": 260 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1151", + "$index": 261 + }, + { + "comment": "", + "data_type": "Double", + "name": "k116", + "$index": 262 + }, + { + "comment": "", + "data_type": "Double", + "name": "k318", + "$index": 263 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1425", + "$index": 264 + }, + { + "comment": "", + "data_type": "Double", + "name": "k259", + "$index": 265 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1632", + "$index": 266 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1228", + "$index": 267 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1279", + "$index": 268 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1050", + "$index": 269 + }, + { + "comment": "", + "data_type": "Double", + "name": "k781", + "$index": 270 + }, + { + "comment": "", + "data_type": "Double", + "name": "k783", + "$index": 271 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1478", + "$index": 272 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1214", + "$index": 273 + }, + { + "comment": "", + "data_type": "Double", + "name": "k426", + "$index": 274 + }, + { + "comment": "", + "data_type": "Double", + "name": "k718", + "$index": 275 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1596", + "$index": 276 + }, + { + "comment": "", + "data_type": "Double", + "name": "k729", + "$index": 277 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1040", + "$index": 278 + }, + { + "comment": "", + "data_type": "Double", + "name": "k805", + "$index": 279 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1065", + "$index": 280 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1511", + "$index": 281 + }, + { + "comment": "", + "data_type": "Double", + "name": "k96", + "$index": 282 + }, + { + "comment": "", + "data_type": "Double", + "name": "k98", + "$index": 283 + }, + { + "comment": "", + "data_type": "Double", + "name": "k51", + "$index": 284 + }, + { + "comment": "", + "data_type": "Double", + "name": "k300", + "$index": 285 + }, + { + "comment": "", + "data_type": "Double", + "name": "k524", + "$index": 286 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1079", + "$index": 287 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1012", + "$index": 288 + }, + { + "comment": "", + "data_type": "Double", + "name": "k989", + "$index": 289 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1457", + "$index": 290 + }, + { + "comment": "", + "data_type": "Double", + "name": "k537", + "$index": 291 + }, + { + "comment": "", + "data_type": "Double", + "name": "k10", + "$index": 292 + }, + { + "comment": "", + "data_type": "Double", + "name": "k566", + "$index": 293 + }, + { + "comment": "", + "data_type": "Double", + "name": "k996", + "$index": 294 + }, + { + "comment": "", + "data_type": "Double", + "name": "k568", + "$index": 295 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1612", + "$index": 296 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1468", + "$index": 297 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1563", + "$index": 298 + }, + { + "comment": "", + "data_type": "Double", + "name": "k197", + "$index": 299 + }, + { + "comment": "", + "data_type": "Double", + "name": "k449", + "$index": 300 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1694", + "$index": 301 + }, + { + "comment": "", + "data_type": "Double", + "name": "k674", + "$index": 302 + }, + { + "comment": "", + "data_type": "Double", + "name": "k696", + "$index": 303 + }, + { + "comment": "", + "data_type": "Double", + "name": "k591", + "$index": 304 + }, + { + "comment": "", + "data_type": "Double", + "name": "k578", + "$index": 305 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1702", + "$index": 306 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1663", + "$index": 307 + }, + { + "comment": "", + "data_type": "Double", + "name": "k195", + "$index": 308 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1171", + "$index": 309 + }, + { + "comment": "", + "data_type": "Double", + "name": "k702", + "$index": 310 + }, + { + "comment": "", + "data_type": "Double", + "name": "k618", + "$index": 311 + }, + { + "comment": "", + "data_type": "Double", + "name": "k704", + "$index": 312 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1700", + "$index": 313 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1414", + "$index": 314 + }, + { + "comment": "", + "data_type": "Double", + "name": "k125", + "$index": 315 + }, + { + "comment": "", + "data_type": "Double", + "name": "k325", + "$index": 316 + }, + { + "comment": "", + "data_type": "Double", + "name": "k555", + "$index": 317 + }, + { + "comment": "", + "data_type": "Double", + "name": "k348", + "$index": 318 + }, + { + "comment": "", + "data_type": "Double", + "name": "k613", + "$index": 319 + }, + { + "comment": "", + "data_type": "Double", + "name": "k772", + "$index": 320 + }, + { + "comment": "", + "data_type": "Double", + "name": "k607", + "$index": 321 + }, + { + "comment": "", + "data_type": "Double", + "name": "k771", + "$index": 322 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1199", + "$index": 323 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1264", + "$index": 324 + }, + { + "comment": "", + "data_type": "Double", + "name": "k7", + "$index": 325 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1528", + "$index": 326 + }, + { + "comment": "", + "data_type": "Double", + "name": "k291", + "$index": 327 + }, + { + "comment": "", + "data_type": "Double", + "name": "k826", + "$index": 328 + }, + { + "comment": "", + "data_type": "Double", + "name": "k138", + "$index": 329 + }, + { + "comment": "", + "data_type": "Double", + "name": "k31", + "$index": 330 + }, + { + "comment": "", + "data_type": "Double", + "name": "k410", + "$index": 331 + }, + { + "comment": "", + "data_type": "Double", + "name": "k520", + "$index": 332 + }, + { + "comment": "", + "data_type": "Double", + "name": "k159", + "$index": 333 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1002", + "$index": 334 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1733", + "$index": 335 + }, + { + "comment": "", + "data_type": "Double", + "name": "349", + "$index": 336 + }, + { + "comment": "", + "data_type": "Double", + "name": "345", + "$index": 337 + }, + { + "comment": "", + "data_type": "Double", + "name": "346", + "$index": 338 + }, + { + "comment": "", + "data_type": "Double", + "name": "350", + "$index": 339 + }, + { + "comment": "", + "data_type": "Double", + "name": "348", + "$index": 340 + }, + { + "comment": "", + "data_type": "Double", + "name": "347", + "$index": 341 + }, + { + "comment": "", + "data_type": "Double", + "name": "327", + "$index": 342 + }, + { + "comment": "", + "data_type": "Double", + "name": "328", + "$index": 343 + }, + { + "comment": "", + "data_type": "Double", + "name": "331", + "$index": 344 + }, + { + "comment": "", + "data_type": "Double", + "name": "330", + "$index": 345 + }, + { + "comment": "", + "data_type": "Double", + "name": "332", + "$index": 346 + }, + { + "comment": "", + "data_type": "Double", + "name": "329", + "$index": 347 + }, + { + "comment": "", + "data_type": "Double", + "name": "356", + "$index": 348 + }, + { + "comment": "", + "data_type": "Double", + "name": "355", + "$index": 349 + }, + { + "comment": "", + "data_type": "Double", + "name": "353", + "$index": 350 + }, + { + "comment": "", + "data_type": "Double", + "name": "354", + "$index": 351 + }, + { + "comment": "", + "data_type": "Double", + "name": "352", + "$index": 352 + }, + { + "comment": "", + "data_type": "Double", + "name": "351", + "$index": 353 + }, + { + "comment": "", + "data_type": "Double", + "name": "k983", + "$index": 354 + }, + { + "comment": "", + "data_type": "Double", + "name": "k40", + "$index": 355 + }, + { + "comment": "", + "data_type": "Double", + "name": "339", + "$index": 356 + }, + { + "comment": "", + "data_type": "Double", + "name": "335", + "$index": 357 + }, + { + "comment": "", + "data_type": "Double", + "name": "334", + "$index": 358 + }, + { + "comment": "", + "data_type": "Double", + "name": "359", + "$index": 359 + }, + { + "comment": "", + "data_type": "Double", + "name": "333", + "$index": 360 + }, + { + "comment": "", + "data_type": "Double", + "name": "360", + "$index": 361 + }, + { + "comment": "", + "data_type": "Double", + "name": "342", + "$index": 362 + }, + { + "comment": "", + "data_type": "Double", + "name": "340", + "$index": 363 + }, + { + "comment": "", + "data_type": "Double", + "name": "338", + "$index": 364 + }, + { + "comment": "", + "data_type": "Double", + "name": "337", + "$index": 365 + }, + { + "comment": "", + "data_type": "Double", + "name": "341", + "$index": 366 + }, + { + "comment": "", + "data_type": "Double", + "name": "344", + "$index": 367 + }, + { + "comment": "", + "data_type": "Double", + "name": "343", + "$index": 368 + }, + { + "comment": "", + "data_type": "Double", + "name": "336", + "$index": 369 + }, + { + "comment": "", + "data_type": "Double", + "name": "k287", + "$index": 370 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1519", + "$index": 371 + }, + { + "comment": "", + "data_type": "Double", + "name": "k264", + "$index": 372 + }, + { + "comment": "", + "data_type": "Double", + "name": "9", + "$index": 373 + }, + { + "comment": "", + "data_type": "Double", + "name": "k820", + "$index": 374 + }, + { + "comment": "", + "data_type": "Double", + "name": "264", + "$index": 375 + }, + { + "comment": "", + "data_type": "Double", + "name": "265", + "$index": 376 + }, + { + "comment": "", + "data_type": "Double", + "name": "263", + "$index": 377 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1328", + "$index": 378 + }, + { + "comment": "", + "data_type": "Double", + "name": "k411", + "$index": 379 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1410", + "$index": 380 + }, + { + "comment": "", + "data_type": "Double", + "name": "k730", + "$index": 381 + }, + { + "comment": "", + "data_type": "Double", + "name": "k3", + "$index": 382 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1743", + "$index": 383 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1222", + "$index": 384 + }, + { + "comment": "", + "data_type": "Double", + "name": "k791", + "$index": 385 + }, + { + "comment": "", + "data_type": "Double", + "name": "k965", + "$index": 386 + }, + { + "comment": "", + "data_type": "Double", + "name": "k597", + "$index": 387 + }, + { + "comment": "", + "data_type": "Double", + "name": "k916", + "$index": 388 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1709", + "$index": 389 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1236", + "$index": 390 + }, + { + "comment": "", + "data_type": "Double", + "name": "k43", + "$index": 391 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1169", + "$index": 392 + }, + { + "comment": "", + "data_type": "Double", + "name": "k381", + "$index": 393 + }, + { + "comment": "", + "data_type": "Double", + "name": "89", + "$index": 394 + }, + { + "comment": "", + "data_type": "Double", + "name": "90", + "$index": 395 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1305", + "$index": 396 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1055", + "$index": 397 + }, + { + "comment": "", + "data_type": "Double", + "name": "k378", + "$index": 398 + }, + { + "comment": "", + "data_type": "Double", + "name": "267", + "$index": 399 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1100", + "$index": 400 + }, + { + "comment": "", + "data_type": "Double", + "name": "k82", + "$index": 401 + }, + { + "comment": "", + "data_type": "Double", + "name": "k252", + "$index": 402 + }, + { + "comment": "", + "data_type": "Double", + "name": "k36", + "$index": 403 + }, + { + "comment": "", + "data_type": "Double", + "name": "4", + "$index": 404 + }, + { + "comment": "", + "data_type": "Double", + "name": "1", + "$index": 405 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1687", + "$index": 406 + }, + { + "comment": "", + "data_type": "Double", + "name": "k672", + "$index": 407 + }, + { + "comment": "", + "data_type": "Double", + "name": "k655", + "$index": 408 + }, + { + "comment": "", + "data_type": "Double", + "name": "k271", + "$index": 409 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1195", + "$index": 410 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1579", + "$index": 411 + }, + { + "comment": "", + "data_type": "Double", + "name": "X1", + "$index": 412 + }, + { + "comment": "", + "data_type": "Double", + "name": "X112", + "$index": 413 + }, + { + "comment": "", + "data_type": "Double", + "name": "X113", + "$index": 414 + }, + { + "comment": "", + "data_type": "Double", + "name": "X103", + "$index": 415 + }, + { + "comment": "", + "data_type": "Double", + "name": "X111", + "$index": 416 + }, + { + "comment": "", + "data_type": "Double", + "name": "X10", + "$index": 417 + }, + { + "comment": "", + "data_type": "Double", + "name": "X11", + "$index": 418 + }, + { + "comment": "", + "data_type": "Double", + "name": "X109", + "$index": 419 + }, + { + "comment": "", + "data_type": "Double", + "name": "X108", + "$index": 420 + }, + { + "comment": "", + "data_type": "Double", + "name": "X100", + "$index": 421 + }, + { + "comment": "", + "data_type": "Double", + "name": "X101", + "$index": 422 + }, + { + "comment": "", + "data_type": "Double", + "name": "X107", + "$index": 423 + }, + { + "comment": "", + "data_type": "Double", + "name": "X106", + "$index": 424 + }, + { + "comment": "", + "data_type": "Double", + "name": "X102", + "$index": 425 + }, + { + "comment": "", + "data_type": "Double", + "name": "X105", + "$index": 426 + }, + { + "comment": "", + "data_type": "Double", + "name": "X104", + "$index": 427 + }, + { + "comment": "", + "data_type": "Double", + "name": "X110", + "$index": 428 + }, + { + "comment": "", + "data_type": "Double", + "name": "X3", + "$index": 429 + }, + { + "comment": "", + "data_type": "Double", + "name": "X114", + "$index": 430 + }, + { + "comment": "", + "data_type": "Double", + "name": "X67", + "$index": 431 + }, + { + "comment": "", + "data_type": "Double", + "name": "X72", + "$index": 432 + }, + { + "comment": "", + "data_type": "Double", + "name": "X71", + "$index": 433 + }, + { + "comment": "", + "data_type": "Double", + "name": "X70", + "$index": 434 + }, + { + "comment": "", + "data_type": "Double", + "name": "X7", + "$index": 435 + }, + { + "comment": "", + "data_type": "Double", + "name": "X69", + "$index": 436 + }, + { + "comment": "", + "data_type": "Double", + "name": "X68", + "$index": 437 + }, + { + "comment": "", + "data_type": "Double", + "name": "X66", + "$index": 438 + }, + { + "comment": "", + "data_type": "Double", + "name": "X90", + "$index": 439 + }, + { + "comment": "", + "data_type": "Double", + "name": "X65", + "$index": 440 + }, + { + "comment": "", + "data_type": "Double", + "name": "X64", + "$index": 441 + }, + { + "comment": "", + "data_type": "Double", + "name": "X63", + "$index": 442 + }, + { + "comment": "", + "data_type": "Double", + "name": "X62", + "$index": 443 + }, + { + "comment": "", + "data_type": "Double", + "name": "X61", + "$index": 444 + }, + { + "comment": "", + "data_type": "Double", + "name": "X60", + "$index": 445 + }, + { + "comment": "", + "data_type": "Double", + "name": "X73", + "$index": 446 + }, + { + "comment": "", + "data_type": "Double", + "name": "X74", + "$index": 447 + }, + { + "comment": "", + "data_type": "Double", + "name": "X75", + "$index": 448 + }, + { + "comment": "", + "data_type": "Double", + "name": "X76", + "$index": 449 + }, + { + "comment": "", + "data_type": "Double", + "name": "X77", + "$index": 450 + }, + { + "comment": "", + "data_type": "Double", + "name": "X78", + "$index": 451 + }, + { + "comment": "", + "data_type": "Double", + "name": "X79", + "$index": 452 + }, + { + "comment": "", + "data_type": "Double", + "name": "X8", + "$index": 453 + }, + { + "comment": "", + "data_type": "Double", + "name": "X80", + "$index": 454 + }, + { + "comment": "", + "data_type": "Double", + "name": "X81", + "$index": 455 + }, + { + "comment": "", + "data_type": "Double", + "name": "X82", + "$index": 456 + }, + { + "comment": "", + "data_type": "Double", + "name": "X83", + "$index": 457 + }, + { + "comment": "", + "data_type": "Double", + "name": "X84", + "$index": 458 + }, + { + "comment": "", + "data_type": "Double", + "name": "X85", + "$index": 459 + }, + { + "comment": "", + "data_type": "Double", + "name": "X86", + "$index": 460 + }, + { + "comment": "", + "data_type": "Double", + "name": "X88", + "$index": 461 + }, + { + "comment": "", + "data_type": "Double", + "name": "X89", + "$index": 462 + }, + { + "comment": "", + "data_type": "Double", + "name": "X6", + "$index": 463 + }, + { + "comment": "", + "data_type": "Double", + "name": "X59", + "$index": 464 + }, + { + "comment": "", + "data_type": "Double", + "name": "X58", + "$index": 465 + }, + { + "comment": "", + "data_type": "Double", + "name": "X42", + "$index": 466 + }, + { + "comment": "", + "data_type": "Double", + "name": "X29", + "$index": 467 + }, + { + "comment": "", + "data_type": "Double", + "name": "X31", + "$index": 468 + }, + { + "comment": "", + "data_type": "Double", + "name": "X32", + "$index": 469 + }, + { + "comment": "", + "data_type": "Double", + "name": "X33", + "$index": 470 + }, + { + "comment": "", + "data_type": "Double", + "name": "X34", + "$index": 471 + }, + { + "comment": "", + "data_type": "Double", + "name": "X28", + "$index": 472 + }, + { + "comment": "", + "data_type": "Double", + "name": "X35", + "$index": 473 + }, + { + "comment": "", + "data_type": "Double", + "name": "X36", + "$index": 474 + }, + { + "comment": "", + "data_type": "Double", + "name": "X37", + "$index": 475 + }, + { + "comment": "", + "data_type": "Double", + "name": "X38", + "$index": 476 + }, + { + "comment": "", + "data_type": "Double", + "name": "X39", + "$index": 477 + }, + { + "comment": "", + "data_type": "Double", + "name": "X4", + "$index": 478 + }, + { + "comment": "", + "data_type": "Double", + "name": "X40", + "$index": 479 + }, + { + "comment": "", + "data_type": "Double", + "name": "X41", + "$index": 480 + }, + { + "comment": "", + "data_type": "Double", + "name": "X43", + "$index": 481 + }, + { + "comment": "", + "data_type": "Double", + "name": "X57", + "$index": 482 + }, + { + "comment": "", + "data_type": "Double", + "name": "X44", + "$index": 483 + }, + { + "comment": "", + "data_type": "Double", + "name": "X45", + "$index": 484 + }, + { + "comment": "", + "data_type": "Double", + "name": "X46", + "$index": 485 + }, + { + "comment": "", + "data_type": "Double", + "name": "X47", + "$index": 486 + }, + { + "comment": "", + "data_type": "Double", + "name": "X48", + "$index": 487 + }, + { + "comment": "", + "data_type": "Double", + "name": "X49", + "$index": 488 + }, + { + "comment": "", + "data_type": "Double", + "name": "X5", + "$index": 489 + }, + { + "comment": "", + "data_type": "Double", + "name": "X50", + "$index": 490 + }, + { + "comment": "", + "data_type": "Double", + "name": "X51", + "$index": 491 + }, + { + "comment": "", + "data_type": "Double", + "name": "X52", + "$index": 492 + }, + { + "comment": "", + "data_type": "Double", + "name": "X53", + "$index": 493 + }, + { + "comment": "", + "data_type": "Double", + "name": "X54", + "$index": 494 + }, + { + "comment": "", + "data_type": "Double", + "name": "X55", + "$index": 495 + }, + { + "comment": "", + "data_type": "Double", + "name": "X56", + "$index": 496 + }, + { + "comment": "", + "data_type": "Double", + "name": "X9", + "$index": 497 + }, + { + "comment": "", + "data_type": "Double", + "name": "X87", + "$index": 498 + }, + { + "comment": "", + "data_type": "Double", + "name": "X91", + "$index": 499 + }, + { + "comment": "", + "data_type": "Double", + "name": "X125", + "$index": 500 + }, + { + "comment": "", + "data_type": "Double", + "name": "X2", + "$index": 501 + }, + { + "comment": "", + "data_type": "Double", + "name": "X19", + "$index": 502 + }, + { + "comment": "", + "data_type": "Double", + "name": "X18", + "$index": 503 + }, + { + "comment": "", + "data_type": "Double", + "name": "X92", + "$index": 504 + }, + { + "comment": "", + "data_type": "Double", + "name": "X16", + "$index": 505 + }, + { + "comment": "", + "data_type": "Double", + "name": "X15", + "$index": 506 + }, + { + "comment": "", + "data_type": "Double", + "name": "X14", + "$index": 507 + }, + { + "comment": "", + "data_type": "Double", + "name": "X13", + "$index": 508 + }, + { + "comment": "", + "data_type": "Double", + "name": "X128", + "$index": 509 + }, + { + "comment": "", + "data_type": "Double", + "name": "X127", + "$index": 510 + }, + { + "comment": "", + "data_type": "Double", + "name": "X126", + "$index": 511 + }, + { + "comment": "", + "data_type": "Double", + "name": "X124", + "$index": 512 + }, + { + "comment": "", + "data_type": "Double", + "name": "X21", + "$index": 513 + }, + { + "comment": "", + "data_type": "Double", + "name": "X123", + "$index": 514 + }, + { + "comment": "", + "data_type": "Double", + "name": "X122", + "$index": 515 + }, + { + "comment": "", + "data_type": "Double", + "name": "X121", + "$index": 516 + }, + { + "comment": "", + "data_type": "Double", + "name": "X120", + "$index": 517 + }, + { + "comment": "", + "data_type": "Double", + "name": "X12", + "$index": 518 + }, + { + "comment": "", + "data_type": "Double", + "name": "X119", + "$index": 519 + }, + { + "comment": "", + "data_type": "Double", + "name": "X30", + "$index": 520 + }, + { + "comment": "", + "data_type": "Double", + "name": "X118", + "$index": 521 + }, + { + "comment": "", + "data_type": "Double", + "name": "X117", + "$index": 522 + }, + { + "comment": "", + "data_type": "Double", + "name": "X116", + "$index": 523 + }, + { + "comment": "", + "data_type": "Double", + "name": "X115", + "$index": 524 + }, + { + "comment": "", + "data_type": "Double", + "name": "X20", + "$index": 525 + }, + { + "comment": "", + "data_type": "Double", + "name": "X17", + "$index": 526 + }, + { + "comment": "", + "data_type": "Double", + "name": "X26", + "$index": 527 + }, + { + "comment": "", + "data_type": "Double", + "name": "X93", + "$index": 528 + }, + { + "comment": "", + "data_type": "Double", + "name": "X94", + "$index": 529 + }, + { + "comment": "", + "data_type": "Double", + "name": "X22", + "$index": 530 + }, + { + "comment": "", + "data_type": "Double", + "name": "X95", + "$index": 531 + }, + { + "comment": "", + "data_type": "Double", + "name": "X96", + "$index": 532 + }, + { + "comment": "", + "data_type": "Double", + "name": "X23", + "$index": 533 + }, + { + "comment": "", + "data_type": "Double", + "name": "X24", + "$index": 534 + }, + { + "comment": "", + "data_type": "Double", + "name": "X97", + "$index": 535 + }, + { + "comment": "", + "data_type": "Double", + "name": "X98", + "$index": 536 + }, + { + "comment": "", + "data_type": "Double", + "name": "X99", + "$index": 537 + }, + { + "comment": "", + "data_type": "Double", + "name": "X27", + "$index": 538 + }, + { + "comment": "", + "data_type": "Double", + "name": "X25", + "$index": 539 + }, + { + "comment": "", + "data_type": "Double", + "name": "k541", + "$index": 540 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1092", + "$index": 541 + }, + { + "comment": "", + "data_type": "Double", + "name": "k311", + "$index": 542 + }, + { + "comment": "", + "data_type": "Double", + "name": "k57", + "$index": 543 + }, + { + "comment": "", + "data_type": "Double", + "name": "266", + "$index": 544 + }, + { + "comment": "", + "data_type": "Double", + "name": "mf_3585", + "$index": 545 + }, + { + "comment": "", + "data_type": "Double", + "name": "k387", + "$index": 546 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1243", + "$index": 547 + }, + { + "comment": "", + "data_type": "Double", + "name": "k688", + "$index": 548 + }, + { + "comment": "", + "data_type": "Double", + "name": "k757", + "$index": 549 + }, + { + "comment": "", + "data_type": "Double", + "name": "k898", + "$index": 550 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1587", + "$index": 551 + }, + { + "comment": "", + "data_type": "Double", + "name": "k882", + "$index": 552 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1191", + "$index": 553 + }, + { + "comment": "", + "data_type": "Double", + "name": "k881", + "$index": 554 + }, + { + "comment": "", + "data_type": "Double", + "name": "k299", + "$index": 555 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1166", + "$index": 556 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1219", + "$index": 557 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1726", + "$index": 558 + }, + { + "comment": "", + "data_type": "Double", + "name": "k813", + "$index": 559 + }, + { + "comment": "", + "data_type": "Double", + "name": "k90", + "$index": 560 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1484", + "$index": 561 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1350", + "$index": 562 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1038", + "$index": 563 + }, + { + "comment": "", + "data_type": "Double", + "name": "k521", + "$index": 564 + }, + { + "comment": "", + "data_type": "Double", + "name": "k779", + "$index": 565 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1380", + "$index": 566 + }, + { + "comment": "", + "data_type": "Double", + "name": "k84", + "$index": 567 + }, + { + "comment": "", + "data_type": "Double", + "name": "k809", + "$index": 568 + }, + { + "comment": "", + "data_type": "Double", + "name": "k328", + "$index": 569 + }, + { + "comment": "", + "data_type": "Double", + "name": "k717", + "$index": 570 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1669", + "$index": 571 + }, + { + "comment": "", + "data_type": "Double", + "name": "k756", + "$index": 572 + }, + { + "comment": "", + "data_type": "Double", + "name": "k289", + "$index": 573 + }, + { + "comment": "", + "data_type": "Double", + "name": "k71", + "$index": 574 + }, + { + "comment": "", + "data_type": "Double", + "name": "k446", + "$index": 575 + }, + { + "comment": "", + "data_type": "Double", + "name": "k870", + "$index": 576 + }, + { + "comment": "", + "data_type": "Double", + "name": "k873", + "$index": 577 + }, + { + "comment": "", + "data_type": "Double", + "name": "k840", + "$index": 578 + }, + { + "comment": "", + "data_type": "Double", + "name": "k472", + "$index": 579 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1109", + "$index": 580 + }, + { + "comment": "", + "data_type": "Double", + "name": "k663", + "$index": 581 + }, + { + "comment": "", + "data_type": "Double", + "name": "3", + "$index": 582 + }, + { + "comment": "", + "data_type": "Double", + "name": "2", + "$index": 583 + }, + { + "comment": "", + "data_type": "Double", + "name": "2074", + "$index": 584 + }, + { + "comment": "", + "data_type": "Double", + "name": "2075", + "$index": 585 + }, + { + "comment": "", + "data_type": "Double", + "name": "mf_2332", + "$index": 586 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1032", + "$index": 587 + }, + { + "comment": "", + "data_type": "Double", + "name": "k707", + "$index": 588 + }, + { + "comment": "", + "data_type": "Double", + "name": "k857", + "$index": 589 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1101", + "$index": 590 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1239", + "$index": 591 + }, + { + "comment": "", + "data_type": "Double", + "name": "a9", + "$index": 592 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1443", + "$index": 593 + }, + { + "comment": "", + "data_type": "Double", + "name": "t35", + "$index": 594 + }, + { + "comment": "", + "data_type": "Double", + "name": "t3", + "$index": 595 + }, + { + "comment": "", + "data_type": "Double", + "name": "t30", + "$index": 596 + }, + { + "comment": "", + "data_type": "Double", + "name": "t31", + "$index": 597 + }, + { + "comment": "", + "data_type": "Double", + "name": "t32", + "$index": 598 + }, + { + "comment": "", + "data_type": "Double", + "name": "t33", + "$index": 599 + }, + { + "comment": "", + "data_type": "Double", + "name": "t34", + "$index": 600 + }, + { + "comment": "", + "data_type": "Double", + "name": "t36", + "$index": 601 + }, + { + "comment": "", + "data_type": "Double", + "name": "t7", + "$index": 602 + }, + { + "comment": "", + "data_type": "Double", + "name": "t37", + "$index": 603 + }, + { + "comment": "", + "data_type": "Double", + "name": "t29", + "$index": 604 + }, + { + "comment": "", + "data_type": "Double", + "name": "t39", + "$index": 605 + }, + { + "comment": "", + "data_type": "Double", + "name": "t4", + "$index": 606 + }, + { + "comment": "", + "data_type": "Double", + "name": "t40", + "$index": 607 + }, + { + "comment": "", + "data_type": "Double", + "name": "t5", + "$index": 608 + }, + { + "comment": "", + "data_type": "Double", + "name": "t6", + "$index": 609 + }, + { + "comment": "", + "data_type": "Double", + "name": "t38", + "$index": 610 + }, + { + "comment": "", + "data_type": "Double", + "name": "t27", + "$index": 611 + }, + { + "comment": "", + "data_type": "Double", + "name": "t28", + "$index": 612 + }, + { + "comment": "", + "data_type": "Double", + "name": "t26", + "$index": 613 + }, + { + "comment": "", + "data_type": "Double", + "name": "t1", + "$index": 614 + }, + { + "comment": "", + "data_type": "Double", + "name": "t10", + "$index": 615 + }, + { + "comment": "", + "data_type": "Double", + "name": "t11", + "$index": 616 + }, + { + "comment": "", + "data_type": "Double", + "name": "t12", + "$index": 617 + }, + { + "comment": "", + "data_type": "Double", + "name": "t13", + "$index": 618 + }, + { + "comment": "", + "data_type": "Double", + "name": "t14", + "$index": 619 + }, + { + "comment": "", + "data_type": "Double", + "name": "t15", + "$index": 620 + }, + { + "comment": "", + "data_type": "Double", + "name": "t16", + "$index": 621 + }, + { + "comment": "", + "data_type": "Double", + "name": "t17", + "$index": 622 + }, + { + "comment": "", + "data_type": "Double", + "name": "t18", + "$index": 623 + }, + { + "comment": "", + "data_type": "Double", + "name": "t19", + "$index": 624 + }, + { + "comment": "", + "data_type": "Double", + "name": "t2", + "$index": 625 + }, + { + "comment": "", + "data_type": "Double", + "name": "t20", + "$index": 626 + }, + { + "comment": "", + "data_type": "Double", + "name": "t22", + "$index": 627 + }, + { + "comment": "", + "data_type": "Double", + "name": "t23", + "$index": 628 + }, + { + "comment": "", + "data_type": "Double", + "name": "t24", + "$index": 629 + }, + { + "comment": "", + "data_type": "Double", + "name": "t25", + "$index": 630 + }, + { + "comment": "", + "data_type": "Double", + "name": "t21", + "$index": 631 + }, + { + "comment": "", + "data_type": "Double", + "name": "t8", + "$index": 632 + }, + { + "comment": "", + "data_type": "Double", + "name": "a6", + "$index": 633 + }, + { + "comment": "", + "data_type": "Double", + "name": "a27", + "$index": 634 + }, + { + "comment": "", + "data_type": "Double", + "name": "a29", + "$index": 635 + }, + { + "comment": "", + "data_type": "Double", + "name": "a3", + "$index": 636 + }, + { + "comment": "", + "data_type": "Double", + "name": "a30", + "$index": 637 + }, + { + "comment": "", + "data_type": "Double", + "name": "a31", + "$index": 638 + }, + { + "comment": "", + "data_type": "Double", + "name": "a32", + "$index": 639 + }, + { + "comment": "", + "data_type": "Double", + "name": "a33", + "$index": 640 + }, + { + "comment": "", + "data_type": "Double", + "name": "a35", + "$index": 641 + }, + { + "comment": "", + "data_type": "Double", + "name": "a36", + "$index": 642 + }, + { + "comment": "", + "data_type": "Double", + "name": "a37", + "$index": 643 + }, + { + "comment": "", + "data_type": "Double", + "name": "a38", + "$index": 644 + }, + { + "comment": "", + "data_type": "Double", + "name": "a39", + "$index": 645 + }, + { + "comment": "", + "data_type": "Double", + "name": "a4", + "$index": 646 + }, + { + "comment": "", + "data_type": "Double", + "name": "a40", + "$index": 647 + }, + { + "comment": "", + "data_type": "Double", + "name": "a5", + "$index": 648 + }, + { + "comment": "", + "data_type": "Double", + "name": "a7", + "$index": 649 + }, + { + "comment": "", + "data_type": "Double", + "name": "a8", + "$index": 650 + }, + { + "comment": "", + "data_type": "Double", + "name": "t9", + "$index": 651 + }, + { + "comment": "", + "data_type": "Double", + "name": "a28", + "$index": 652 + }, + { + "comment": "", + "data_type": "Double", + "name": "a34", + "$index": 653 + }, + { + "comment": "", + "data_type": "Double", + "name": "a26", + "$index": 654 + }, + { + "comment": "", + "data_type": "Double", + "name": "a17", + "$index": 655 + }, + { + "comment": "", + "data_type": "Double", + "name": "a1", + "$index": 656 + }, + { + "comment": "", + "data_type": "Double", + "name": "a10", + "$index": 657 + }, + { + "comment": "", + "data_type": "Double", + "name": "a25", + "$index": 658 + }, + { + "comment": "", + "data_type": "Double", + "name": "a12", + "$index": 659 + }, + { + "comment": "", + "data_type": "Double", + "name": "a13", + "$index": 660 + }, + { + "comment": "", + "data_type": "Double", + "name": "a14", + "$index": 661 + }, + { + "comment": "", + "data_type": "Double", + "name": "a15", + "$index": 662 + }, + { + "comment": "", + "data_type": "Double", + "name": "a16", + "$index": 663 + }, + { + "comment": "", + "data_type": "Double", + "name": "a11", + "$index": 664 + }, + { + "comment": "", + "data_type": "Double", + "name": "a18", + "$index": 665 + }, + { + "comment": "", + "data_type": "Double", + "name": "a24", + "$index": 666 + }, + { + "comment": "", + "data_type": "Double", + "name": "a2", + "$index": 667 + }, + { + "comment": "", + "data_type": "Double", + "name": "a20", + "$index": 668 + }, + { + "comment": "", + "data_type": "Double", + "name": "a21", + "$index": 669 + }, + { + "comment": "", + "data_type": "Double", + "name": "a22", + "$index": 670 + }, + { + "comment": "", + "data_type": "Double", + "name": "a19", + "$index": 671 + }, + { + "comment": "", + "data_type": "Double", + "name": "a23", + "$index": 672 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1524", + "$index": 673 + }, + { + "comment": "", + "data_type": "Double", + "name": "k638", + "$index": 674 + }, + { + "comment": "", + "data_type": "Double", + "name": "k824", + "$index": 675 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1200", + "$index": 676 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1223", + "$index": 677 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1273", + "$index": 678 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1251", + "$index": 679 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1154", + "$index": 680 + }, + { + "comment": "", + "data_type": "Double", + "name": "k115", + "$index": 681 + }, + { + "comment": "", + "data_type": "Double", + "name": "k852", + "$index": 682 + }, + { + "comment": "", + "data_type": "Double", + "name": "k818", + "$index": 683 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1299", + "$index": 684 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1322", + "$index": 685 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1330", + "$index": 686 + }, + { + "comment": "", + "data_type": "Double", + "name": "k47", + "$index": 687 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1357", + "$index": 688 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1377", + "$index": 689 + }, + { + "comment": "", + "data_type": "Double", + "name": "k114", + "$index": 690 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1146", + "$index": 691 + }, + { + "comment": "", + "data_type": "Double", + "name": "k442", + "$index": 692 + }, + { + "comment": "", + "data_type": "Double", + "name": "k909", + "$index": 693 + }, + { + "comment": "", + "data_type": "Double", + "name": "k911", + "$index": 694 + }, + { + "comment": "", + "data_type": "Double", + "name": "k80", + "$index": 695 + }, + { + "comment": "", + "data_type": "Double", + "name": "mf_1136", + "$index": 696 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1000", + "$index": 697 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1010", + "$index": 698 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1019", + "$index": 699 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1028", + "$index": 700 + }, + { + "comment": "", + "data_type": "Double", + "name": "k978", + "$index": 701 + }, + { + "comment": "", + "data_type": "Double", + "name": "mf_1487", + "$index": 702 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1052", + "$index": 703 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1090", + "$index": 704 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1121", + "$index": 705 + }, + { + "comment": "", + "data_type": "Double", + "name": "k940", + "$index": 706 + }, + { + "comment": "", + "data_type": "Double", + "name": "k939", + "$index": 707 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1123", + "$index": 708 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1126", + "$index": 709 + }, + { + "comment": "", + "data_type": "Double", + "name": "k801", + "$index": 710 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1375", + "$index": 711 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1393", + "$index": 712 + }, + { + "comment": "", + "data_type": "Double", + "name": "k700", + "$index": 713 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1590", + "$index": 714 + }, + { + "comment": "", + "data_type": "Double", + "name": "k587", + "$index": 715 + }, + { + "comment": "", + "data_type": "Double", + "name": "k255", + "$index": 716 + }, + { + "comment": "", + "data_type": "Double", + "name": "k714", + "$index": 717 + }, + { + "comment": "", + "data_type": "Double", + "name": "k598", + "$index": 718 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1617", + "$index": 719 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1650", + "$index": 720 + }, + { + "comment": "", + "data_type": "Double", + "name": "k605", + "$index": 721 + }, + { + "comment": "", + "data_type": "Double", + "name": "k230", + "$index": 722 + }, + { + "comment": "", + "data_type": "Double", + "name": "k219", + "$index": 723 + }, + { + "comment": "", + "data_type": "Double", + "name": "k156", + "$index": 724 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1677", + "$index": 725 + }, + { + "comment": "", + "data_type": "Double", + "name": "k691", + "$index": 726 + }, + { + "comment": "", + "data_type": "Double", + "name": "k632", + "$index": 727 + }, + { + "comment": "", + "data_type": "Double", + "name": "k635", + "$index": 728 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1736", + "$index": 729 + }, + { + "comment": "", + "data_type": "Double", + "name": "k669", + "$index": 730 + }, + { + "comment": "", + "data_type": "Double", + "name": "k206", + "$index": 731 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1752", + "$index": 732 + }, + { + "comment": "", + "data_type": "Double", + "name": "k201", + "$index": 733 + }, + { + "comment": "", + "data_type": "Double", + "name": "k177", + "$index": 734 + }, + { + "comment": "", + "data_type": "Double", + "name": "k302", + "$index": 735 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1546", + "$index": 736 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1521", + "$index": 737 + }, + { + "comment": "", + "data_type": "Double", + "name": "k435", + "$index": 738 + }, + { + "comment": "", + "data_type": "Double", + "name": "k397", + "$index": 739 + }, + { + "comment": "", + "data_type": "Double", + "name": "k143", + "$index": 740 + }, + { + "comment": "", + "data_type": "Double", + "name": "k515", + "$index": 741 + }, + { + "comment": "", + "data_type": "Double", + "name": "k389", + "$index": 742 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1453", + "$index": 743 + }, + { + "comment": "", + "data_type": "Double", + "name": "k383", + "$index": 744 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1463", + "$index": 745 + }, + { + "comment": "", + "data_type": "Double", + "name": "k526", + "$index": 746 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1472", + "$index": 747 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1755", + "$index": 748 + }, + { + "comment": "", + "data_type": "Double", + "name": "k53", + "$index": 749 + }, + { + "comment": "", + "data_type": "Double", + "name": "k768", + "$index": 750 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1542", + "$index": 751 + }, + { + "comment": "", + "data_type": "Double", + "name": "k369", + "$index": 752 + }, + { + "comment": "", + "data_type": "Double", + "name": "k356", + "$index": 753 + }, + { + "comment": "", + "data_type": "Double", + "name": "k563", + "$index": 754 + }, + { + "comment": "", + "data_type": "Double", + "name": "k324", + "$index": 755 + }, + { + "comment": "", + "data_type": "Double", + "name": "k374", + "$index": 756 + }, + { + "comment": "", + "data_type": "Double", + "name": "k237", + "$index": 757 + }, + { + "comment": "", + "data_type": "Double", + "name": "k52", + "$index": 758 + }, + { + "comment": "", + "data_type": "Double", + "name": "k33", + "$index": 759 + }, + { + "comment": "", + "data_type": "Double", + "name": "k582", + "$index": 760 + }, + { + "comment": "", + "data_type": "Double", + "name": "k267", + "$index": 761 + }, + { + "comment": "", + "data_type": "Double", + "name": "k234", + "$index": 762 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1693", + "$index": 763 + }, + { + "comment": "", + "data_type": "Double", + "name": "k814", + "$index": 764 + }, + { + "comment": "", + "data_type": "Double", + "name": "k144", + "$index": 765 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1268", + "$index": 766 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1177", + "$index": 767 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1164", + "$index": 768 + }, + { + "comment": "", + "data_type": "Double", + "name": "k667", + "$index": 769 + }, + { + "comment": "", + "data_type": "Double", + "name": "k727", + "$index": 770 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1589", + "$index": 771 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1607", + "$index": 772 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1537", + "$index": 773 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1611", + "$index": 774 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1661", + "$index": 775 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1426", + "$index": 776 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1401", + "$index": 777 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1367", + "$index": 778 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1344", + "$index": 779 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1221", + "$index": 780 + }, + { + "comment": "", + "data_type": "Double", + "name": "k778", + "$index": 781 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1302", + "$index": 782 + }, + { + "comment": "", + "data_type": "Double", + "name": "k816", + "$index": 783 + }, + { + "comment": "", + "data_type": "Double", + "name": "k514", + "$index": 784 + }, + { + "comment": "", + "data_type": "Double", + "name": "k386", + "$index": 785 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1754", + "$index": 786 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1213", + "$index": 787 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1545", + "$index": 788 + }, + { + "comment": "", + "data_type": "Double", + "name": "k914", + "$index": 789 + }, + { + "comment": "", + "data_type": "Double", + "name": "k588", + "$index": 790 + }, + { + "comment": "", + "data_type": "Double", + "name": "k266", + "$index": 791 + }, + { + "comment": "", + "data_type": "Double", + "name": "k915", + "$index": 792 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1610", + "$index": 793 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1077", + "$index": 794 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1054", + "$index": 795 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1676", + "$index": 796 + }, + { + "comment": "", + "data_type": "Double", + "name": "k1734", + "$index": 797 + }, + { + "comment": "", + "data_type": "Double", + "name": "k310", + "$index": 798 + }, + { + "comment": "", + "data_type": "Double", + "name": "y", + "$index": 799 + } + ], + "name": "zane ttt", + "publicLevel": "Public", + "public_member_list": "", + "sql": "", + "tags": [ + "学校" + ] +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-delete.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-delete.http similarity index 81% rename from board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-delete.http rename to board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-delete.http index 743ddcffd..207206c46 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-delete.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-delete.http @@ -1,6 +1,6 @@ ### 删除数据集 -POST {{baseUrl}}/data_set/delete +POST http://localhost:8080/board-service/data_set/delete Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-detail.http new file mode 100644 index 000000000..cc6fb9f46 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-detail.http @@ -0,0 +1,18 @@ + +### 查询单个数据集 +POST http://localhost:8080/board-service/data_set/detail +Content-Type: application/json + +{ + "id": "75c2b7d8e53f4400859a8fa72a099fdf" +} + + + +### 查询已删除的数据集 +POST http://localhost:8080/board-service/data_set/detail +Content-Type: application/json + +{ + "id": "a80a0056505c498d9a9be56aa06f7324" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-preview.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-preview.http new file mode 100644 index 000000000..8a9048851 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-preview.http @@ -0,0 +1,10 @@ + +### 预览数据集文件 +POST http://localhost:8080/board-service/data_set/preview +Content-Type: application/json +token: {{token}} + +{ + "filename": "/Users/zane.luo/data/f03ec02a-ef5e-4fb5-8873-08fbc8ebabe7-horz_lr_provider.csv", + "data_set_add_method": "LocalFile" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-query.http new file mode 100644 index 000000000..8d2812a5e --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-query.http @@ -0,0 +1,33 @@ + +### 查询全部数据集 +POST http://localhost:8080/board-service/table_data_set/query +Content-Type: application/json +token: {{token}} + +{} + + + +### 按名字查 +POST http://localhost:8080/board-service/data_set/query +Content-Type: application/json + +{ + "name": "xlsx" +} + +### 按 tag 查 +POST http://localhost:8080/board-service/data_set/query +Content-Type: application/json + +{ + "tag": "xlsx" +} + +### 按是否有 Y 值查 +POST http://localhost:8080/board-service/data_set/query +Content-Type: application/json + +{ + "contains_y": false +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-tags.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-tags.http new file mode 100644 index 000000000..6a64ad787 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-tags.http @@ -0,0 +1,15 @@ + +### 查询所有标签 +POST http://localhost:8080/board-service/data_set/tags +Content-Type: application/json + +{} + + +### 根据标签名称模糊查询 +POST http://localhost:8080/board-service/data_set/tags +Content-Type: application/json + +{ + "tag": "s" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-update.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-update.http similarity index 86% rename from board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-update.http rename to board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-update.http index 66b033122..eb86c105e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-update.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/table_data_set/test/dataset-update.http @@ -1,6 +1,6 @@ ### 修改数据集 -POST {{baseUrl}}/data_set/update +POST http://localhost:8080/board-service/table_data_set/update Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/test/query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/test/query.http new file mode 100644 index 000000000..3310428ad --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/test/query.http @@ -0,0 +1,47 @@ + +### 查询全部数据集 +POST http://localhost:8080/board-service/data_resource/query +Content-Type: application/json +token: {{token}} + +{ +} + +### +POST http://localhost:8080/board-service/data_resource/query +Content-Type: application/json +token: {{token}} + +{ + "data_resource_type": [ + "TableDataSet", + "ImageDataSet" + ] +} + +### +POST http://localhost:8080/board-service/data_resource/query +Content-Type: application/json +token: {{token}} + +{ + "data_resource_type": "ImageDataSet" +} + +### +POST http://localhost:8080/board-service/data_resource/query +Content-Type: application/json +token: {{token}} + +{ + "id": "", + "name": "", + "creator": "", + "tag": "", + "dataResourceType": [], + "page_index": 0, + "page_size": 20, + "containsY": "", + "forJobType": "" +} + diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/test/tags.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/test/tags.http new file mode 100644 index 000000000..98b5c85b2 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/test/tags.http @@ -0,0 +1,37 @@ + +### +POST http://localhost:8080/board-service/data_resource/tags +Content-Type: application/json +token: {{token}} + +{ +} + + +### +POST http://localhost:8080/board-service/data_resource/tags +Content-Type: application/json +token: {{token}} + +{ + "data_resource_type": "ImageDataSet" +} + +### +POST http://localhost:8080/board-service/data_resource/tags +Content-Type: application/json +token: {{token}} + +{ + "data_resource_type": "TableDataSet" +} + + +### +POST http://localhost:8080/board-service/data_resource/tags +Content-Type: application/json +token: {{token}} + +{ + "data_resource_type": "BloomFilter" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/DataResourceUploadTaskDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/DataResourceUploadTaskDetailApi.java new file mode 100644 index 000000000..fceec0c8a --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/DataResourceUploadTaskDetailApi.java @@ -0,0 +1,48 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.upload_task; + +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceUploadTaskMysqlModel; +import com.welab.wefe.board.service.service.data_resource.DataResourceUploadTaskService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author lonnie + */ +@Api(path = "data_resource/upload_task/detail", name = "get a data set upload task info") +public class DataResourceUploadTaskDetailApi extends AbstractApi { + + @Autowired + private DataResourceUploadTaskService dataResourceUploadTaskService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + return success(dataResourceUploadTaskService.findByDataResourceId(input.dataResourceId)); + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + public String dataResourceId; + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/DataResourceUploadTaskQueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/DataResourceUploadTaskQueryApi.java new file mode 100644 index 000000000..5e0341dd1 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/DataResourceUploadTaskQueryApi.java @@ -0,0 +1,46 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.data_resource.upload_task; + +import com.welab.wefe.board.service.dto.base.PagingInput; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.entity.data_resource.output.DataResourceUploadTaskOutputModel; +import com.welab.wefe.board.service.service.data_resource.DataResourceUploadTaskService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author zane.luo + */ +@Api(path = "data_resource/upload_task/query", name = "query data set upload task list") +public class DataResourceUploadTaskQueryApi extends AbstractApi> { + + @Autowired + private DataResourceUploadTaskService dataResourceUploadTaskService; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + return success(dataResourceUploadTaskService.query(input)); + } + + public static class Input extends PagingInput { + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/test/detail.http new file mode 100644 index 000000000..8278ac78c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/test/detail.http @@ -0,0 +1,7 @@ +POST http://localhost:8080/board-service/data_set_task/detail +Content-Type: application/json +token: ac266310-0c98-46ac-b9d0-d7d8ff9ff888 + +{ + "id": "" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/test/query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/test/query.http new file mode 100644 index 000000000..69aa058f3 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/data_resource/upload_task/test/query.http @@ -0,0 +1,7 @@ +POST http://localhost:8080/board-service/data_resource/upload_task/query +Content-Type: application/json +token: {{token}} + +{ + "id": "" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/AddApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/AddApi.java deleted file mode 100644 index 685c052b9..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/AddApi.java +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset; - -import com.welab.wefe.board.service.database.entity.data_set.DataSetTaskMysqlModel; -import com.welab.wefe.board.service.dto.vo.DataSetAddInputModel; -import com.welab.wefe.board.service.service.dataset.DataSetTaskService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.io.IOException; - -/** - * @author Zane - */ -@Api(path = "data_set/add", name = "add data set") -public class AddApi extends AbstractApi { - - @Autowired - private DataSetTaskService dataSetTaskService; - - @Override - protected ApiResult handle(DataSetAddInputModel input) throws StatusCodeWithException, IOException { - DataSetTaskMysqlModel dataSetTaskMysqlModel = dataSetTaskService.add(input); - return success(dataSetTaskMysqlModel); - } - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/DeleteApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/DeleteApi.java deleted file mode 100644 index 6d23805b4..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/DeleteApi.java +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset; - -import com.welab.wefe.board.service.service.DataSetService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author Zane - */ -@Api(path = "data_set/delete", name = "delete data set") -public class DeleteApi extends AbstractNoneOutputApi { - - @Autowired - private DataSetService dataSetService; - - @Override - protected ApiResult handler(Input input) throws StatusCodeWithException { - dataSetService.delete(input); - return success(); - } - - public static class Input extends AbstractApiInput { - @Check(name = "数据集 Id", require = true) - private String id; - - //region getter/setter - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - - //endregion - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/DetailApi.java deleted file mode 100644 index e678041da..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/DetailApi.java +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset; - -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; -import com.welab.wefe.board.service.database.repository.DataSetRepository; -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; -import com.welab.wefe.board.service.util.ModelMapper; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author Zane - */ -@Api(path = "data_set/detail", name = "get data set detail") -public class DetailApi extends AbstractApi { - - @Autowired - DataSetRepository dataSetRepository; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - - DataSetMysqlModel model = dataSetRepository.findById(input.id).orElse(null); - - if (model == null) { - return success(); - } - - DataSetOutputModel output = ModelMapper.map(model, DataSetOutputModel.class); - - return success(output); - - } - - public static class Input extends AbstractApiInput { - private String id; - - //region getter/setter - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - - //endregion - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/QueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/QueryApi.java deleted file mode 100644 index aa8290430..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/QueryApi.java +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset; - -import com.welab.wefe.board.service.dto.base.PagingInput; -import com.welab.wefe.board.service.dto.base.PagingOutput; -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; -import com.welab.wefe.board.service.service.DataSetService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author Zane - */ -@Api(path = "data_set/query", name = "query data set") -public class QueryApi extends AbstractApi> { - - @Autowired - private DataSetService dataSetService; - - @Override - protected ApiResult> handle(Input input) throws StatusCodeWithException { - return success(dataSetService.query(input)); - } - - - public static class Input extends PagingInput { - - private String id; - - @Check(name = "数据集名称") - private String name; - - @Check(name = "标签") - private String tag; - - @Check(name = "是否包含 Y 值") - private Boolean containsY; - - @Check(name = "上传者") - private String creator; - - //region getter/setter - - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public String getTag() { - return tag; - } - - public void setTag(String tag) { - this.tag = tag; - } - - public Boolean getContainsY() { - return containsY; - } - - public void setContainsY(Boolean containsY) { - this.containsY = containsY; - } - - public String getCreator() { - return creator; - } - - public void setCreator(String creator) { - this.creator = creator; - } - - //endregion - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/ServerLocalFilesApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/ServerLocalFilesApi.java deleted file mode 100644 index e58a06939..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/ServerLocalFilesApi.java +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset; - -import com.welab.wefe.board.service.constant.Config; -import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiOutput; -import com.welab.wefe.common.web.dto.ApiResult; -import com.welab.wefe.common.web.dto.NoneApiInput; -import org.springframework.beans.factory.annotation.Autowired; - -import java.io.File; -import java.util.ArrayList; -import java.util.List; - -/** - * @author Johnny.lin - */ -@Api(path = "data_set/list_local_data_set_files", name = "query the files in the specified directory on the server") -public class ServerLocalFilesApi extends AbstractApi { - - @Autowired - private Config config; - - private static final List SUPPORT_SUFFIX = new ArrayList(); - - static { - SUPPORT_SUFFIX.add("xls"); - SUPPORT_SUFFIX.add("xlsx"); - SUPPORT_SUFFIX.add("csv"); - } - - @Override - protected ApiResult handle(NoneApiInput input) throws StatusCodeWithException { - List files = new ArrayList<>(); - File file = new File(config.getFileUploadDir()); - LOG.info("file.exists(): " + file.exists()); - if (!file.exists() || !file.isDirectory()) { - throw new StatusCodeWithException(StatusCode.DIRECTORY_NOT_FOUND, config.getFileUploadDir()); - } - - File[] tempList = file.listFiles(); - for (File fileObj : tempList) { - if (fileObj.isFile()) { - LOG.info("file: " + fileObj); - - //File name, excluding path - String fileName = fileObj.getName(); - String suffix = fileName.substring(fileName.lastIndexOf(".") + 1); - - //Only XLS, xlsx, and CSV files are displayed - if (!SUPPORT_SUFFIX.contains(suffix.toLowerCase())) { - continue; - } - - files.add(fileName); - } - } - - Output output = new Output(); - output.setFiles(files); - return success(output); - } - - public static class Output extends AbstractApiOutput { - private List files; - - public List getFiles() { - return files; - } - - public void setFiles(List files) { - this.files = files; - } - - public Output() { - - } - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/TagListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/TagListApi.java deleted file mode 100644 index da4a2f45d..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/TagListApi.java +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset; - -import com.welab.wefe.board.service.database.repository.DataSetRepository; -import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.util.StringUtil; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.util.TreeMap; - -/** - * @author Zane - */ -@Api(path = "data_set/tags", name = "all of the data set tags") -public class TagListApi extends AbstractApi> { - - @Autowired - DataSetRepository repo; - - @Override - protected ApiResult> handle(Input input) throws StatusCodeWithException { - TreeMap map = (TreeMap) CacheObjects.getDataSetTags().clone(); - - // filter - if (StringUtil.isNotEmpty(input.tag)) { - for (Object tag : map.keySet().toArray()) { - if (!String.valueOf(tag).toLowerCase().contains(input.tag)) { - map.remove(tag); - } - } - } - - return success(map); - } - - public static class Input extends AbstractApiInput { - private String tag; - - //region getter/setter - - public String getTag() { - return tag; - } - - public void setTag(String tag) { - this.tag = tag; - } - - - //endregion - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/UpdateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/UpdateApi.java deleted file mode 100644 index b17a01d5a..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/UpdateApi.java +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset; - -import com.welab.wefe.board.service.dto.vo.DataSetBaseInputModel; -import com.welab.wefe.board.service.service.DataSetService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.util.List; - -/** - * @author Zane - */ -@Api(path = "data_set/update", name = "update data set info") -public class UpdateApi extends AbstractNoneOutputApi { - - @Autowired - private DataSetService dataSetService; - - @Override - protected ApiResult handler(Input input) throws StatusCodeWithException { - dataSetService.update(input); - - return success(); - } - - public static class Input extends DataSetBaseInputModel { - - @Check(name = "数据集Id", require = true) - private String id; - - @Check(name = "数据集名称", require = true, regex = "^.{4,50}$") - private String name; - - @Check(name = "标签", require = true) - private List tags; - - @Check(name = "描述") - private String description; - - //region getter/setter - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public List getTags() { - return tags; - } - - public void setTags(List tags) { - this.tags = tags; - } - - public String getDescription() { - return description; - } - - public void setDescription(String description) { - this.description = description; - } - - //endregion - - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/UsageDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/UsageDetailApi.java deleted file mode 100644 index db70245c2..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/UsageDetailApi.java +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset; - -import com.welab.wefe.board.service.dto.entity.project.ProjectUsageDetailOutputModel; -import com.welab.wefe.board.service.service.DataSetService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.io.IOException; -import java.util.List; - -/** - * @author Jacky.jiang - */ -@Api(path = "data_set/usage_detail", name = "list usage_detail") -public class UsageDetailApi extends AbstractApi> { - @Autowired - private DataSetService dataSetService; - - @Override - protected ApiResult> handle(Input input) throws StatusCodeWithException, IOException { - return success(dataSetService.queryUsageInProject(input.getDataSetId())); - } - - public static class Input extends AbstractApiInput { - @Check(name = "数据集ID", require = true) - private String dataSetId; - - //region getter/setter - - public String getDataSetId() { - return dataSetId; - } - - public void setDataSetId(String dataSetId) { - this.dataSetId = dataSetId; - } - - //endregion - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/column/ListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/column/ListApi.java deleted file mode 100644 index 691d9b3f6..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/column/ListApi.java +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset.column; - -import com.welab.wefe.board.service.dto.base.PagingOutput; -import com.welab.wefe.board.service.dto.entity.data_set.DataSetColumnOutputModel; -import com.welab.wefe.board.service.service.DataSetColumnService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author Zane - */ -@Api(path = "data_set/column/list", name = "list of data set fields") -public class ListApi extends AbstractApi> { - - @Autowired - private DataSetColumnService service; - - @Override - protected ApiResult> handle(Input input) throws StatusCodeWithException { - return success(service.list(input.getDataSetId())); - } - - - public static class Input extends AbstractApiInput { - - @Check(require = true, name = "数据集Id") - private String dataSetId; - - //region getter/setter - - public String getDataSetId() { - return dataSetId; - } - - public void setDataSetId(String dataSetId) { - this.dataSetId = dataSetId; - } - - - //endregion - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-add.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-add.http deleted file mode 100644 index 018157743..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-add.http +++ /dev/null @@ -1,195 +0,0 @@ - -### 添加数据集-xlsx -POST {{baseUrl}}/data_set/add -Content-Type: application/json -token: ac266310-0c98-46ac-b9d0-d7d8ff9ff888 - -{ - "publicLevel": "Public", - "name": "zane test", - "tags": [ - "12321" - ], - "description": "", - "public_member_list": "", - "filename": "db359e88-99a4-4998-a7f2-8ab6bdf0d688-test-01-1600-15.xls", - "data_set_add_method": "HttpUpload", - "metadata_list": [ - { - "data_type": "Integer", - "name": "id", - "$index": 0, - "comment": "" - }, - { - "data_type": "Double", - "name": "x1", - "$index": 1, - "comment": "" - }, - { - "data_type": "Double", - "name": "x2", - "$index": 2, - "comment": "" - }, - { - "data_type": "Double", - "name": "x3", - "$index": 3, - "comment": "" - }, - { - "data_type": "Double", - "name": "x4", - "$index": 4, - "comment": "" - }, - { - "data_type": "Double", - "name": "x5", - "$index": 5, - "comment": "" - }, - { - "data_type": "Double", - "name": "x6", - "$index": 6, - "comment": "" - }, - { - "data_type": "Double", - "name": "x7", - "$index": 7, - "comment": "" - }, - { - "data_type": "Double", - "name": "x8", - "$index": 8, - "comment": "" - }, - { - "data_type": "Double", - "name": "x9", - "$index": 9, - "comment": "" - }, - { - "data_type": "Double", - "name": "x10", - "$index": 10, - "comment": "" - }, - { - "data_type": "Double", - "name": "x11", - "$index": 11, - "comment": "" - }, - { - "data_type": "Double", - "name": "x12", - "$index": 12, - "comment": "" - }, - { - "data_type": "Double", - "name": "x13", - "$index": 13, - "comment": "" - }, - { - "data_type": "Double", - "name": "x14", - "$index": 14, - "comment": "" - }, - { - "data_type": "Double", - "name": "x15", - "$index": 15, - "comment": "" - } - ], - "deduplication": true -} - - - - -### 添加数据集-xls -POST {{baseUrl}}/data_set/add -Content-Type: multipart/form-data; boundary=WebAppBoundary -token: {{token}} - ---WebAppBoundary ---WebAppBoundary -Content-Disposition: form-data; name="file"; filename="test02.xlsx" -Content-Type: application/vnd.openxmlformats-officedocument.spreadsheetml.sheet - -< /Users/zane.luo/test02.xlsx ---WebAppBoundary -Content-Disposition: form-data; name="name" - -xlsx测试数据集 ---WebAppBoundary -Content-Disposition: form-data; name="tags" - -xlsx,test ---WebAppBoundary -Content-Disposition: form-data; name="description" - -这是一个测试数据集,用于测试数据集上传功能。 ---WebAppBoundary -Content-Disposition: form-data; name="contains_y" - -false ---WebAppBoundary - -> {% - -client.test("Request executed successfully", function() { - client.assert(response.body.code === 0, "Response code is not 0"); -}); - -%} - - - - -### 添加数据集-csv -POST {{baseUrl}}/data_set/add -Content-Type: multipart/form-data; boundary=WebAppBoundary - ---WebAppBoundary -Content-Disposition: form-data; name="file"; filename="反欺诈·反查贷后字段开发☞数据样本.csv" -Content-Type: text/csv - -< /Users/Zane/Code/WeLab/Wefe/board/board-service/src/main/resources/test/data_set/data_set01.csv ---WebAppBoundary -Content-Disposition: form-data; name="name" - -xlsx测试数据集 ---WebAppBoundary -Content-Disposition: form-data; name="tags" - -xlsx,test ---WebAppBoundary -Content-Disposition: form-data; name="description" - -这是一个测试数据集,用于测试数据集上传功能。 ---WebAppBoundary -Content-Disposition: form-data; name="contains_y" - -false ---WebAppBoundary - -> {% - -client.test("Request executed successfully", function() { - client.assert(response.body.code === 0, "Response code is not 0"); -}); - -%} - diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-detail.http deleted file mode 100644 index a5e2b3337..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-detail.http +++ /dev/null @@ -1,18 +0,0 @@ - -### 查询单个数据集 -POST {{baseUrl}}/data_set/detail -Content-Type: application/json - -{ - "id": "75c2b7d8e53f4400859a8fa72a099fdf" -} - - - -### 查询已删除的数据集 -POST {{baseUrl}}/data_set/detail -Content-Type: application/json - -{ - "id": "a80a0056505c498d9a9be56aa06f7324" -} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-preview.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-preview.http deleted file mode 100644 index 3159c25d3..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-preview.http +++ /dev/null @@ -1,10 +0,0 @@ - -### 预览数据集文件 -POST {{baseUrl}}/data_set/preview -Content-Type: application/json -token: {{token}} - -{ - "filename": "/Users/zane.luo/data/f03ec02a-ef5e-4fb5-8873-08fbc8ebabe7-horz_lr_provider.csv", - "data_set_add_method" : "LocalFile" -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-query.http deleted file mode 100644 index 8f63da312..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-query.http +++ /dev/null @@ -1,40 +0,0 @@ - -### 查询全部数据集 -POST {{baseUrl}}/data_set/query -Content-Type: application/json - -{} - -> {% - -client.test("Request executed successfully", function() { - client.assert(response.body.code === 0, "Response code is not 0"); -}); - -%} - - - -### 按名字查 -POST {{baseUrl}}/data_set/query -Content-Type: application/json - -{ - "name": "xlsx" -} - -### 按 tag 查 -POST {{baseUrl}}/data_set/query -Content-Type: application/json - -{ - "tag": "xlsx" -} - -### 按是否有 Y 值查 -POST {{baseUrl}}/data_set/query -Content-Type: application/json - -{ - "contains_y": false -} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-tags.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-tags.http deleted file mode 100644 index 6624ab282..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset/test/dataset-tags.http +++ /dev/null @@ -1,15 +0,0 @@ - -### 查询所有标签 -POST {{baseUrl}}/data_set/tags -Content-Type: application/json - -{} - - -### 根据标签名称模糊查询 -POST {{baseUrl}}/data_set/tags -Content-Type: application/json - -{ - "tag": "s" -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset_task/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset_task/DetailApi.java deleted file mode 100644 index 39c006508..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset_task/DetailApi.java +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset_task; - -import com.welab.wefe.board.service.database.entity.data_set.DataSetTaskMysqlModel; -import com.welab.wefe.board.service.service.dataset.DataSetTaskService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author lonnie - */ -@Api(path = "data_set_task/detail", name = "get a data set upload task info") -public class DetailApi extends AbstractApi { - - @Autowired - private DataSetTaskService dataSetTaskService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - return success(dataSetTaskService.findById(input.getId())); - } - - public static class Input extends AbstractApiInput { - @Check(name = "id唯一标识", require = true) - private String id; - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - } - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset_task/QueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset_task/QueryApi.java deleted file mode 100644 index c51aa2c4e..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset_task/QueryApi.java +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.dataset_task; - -import com.welab.wefe.board.service.dto.base.PagingInput; -import com.welab.wefe.board.service.dto.base.PagingOutput; -import com.welab.wefe.board.service.dto.entity.DataSetTaskOutputModel; -import com.welab.wefe.board.service.service.dataset.DataSetTaskService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author zane.luo - */ -@Api(path = "data_set_task/query", name = "query data set upload task list") -public class QueryApi extends AbstractApi> { - - @Autowired - private DataSetTaskService dataSetTaskService; - - @Override - protected ApiResult> handle(Input input) throws StatusCodeWithException { - return success(dataSetTaskService.query(input)); - } - - public static class Input extends PagingInput { - } - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset_task/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset_task/test/detail.http deleted file mode 100644 index adf6a6947..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dataset_task/test/detail.http +++ /dev/null @@ -1,7 +0,0 @@ -POST {{baseUrl}}/data_set_task/detail -Content-Type: application/json -token: ac266310-0c98-46ac-b9d0-d7d8ff9ff888 - -{ - "id": "" -} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/AddApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/AddApi.java index 176a23c69..a17922d1d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/AddApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/AddApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,6 @@ package com.welab.wefe.board.service.api.datasource; import com.welab.wefe.board.service.service.DataSourceService; -import com.welab.wefe.common.enums.DatabaseType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; @@ -25,6 +24,7 @@ import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.AbstractApiOutput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DatabaseType; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/DeleteApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/DeleteApi.java index 7551449d3..5b7d70090 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/DeleteApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/DeleteApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/QueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/QueryApi.java index 2a499b1dc..790aff646 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/QueryApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/QueryApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,13 +19,13 @@ import com.welab.wefe.board.service.dto.base.PagingInput; import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.service.DataSourceService; -import com.welab.wefe.common.enums.DatabaseType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiOutput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DatabaseType; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/TestDBConnectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/TestDBConnectApi.java index 9ef802f1f..d13e922d9 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/TestDBConnectApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/datasource/TestDBConnectApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dev/CreateTestDataSetApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dev/CreateTestDataSetApi.java index cbe5811d1..16e1e01aa 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dev/CreateTestDataSetApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dev/CreateTestDataSetApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,6 +16,7 @@ package com.welab.wefe.board.service.api.dev; +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; import com.welab.wefe.board.service.constant.Config; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.exception.StatusCodeWithException; @@ -61,14 +62,13 @@ protected ApiResult handle(Input input) throws StatusCodeWithException { } private String createCsv(Input input) throws IOException { - String fileName = config.getFileUploadDir() + "/" - + input.idType + "-" + String fileName = input.idType + "-" + input.features + "-" + input.rows + (input.hasY ? "-y" : "") + ".csv"; - File file = new File(fileName); + File file = WeFeFileSystem.getBaseDir(WeFeFileSystem.UseType.AddTableDataSet).resolve(fileName).toFile(); if (file.exists()) { FileUtils.deleteQuietly(file); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dev/test/create_data_set.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dev/test/create_data_set.http index e24e5640d..4a6423560 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/dev/test/create_data_set.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/dev/test/create_data_set.http @@ -1,5 +1,5 @@ ### 生成指定大小的数据集供测试使用 -POST {{baseUrl}}/test/create_data_set +POST http://localhost:8080/board-service/test/create_data_set Content-Type: application/json { @@ -12,7 +12,7 @@ Content-Type: application/json ### -POST {{baseUrl}}/test/create_data_set +POST http://localhost:8080/board-service/test/create_data_set Content-Type: application/json { @@ -25,7 +25,7 @@ Content-Type: application/json ### -POST {{baseUrl}}/test/create_data_set +POST http://localhost:8080/board-service/test/create_data_set Content-Type: application/json { @@ -36,7 +36,7 @@ Content-Type: application/json } ### -POST {{baseUrl}}/test/create_data_set +POST http://localhost:8080/board-service/test/create_data_set Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/FileUploadApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/FileUploadApi.java new file mode 100644 index 000000000..156fc903f --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/FileUploadApi.java @@ -0,0 +1,254 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.file; + +import com.welab.wefe.board.service.api.file.security.FileSecurityChecker; +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractWithFilesApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.apache.commons.io.FileUtils; +import org.springframework.web.multipart.MultipartFile; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; + +/** + * The front end uses the simple-uploader component + * doc:https://github.com/simple-uploader/Uploader/blob/develop/README_zh-CN.md#%E5%A4%84%E7%90%86-get-%E6%88%96%E8%80%85-test-%E8%AF%B7%E6%B1%82 + * + * @author zane.luo + */ +@Api(path = "file/upload", name = "upload file") +public class FileUploadApi extends AbstractApi { + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + + // 检查文件是否是支持的文件类型 + try { + FileSecurityChecker.checkIsAllowFileType(input.filename); + } catch (Exception e) { + return fail(e) + .setHttpCode(599); + } + + switch (input.method) { + case "POST": + return saveChunk(input); + + case "GET": + return checkChunk(input); + + default: + throw new StatusCodeWithException(StatusCode.UNEXPECTED_ENUM_CASE); + } + + + } + + /** + * Check if the chunk already exists + */ + private ApiResult checkChunk(Input input) { + Integer chunkNumber = input.getChunkNumber(); + if (chunkNumber == null) { + chunkNumber = 0; + } + + File outFile = WeFeFileSystem.getBaseDir(input.uploadFileUseType) + .resolve(input.getIdentifier()) + .resolve(chunkNumber + ".part") + .toFile(); + + if (outFile.exists()) { + return success() + .setMessage("该分片已存在"); + } else { + return success() + .setHttpCode(299) + .setMessage("该分片不存在"); + } + } + + /** + * save chunk + */ + private ApiResult saveChunk(Input input) throws StatusCodeWithException { + MultipartFile inputFile = input.getFirstFile(); + + Integer chunkNumber = input.getChunkNumber(); + if (chunkNumber == null) { + chunkNumber = 0; + } + + Path outputDir = WeFeFileSystem.getBaseDir(input.uploadFileUseType).resolve(input.getIdentifier()); + FileUtil.createDir(outputDir.toString()); + LOG.info("创建目录 " + outputDir.toFile().exists() + " :" + outputDir); + + File outFile = outputDir.resolve(chunkNumber + ".part").toFile(); + + try { + InputStream inputStream = inputFile.getInputStream(); + FileUtils.copyInputStreamToFile(inputStream, outFile); + } catch (IOException e) { + LOG.error(e.getMessage(), e); + throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); + } + + return success(new Output(inputFile.getSize())); + } + + public static class Output { + private long length; + + public Output(long length) { + this.length = length; + } + + public long getLength() { + return length; + } + + public void setLength(long length) { + this.length = length; + } + } + + public static class Input extends AbstractWithFilesApiInput { + private Long id; + @Check(name = "当前文件块,从1开始") + private Integer chunkNumber; + @Check(name = "分块大小") + private Long chunkSize; + @Check(name = "当前分块大小") + private Long currentChunkSize; + @Check(name = "总大小") + private Long totalSize; + @Check(name = "文件标识") + private String identifier; + @Check(name = "文件名") + private String filename; + @Check(name = "相对路径") + private String relativePath; + @Check(name = "总块数") + private Integer totalChunks; + @Check(name = "文件类型") + private String type; + @Check(name = "文件用途", require = true) + private WeFeFileSystem.UseType uploadFileUseType; + + //region getter/setter + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public Integer getChunkNumber() { + return chunkNumber; + } + + public void setChunkNumber(Integer chunkNumber) { + this.chunkNumber = chunkNumber; + } + + public Long getChunkSize() { + return chunkSize; + } + + public void setChunkSize(Long chunkSize) { + this.chunkSize = chunkSize; + } + + public Long getCurrentChunkSize() { + return currentChunkSize; + } + + public void setCurrentChunkSize(Long currentChunkSize) { + this.currentChunkSize = currentChunkSize; + } + + public Long getTotalSize() { + return totalSize; + } + + public void setTotalSize(Long totalSize) { + this.totalSize = totalSize; + } + + public String getIdentifier() { + return identifier; + } + + public void setIdentifier(String identifier) { + this.identifier = identifier; + } + + public String getFilename() { + return filename; + } + + public void setFilename(String filename) { + this.filename = filename; + } + + public String getRelativePath() { + return relativePath; + } + + public void setRelativePath(String relativePath) { + this.relativePath = relativePath; + } + + public Integer getTotalChunks() { + return totalChunks; + } + + public void setTotalChunks(Integer totalChunks) { + this.totalChunks = totalChunks; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public WeFeFileSystem.UseType getUploadFileUseType() { + return uploadFileUseType; + } + + public void setUploadFileUseType(WeFeFileSystem.UseType uploadFileUseType) { + this.uploadFileUseType = uploadFileUseType; + } + +//endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/MergeApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/MergeApi.java index 1a69ce0c1..d1d9563a2 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/MergeApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/MergeApi.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -17,15 +17,15 @@ package com.welab.wefe.board.service.api.file; import com.welab.wefe.board.service.api.file.security.FileSecurityChecker; -import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; import org.apache.commons.io.FileUtils; -import org.springframework.beans.factory.annotation.Autowired; import java.io.File; import java.io.FileOutputStream; @@ -38,19 +38,20 @@ @Api(path = "file/merge", name = "Merge the chunks after the file is uploaded") public class MergeApi extends AbstractApi { - @Autowired - private Config config; - @Override protected ApiResult handle(Input input) throws Exception { String mergedFileName = UUID.randomUUID() + "-" + input.filename; - File dir = new File(config.getFileUploadDir() + File.separator + input.uniqueIdentifier); + File dir = WeFeFileSystem.getBaseDir(input.uploadFileUseType) + .resolve(input.uniqueIdentifier) + .toFile(); File[] parts = dir.listFiles(); - File mergedFile = new File(config.getFileUploadDir() + File.separator + mergedFileName); + File mergedFile = WeFeFileSystem.getBaseDir(input.uploadFileUseType) + .resolve(mergedFileName) + .toFile(); if (mergedFile.exists()) { return success(new Output(mergedFileName)); @@ -58,7 +59,10 @@ protected ApiResult handle(Input input) throws Exception { try { for (int i = 1; i <= parts.length; i++) { - File part = new File(config.getFileUploadDir() + File.separator + input.uniqueIdentifier, i + ".part"); + File part = WeFeFileSystem.getBaseDir(input.uploadFileUseType) + .resolve(input.uniqueIdentifier) + .resolve(i + ".part") + .toFile(); // append chunk to the target file FileOutputStream stream = new FileOutputStream(mergedFile, true); @@ -98,6 +102,8 @@ public void setFilename(String filename) { public static class Input extends AbstractApiInput { private String filename; private String uniqueIdentifier; + @Check(name = "文件用途", require = true) + private WeFeFileSystem.UseType uploadFileUseType; public String getFilename() { return filename; @@ -114,5 +120,13 @@ public String getUniqueIdentifier() { public void setUniqueIdentifier(String uniqueIdentifier) { this.uniqueIdentifier = uniqueIdentifier; } + + public WeFeFileSystem.UseType getUploadFileUseType() { + return uploadFileUseType; + } + + public void setUploadFileUseType(WeFeFileSystem.UseType uploadFileUseType) { + this.uploadFileUseType = uploadFileUseType; + } } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/UploadApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/UploadApi.java deleted file mode 100644 index e915f4405..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/UploadApi.java +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.file; - -import com.welab.wefe.board.service.constant.Config; -import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.util.FileUtil; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractWithFilesApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.apache.commons.io.FileUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.web.multipart.MultipartFile; - -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.util.concurrent.atomic.AtomicBoolean; - -/** - * The front end uses the simple-uploader component - * doc:https://github.com/simple-uploader/Uploader/blob/develop/README_zh-CN.md#%E5%A4%84%E7%90%86-get-%E6%88%96%E8%80%85-test-%E8%AF%B7%E6%B1%82 - * - * @author zane.luo - */ -@Api(path = "file/upload", name = "upload file") -public class UploadApi extends AbstractApi { - - @Autowired - private Config config; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - switch (input.method) { - case "POST": - return saveChunk(input); - - case "GET": - return checkChunk(input); - - default: - throw new StatusCodeWithException(StatusCode.UNEXPECTED_ENUM_CASE); - } - - - } - - /** - * Check if the chunk already exists - */ - private ApiResult checkChunk(Input input) { - Integer chunkNumber = input.getChunkNumber(); - if (chunkNumber == null) { - chunkNumber = 0; - } - - File outFile = new File(config.getFileUploadDir() + File.separator + input.getIdentifier(), chunkNumber + ".part"); - if (outFile.exists()) { - return success() - .setMessage("该分片已存在"); - } else { - return success() - .setHttpCode(299) - .setMessage("该分片不存在"); - } - } - - /** - * save chunk - */ - private ApiResult saveChunk(Input input) throws StatusCodeWithException { - MultipartFile file = input.getFirstFile(); - - Integer chunkNumber = input.getChunkNumber(); - if (chunkNumber == null) { - chunkNumber = 0; - } - - File outFile = new File(config.getFileUploadDir() + File.separator + input.getIdentifier(), chunkNumber + ".part"); - - try { - InputStream inputStream = file.getInputStream(); - FileUtils.copyInputStreamToFile(inputStream, outFile); - } catch (IOException e) { - throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); - } - - return success(new Output(file.getSize())); - } - - public static class Output { - private long length; - - public Output(long length) { - this.length = length; - } - - public long getLength() { - return length; - } - - public void setLength(long length) { - this.length = length; - } - } - - public static class Input extends AbstractWithFilesApiInput { - private Long id; - @Check(name = "当前文件块,从1开始") - private Integer chunkNumber; - @Check(name = "分块大小") - private Long chunkSize; - @Check(name = "当前分块大小") - private Long currentChunkSize; - @Check(name = "总大小") - private Long totalSize; - @Check(name = "文件标识") - private String identifier; - @Check(name = "文件名") - private String filename; - @Check(name = "相对路径") - private String relativePath; - @Check(name = "总块数") - private Integer totalChunks; - @Check(name = "文件类型") - private String type; - - //region getter/setter - - public Long getId() { - return id; - } - - public void setId(Long id) { - this.id = id; - } - - public Integer getChunkNumber() { - return chunkNumber; - } - - public void setChunkNumber(Integer chunkNumber) { - this.chunkNumber = chunkNumber; - } - - public Long getChunkSize() { - return chunkSize; - } - - public void setChunkSize(Long chunkSize) { - this.chunkSize = chunkSize; - } - - public Long getCurrentChunkSize() { - return currentChunkSize; - } - - public void setCurrentChunkSize(Long currentChunkSize) { - this.currentChunkSize = currentChunkSize; - } - - public Long getTotalSize() { - return totalSize; - } - - public void setTotalSize(Long totalSize) { - this.totalSize = totalSize; - } - - public String getIdentifier() { - return identifier; - } - - public void setIdentifier(String identifier) { - this.identifier = identifier; - } - - public String getFilename() { - return filename; - } - - public void setFilename(String filename) { - this.filename = filename; - } - - public String getRelativePath() { - return relativePath; - } - - public void setRelativePath(String relativePath) { - this.relativePath = relativePath; - } - - public Integer getTotalChunks() { - return totalChunks; - } - - public void setTotalChunks(Integer totalChunks) { - this.totalChunks = totalChunks; - } - - public String getType() { - return type; - } - - public void setType(String type) { - this.type = type; - } - - - //endregion - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/CsvSecurityChecker.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/CsvSecurityChecker.java index 197fac125..e97676573 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/CsvSecurityChecker.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/CsvSecurityChecker.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/ExcelSecurityChecker.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/ExcelSecurityChecker.java index bf025e8f6..269158a22 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/ExcelSecurityChecker.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/ExcelSecurityChecker.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/FileSecurityChecker.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/FileSecurityChecker.java index 36714e3e4..f1c8af5f5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/FileSecurityChecker.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/FileSecurityChecker.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,15 @@ package com.welab.wefe.board.service.api.file.security; import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.FileUtil; import com.welab.wefe.common.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; -import java.io.IOException; +import java.util.Arrays; +import java.util.List; /** * @author zane @@ -31,14 +33,24 @@ public abstract class FileSecurityChecker { protected final static Logger LOG = LoggerFactory.getLogger(FileSecurityChecker.class); protected static final String[] keywords = {"<", ">", "\\"}; + /** + * 允许的文件类型 + */ + private static final List ALLOW_FILE_TYPES = Arrays.asList( + "xls", "xlsx", "csv", + "zip", "gz", "tgz", "7z", + "jpg", "jpeg", "png" + ); - protected abstract void doCheck(File file) throws IOException; + protected abstract void doCheck(File file) throws Exception; public static void check(File file) throws Exception { + // 为检查上传的文件是否安全 String suffix = StringUtil.substringAfterLast(file.getName(), "."); - try { + checkIsAllowFileType(file.getName()); + switch (suffix) { case "xls": case "xlsx": @@ -47,14 +59,39 @@ public static void check(File file) throws Exception { case "csv": new CsvSecurityChecker().doCheck(file); break; + case "zip": + case "gz": + case "tgz": + case "7z": + break; + case "jpg": + case "jpeg": + case "png": + new ImageSecurityChecker().doCheck(file); + break; default: StatusCode.PARAMETER_VALUE_INVALID.throwException("不支持的文件类型:" + suffix); } } catch (Exception e) { LOG.error(e.getMessage(), e); - FileUtil.deleteFile(file); + FileUtil.deleteFileOrDir(file); throw e; } } + + public static void checkIsAllowFileType(String filename) throws StatusCodeWithException { + if (StringUtil.isEmpty(filename)) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("文件名不允许为空"); + } + + String suffix = StringUtil.substringAfterLast(filename, "."); + if (StringUtil.isEmpty(suffix)) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("不支上传无文件后缀的文件"); + } + + if (!ALLOW_FILE_TYPES.contains(suffix)) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("不支持的文件类型:" + suffix); + } + } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/ImageSecurityChecker.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/ImageSecurityChecker.java new file mode 100644 index 000000000..6c1d30493 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/file/security/ImageSecurityChecker.java @@ -0,0 +1,57 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.file.security; + +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.util.FileType; +import net.coobird.thumbnailator.Thumbnails; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; + +/** + * @author zane + * @date 2022/3/14 + */ +public class ImageSecurityChecker extends FileSecurityChecker { + + @Override + protected void doCheck(File file) throws Exception { + byte[] bytes = Files.readAllBytes(file.toPath()); + // 读取之后就可以删除文件了 + file.delete(); + + // 判断是否是图片 + if (!FileType.isImage(bytes)) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("不支持的图片格式"); + } + + // 对图片文件进行缩放重绘,过滤掉内部可能包含的木马内容。 + try { + Thumbnails + .of(new ByteArrayInputStream(bytes)) + .scale(1) + .toFile(file); + } catch (IOException e) { + StatusCode + .PARAMETER_VALUE_INVALID + .throwException("图片已损坏:" + e.getMessage()); + } + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/GetDerivedDataSetDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/GetDerivedDataSetDetailApi.java index 74e48347c..323986ee8 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/GetDerivedDataSetDetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/GetDerivedDataSetDetailApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,15 +16,15 @@ package com.welab.wefe.board.service.api.gateway; -import com.welab.wefe.board.service.dto.entity.project.DerivedProjectDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.project.data_set.DerivedProjectDataSetOutputModel; import com.welab.wefe.board.service.service.ProjectDataSetService; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/GetMemberJobProgressApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/GetMemberJobProgressApi.java index 8d6f1679c..2eac22135 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/GetMemberJobProgressApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/GetMemberJobProgressApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,13 +18,13 @@ import com.welab.wefe.board.service.dto.vo.JobProgressOutput; import com.welab.wefe.board.service.service.FlowJobService; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/RedirectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/RedirectApi.java index 02d34e367..2d831ed5a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/RedirectApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/RedirectApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,17 +16,20 @@ package com.welab.wefe.board.service.api.gateway; -import com.welab.wefe.common.enums.JobMemberRole; + import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.ApiExecutor; import com.welab.wefe.common.web.Launcher; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; import com.welab.wefe.common.web.dto.GatewayMemberInfo; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import java.util.Map; +import java.util.TreeMap; /** * @author zane.luo @@ -38,14 +41,19 @@ public class RedirectApi extends AbstractApi { protected ApiResult handle(Input input) throws StatusCodeWithException { AbstractApi api = Launcher.CONTEXT.getBean(input.api, AbstractApi.class); - // join the requester's member information in the input input.data.put( "callerMemberInfo", new GatewayMemberInfo(input.callerMemberId, input.callerMemberName, input.callerMemberRole) ); - return api.execute("gateway", JObject.create(input.data)); + ApiResult result = api.execute("gateway", JObject.create(input.data)); + + // 由于这个 api 对象由 RedirectApi 调用,没有走 ApiExecutor,会导致没有响应日志,所以在这里补一个日志。 + Api annotation = api.getClass().getAnnotation(Api.class); + ApiExecutor.logResponse(annotation, result); + + return result; } @@ -56,6 +64,14 @@ public static class Input extends AbstractApiInput { private String api; private Map data; + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + if (data == null) { + data = new TreeMap<>(); + } + } + //region getter/setter public String getCallerMemberId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/TestConnectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/TestConnectApi.java index 490016214..3e245a69f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/TestConnectApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/TestConnectApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/RedirectApiTester.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/RedirectApiTester.java deleted file mode 100644 index a2da00463..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/RedirectApiTester.java +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.gateway.test; - -import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.board.service.service.GatewayService; -import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author zane.luo - */ -@Api(path = "gateway/test/redirect", name = "自己给自己发消息,用来测试", login = false) -public class RedirectApiTester extends AbstractApi { - - @Autowired - private GatewayService gatewayService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - Class api = null; - try { - api = Class.forName(input.api); - } catch (ClassNotFoundException e) { - throw new StatusCodeWithException("api class error:" + input.api, StatusCode.PARAMETER_VALUE_INVALID); - } - - ApiResult result = gatewayService.sendToBoardRedirectApi( - CacheObjects.getMemberId(), - input.memberRole, - input.data, - api - ); - return success(new Output(result.success())); - } - - public static class Output { - public boolean success; - - public Output(boolean success) { - this.success = success; - } - } - - public static class Input extends AbstractApiInput { - public Object data; - public JobMemberRole memberRole; - public String api; - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/get-progress.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/get-progress.http index e5a1f9b51..0bb203cc3 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/get-progress.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/get-progress.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/gateway/get_job_progress +POST http://localhost:8080/board-service/gateway/get_job_progress Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/redirect.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/redirect.http index a1cd10242..588b9f270 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/redirect.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/redirect.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/gateway/redirect +POST http://localhost:8080/board-service/gateway/redirect Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/redirect_tester.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/redirect_tester.http index 93756af89..c54845c85 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/redirect_tester.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/gateway/test/redirect_tester.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/gateway/test/redirect +POST http://localhost:8080/board-service/gateway/test/redirect Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GetConfigProperties.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GetConfigProperties.java new file mode 100644 index 000000000..e22f370d0 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GetConfigProperties.java @@ -0,0 +1,100 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.global_config; + +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.web.api.base.AbstractNoneInputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Value; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * @author zane + * @date 2021/10/29 + */ +@Api(path = "global_config/config_properties", name = "get config items in config.properties file", login = false) +public class GetConfigProperties extends AbstractNoneInputApi { + + @Value("${config.path}") + private String configFilePath; + + private static final List WHITE_LIST = Arrays.asList( + "flow.spark.submit.default.driver.memory", + "flow.spark.submit.default.driver.maxResultSize", + "flow.spark.submit.default.num.executors", + "flow.spark.submit.default.executor.memory", + "flow.spark.submit.default.executor.cores", + "flow.spark.default.num.slices" + ); + + @Override + protected ApiResult handle() throws StatusCodeWithException { + + if (!CurrentAccount.isAdmin()) { + StatusCode.PERMISSION_DENIED.throwException("仅管理员可查看系统相关配置"); + } + + Map configs = new LinkedHashMap<>(); + + Path path = Paths.get(configFilePath); + try { + Files + .lines(path) + .filter(x -> { + String trimed = x.trim(); + if (StringUtil.isEmpty(trimed)) { + return false; + } + if (trimed.startsWith("#")) { + return false; + } + return true; + }) + .forEach(x -> { + String key = StringUtil.substringBefore(x, "="); + String value = StringUtil.substringAfter(x, "="); + if (WHITE_LIST.contains(key)) { + configs.put(key, value); + } + }); + } catch (IOException e) { + LOG.error(e.getMessage(), e); + return fail(e); + } + + return success(new Output(configs)); + } + + public class Output { + public Map configs; + + public Output(Map configs) { + this.configs = configs; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GetGlobalConfigApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GetGlobalConfigApi.java index e76bee00e..0ad37f5bf 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GetGlobalConfigApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GetGlobalConfigApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,7 @@ package com.welab.wefe.board.service.api.global_config; import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.board.service.database.entity.GlobalConfigMySqlModel; +import com.welab.wefe.board.service.database.entity.GlobalConfigMysqlModel; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; @@ -46,7 +46,7 @@ protected ApiResult> handle(Input input) throws StatusCo Map output = new HashMap<>(); for (String group : input.groups) { - List list = globalConfigService.list(group); + List list = globalConfigService.list(group); JSONObject json = new JSONObject(); list.forEach(x -> json.put(x.getName(), x.getValue())); output.put(group, json); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GlobalConfigUpdateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GlobalConfigUpdateApi.java index 550138e53..3846e49fe 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GlobalConfigUpdateApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/GlobalConfigUpdateApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,6 +18,7 @@ import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; @@ -42,6 +43,7 @@ protected ApiResult handler(Input input) throws StatusCodeWithException { } public static class Input extends AbstractApiInput { + @Check(name = "配置项组", require = true) public Map> groups; } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/test/config_properties.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/test/config_properties.http new file mode 100644 index 000000000..61c1c5091 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/test/config_properties.http @@ -0,0 +1,8 @@ + +### +POST http://localhost:8080/board-service/global_config/config_properties +Content-Type: application/json +token:{{token}} + +{ +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/test/update.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/test/update.http new file mode 100644 index 000000000..760b359fb --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/global_config/test/update.http @@ -0,0 +1,8 @@ + +### +POST http://localhost:8080/board-service/global_config/update +Content-Type: application/json +token:{{token}} + +{ +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/login.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/login.http new file mode 100644 index 000000000..1b51be144 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/login.http @@ -0,0 +1,16 @@ + +### 登录 +POST http://localhost:8080/board-service/account/login +Content-Type: application/json + +{ + "phone_number": "13100000001", + "password": "password", + "code": "code", + "key": "key" +} + +> {% +client.global.set("token", response.body.data.token); +%} + diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/CheckMemberRouteConnectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/CheckMemberRouteConnectApi.java index 32c68faa4..15d31944a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/CheckMemberRouteConnectApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/CheckMemberRouteConnectApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/GetMemberMachineLearningEnvApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/GetMemberMachineLearningEnvApi.java new file mode 100644 index 000000000..a379bfb0a --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/GetMemberMachineLearningEnvApi.java @@ -0,0 +1,35 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.member; + +import com.welab.wefe.board.service.dto.kernel.machine_learning.Env; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractNoneInputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; + +/** + * @author Zane + */ +@Api(path = "member/env/machine_learning", name = "get member machine learning Env detail") +public class GetMemberMachineLearningEnvApi extends AbstractNoneInputApi { + + @Override + protected ApiResult handle() throws StatusCodeWithException { + return success(Env.get()); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/InitializeApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/InitializeApi.java index 88eedf42a..28582b22a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/InitializeApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/InitializeApi.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -17,6 +17,7 @@ package com.welab.wefe.board.service.api.member; import com.welab.wefe.board.service.service.SystemInitializeService; +import com.welab.wefe.common.constant.SecretKeyType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.StandardFieldType; import com.welab.wefe.common.fieldvalidate.annotation.Check; @@ -65,6 +66,9 @@ public static class Input extends AbstractApiInput { @Check(name = "是否允许对外公开数据集基础信息", require = true) private Boolean memberAllowPublicDataSet; + @Check(name = "密钥类型") + private SecretKeyType secretKeyType = SecretKeyType.rsa; + //region getter/setter public String getMemberName() { @@ -99,6 +103,13 @@ public void setMemberAllowPublicDataSet(Boolean memberAllowPublicDataSet) { this.memberAllowPublicDataSet = memberAllowPublicDataSet; } + public SecretKeyType getSecretKeyType() { + return secretKeyType; + } + + public void setSecretKeyType(SecretKeyType secretKeyType) { + this.secretKeyType = secretKeyType; + } //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/IsInitializedApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/IsInitializedApi.java index e2dc74dd6..c4c778574 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/IsInitializedApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/IsInitializedApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/MemberAvailableCheckApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/MemberAvailableCheckApi.java new file mode 100644 index 000000000..300e97a59 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/MemberAvailableCheckApi.java @@ -0,0 +1,60 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.member; + +import com.welab.wefe.board.service.service.ServiceCheckService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.checkpoint.dto.MemberAvailableCheckOutput; +import com.welab.wefe.common.wefe.checkpoint.dto.ServiceAvailableCheckOutput; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author lonnie + */ +@Api(path = "member/available", name = "Check whether the member’s system services are available") +public class MemberAvailableCheckApi extends AbstractApi { + + @Autowired + private ServiceCheckService serviceCheckService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + MemberAvailableCheckOutput output = serviceCheckService.getMemberAvailableInfo(input.memberId); + if (input.fromGateway()) { + output.details.values().forEach(ServiceAvailableCheckOutput::cleanValues); + } + return success(output); + } + + public static class Input extends AbstractApiInput { + @Check(name = "成员id", require = true) + public String memberId; + + public Input() { + } + + public Input(String memberId) { + this.memberId = memberId; + } + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/MemberDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/MemberDetailApi.java index fef630a3b..d5e8bec4c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/MemberDetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/MemberDetailApi.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/ResetRsaKeyApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/ResetRsaKeyApi.java index 5fa3e304b..c19cfe165 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/ResetRsaKeyApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/ResetRsaKeyApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/ServiceStatusCheckApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/ServiceStatusCheckApi.java deleted file mode 100644 index 17779e9d8..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/ServiceStatusCheckApi.java +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.member; - -import com.welab.wefe.board.service.dto.vo.MemberServiceStatusOutput; -import com.welab.wefe.board.service.service.ServiceCheckService; -import com.welab.wefe.common.enums.MemberService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.util.Map; - -/** - * @author lonnie - */ -@Api(path = "/member/service_status_check", name = "Check whether the member’s system services are normal", login = false) -public class ServiceStatusCheckApi extends AbstractApi { - - @Autowired - private ServiceCheckService serviceCheckService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - - Output output = serviceCheckService.checkMemberServiceStatus(input); - - return success(output); - } - - public static class Input extends AbstractApiInput { - - @Check(name = "成员id", require = true) - private String memberId; - - private MemberService service; - - public Input() { - } - - public Input(String memberId) { - this.memberId = memberId; - } - - public String getMemberId() { - return memberId; - } - - public void setMemberId(String memberId) { - this.memberId = memberId; - } - - public MemberService getService() { - return service; - } - - public void setService(MemberService service) { - this.service = service; - } - } - - public static class Output { - private boolean allStatusIsSuccess; - private Map status; - - public Output() { - } - - public Output(Map status) { - this.allStatusIsSuccess = status - .values() - .stream() - .allMatch(x -> x.isSuccess()); - - this.status = status; - } - - - public boolean isAllStatusIsSuccess() { - return allStatusIsSuccess; - } - - public void setAllStatusIsSuccess(boolean allStatusIsSuccess) { - this.allStatusIsSuccess = allStatusIsSuccess; - } - - public Map getStatus() { - return status; - } - - public void setStatus(Map status) { - this.status = status; - } - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/SyncMemberToUnionApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/SyncMemberToUnionApi.java index 4ce017a67..bea5a1cac 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/SyncMemberToUnionApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/SyncMemberToUnionApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/UpdateMemberInfoApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/UpdateMemberInfoApi.java index 5e67095ff..7cae99e60 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/UpdateMemberInfoApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/UpdateMemberInfoApi.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -43,7 +43,6 @@ protected ApiResult handler(Input input) throws StatusCodeWithException { public static class Input extends InitializeApi.Input { - @Check( name = "成员logo", blockReactionaryKeyword = false, @@ -51,7 +50,6 @@ public static class Input extends InitializeApi.Input { blockXss = false ) private String memberLogo; - @Check(name = "成员隐身状态") private Boolean memberHidden; @@ -74,14 +72,6 @@ public void checkAndStandardize() throws StatusCodeWithException { //region getter/setter - public String getMemberLogo() { - return memberLogo; - } - - public void setMemberLogo(String memberLogo) { - this.memberLogo = memberLogo; - } - public Boolean getMemberHidden() { return memberHidden; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/UpdateMemberLogoApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/UpdateMemberLogoApi.java index 9f1293e31..542cf346b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/UpdateMemberLogoApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/UpdateMemberLogoApi.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/Initialize.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/Initialize.http index 6685ab9d7..478ef3df3 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/Initialize.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/Initialize.http @@ -1,6 +1,6 @@ ### 登录 -POST {{baseUrl}}/account/login +POST http://localhost:8080/board-service/account/login Content-Type: application/json { @@ -20,7 +20,7 @@ client.test("Request executed successfully", function() { ### 初始化系统 -POST {{baseUrl}}/system/initialize +POST http://localhost:8080/board-service/system/initialize Content-Type: application/json token: {{token}} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/available.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/available.http new file mode 100644 index 000000000..4f984ed70 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/available.http @@ -0,0 +1,18 @@ +POST http://localhost:8080/board-service/member/available +Content-Type: application/json +token:{{token}} + +{ + "member_id": "290007c2a71d470ba00f486b18875d31" +} + + + +### 查别人 +POST http://localhost:8080/board-service/member/available +Content-Type: application/json +token:{{token}} + +{ + "member_id": "cbd2c82da50d4408877061e3e981f8ae" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/member-detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/member-detail.http index 2d3908838..887c1aaad 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/member-detail.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/member-detail.http @@ -1,5 +1,5 @@ ### 获取 member 信息 -GET {{baseUrl}}/member/detail +GET http://localhost:8080/board-service/member/detail Content-Type: application/json token: {{token}} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/member-update.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/member-update.http index 6add8dc4e..c4f0fd9bd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/member-update.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/member-update.http @@ -1,6 +1,6 @@ ### 更新 member 信息 -POST {{baseUrl}}/member/update +POST http://localhost:8080/board-service/member/update Content-Type: application/json token: {{token}} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/service_status_check.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/service_status_check.http deleted file mode 100644 index c23dc84de..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/member/test/service_status_check.http +++ /dev/null @@ -1,16 +0,0 @@ -POST {{baseUrl}}/member/service_status_check -Content-Type: application/json - -{ - "member_id": "23bd0dc471514fd0b268fc2d0799cb2c" -} - - - -### 查别人 -POST {{baseUrl}}/member/service_status_check -Content-Type: application/json - -{ - "member_id": "601e3c3f150a4bfd9e9d6f2be9563b26" -} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/DetailApi.java index a9e2c9824..0f206368c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/DetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/DetailApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -23,12 +23,12 @@ import com.welab.wefe.board.service.database.entity.MessageMysqlModel; import com.welab.wefe.board.service.database.repository.MessageRepository; import com.welab.wefe.board.service.dto.entity.MessageOutputModel; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/QueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/QueryApi.java index 6feaffcb8..51d8c5224 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/QueryApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/QueryApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/ReadApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/ReadApi.java index 2280c3511..c81022b17 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/ReadApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/ReadApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-add.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-add.http index 970cf0fd4..a69a4c1a1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-add.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-add.http @@ -1,5 +1,5 @@ ### 添加消息 -POST {{baseUrl}}/message/add +POST http://localhost:8080/board-service/message/add Content-Type: application/json { @@ -11,7 +11,7 @@ Content-Type: application/json ### 添加消息 -POST {{baseUrl}}/message/add +POST http://localhost:8080/board-service/message/add Content-Type: application/json { @@ -23,7 +23,7 @@ Content-Type: application/json ### 添加消息 -POST {{baseUrl}}/message/add +POST http://localhost:8080/board-service/message/add Content-Type: application/json { @@ -35,7 +35,7 @@ Content-Type: application/json ### 添加消息 -POST {{baseUrl}}/message/add +POST http://localhost:8080/board-service/message/add Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-detail.http index 6bfa89e6c..fa893d9f4 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-detail.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-detail.http @@ -1,6 +1,6 @@ ### 获取消息详情 -POST {{baseUrl}}/message/detail +POST http://localhost:8080/board-service/message/detail Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-query.http index 5b1c7bbe9..65f256bdb 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-query.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-query.http @@ -1,13 +1,13 @@ ### 查询全部消息 -POST {{baseUrl}}/message/query +POST http://localhost:8080/board-service/message/query Content-Type: application/json {} ### 查未读消息 -POST {{baseUrl}}/message/query +POST http://localhost:8080/board-service/message/query Content-Type: application/json { @@ -16,7 +16,7 @@ Content-Type: application/json ### 按 level 查 -POST {{baseUrl}}/message/query +POST http://localhost:8080/board-service/message/query Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-read.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-read.http index bf98cd5dd..1b670cdcc 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-read.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/message/test/message-read.http @@ -1,6 +1,6 @@ ### 将消息标记为已读 -POST {{baseUrl}}/message/read +POST http://localhost:8080/board-service/message/read Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadCallModelResultApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadCallModelResultApi.java new file mode 100644 index 000000000..d6f18ff6c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadCallModelResultApi.java @@ -0,0 +1,69 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.model.deep_learning; + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.component.Components; +import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; +import com.welab.wefe.board.service.dto.entity.job.TaskResultOutputModel; +import com.welab.wefe.board.service.service.TaskService; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.ResponseEntity; + +import java.io.File; +import java.util.UUID; + +/** + * @author zane + * @date 2022/2/14 + */ +@Api(path = "model/deep_learning/call/result/download", name = "下载模型推理结果") +public class DownloadCallModelResultApi extends AbstractApi> { + + @Autowired + private TaskService taskService; + + @Override + protected ApiResult> handle(Input input) throws Exception { + TaskMySqlModel task = taskService.findOne(input.taskId); + if (task == null) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("task 不存在:" + input.taskId); + } + + TaskResultOutputModel result = Components.get(task.getTaskType()).getTaskResult(task.getTaskId(), "infer"); + + File file = WeFeFileSystem + .getBaseDir(WeFeFileSystem.UseType.Temp) + .resolve(UUID.randomUUID() + ".json") + .toFile(); + + FileUtil.writeTextToFile(result.getResult().toJSONString(), file.toPath(), false); + + return file(file); + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + public String taskId; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadDataSetImageApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadDataSetImageApi.java new file mode 100644 index 000000000..fda4336c2 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadDataSetImageApi.java @@ -0,0 +1,58 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.model.deep_learning; + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.http.ResponseEntity; + +import java.io.File; + +/** + * @author zane + * @date 2022/2/14 + */ +@Api(path = "model/deep_learning/call/download/image", name = "下载推理图片") +public class DownloadDataSetImageApi extends AbstractApi> { + + @Override + protected ApiResult> handle(Input input) throws Exception { + File file = WeFeFileSystem.CallDeepLearningModel + .getImageSimpleDir(input.taskId, input.inferSessionId) + .resolve(input.filename) + .toFile(); + + if (!file.exists()) { + StatusCode.FILE_DOES_NOT_EXIST.throwException("文件不存在"); + } + + return file(file); + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + public String taskId; + @Check(require = true) + public String inferSessionId; + @Check(require = true) + public String filename; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadDataSetZipApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadDataSetZipApi.java new file mode 100644 index 000000000..f9ef0fbdf --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadDataSetZipApi.java @@ -0,0 +1,47 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.model.deep_learning; + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.http.ResponseEntity; + +import java.io.File; + +/** + * @author zane + * @date 2022/2/14 + */ +@Api(path = "model/deep_learning/call/download/zip", name = "下载需要批量推理的zip文件", login = false) +public class DownloadDataSetZipApi extends AbstractApi> { + + @Override + protected ApiResult> handle(Input input) throws Exception { + File zipFile = WeFeFileSystem.CallDeepLearningModel.getZipFile(input.taskId, input.inferSessionId); + return file(zipFile); + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + public String taskId; + @Check(require = true) + public String inferSessionId; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadModelApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadModelApi.java new file mode 100644 index 000000000..98bed4be2 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/DownloadModelApi.java @@ -0,0 +1,155 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.model.deep_learning; + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; +import com.welab.wefe.board.service.dto.globalconfig.DeepLearningConfigModel; +import com.welab.wefe.board.service.service.TaskService; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.TimeSpan; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.ResponseEntity; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.math.BigDecimal; +import java.math.RoundingMode; + +/** + * @author zane + * @date 2022/2/14 + */ +@Api(path = "model/deep_learning/download", name = "下载模型") +public class DownloadModelApi extends AbstractApi> { + + @Autowired + private TaskService taskService; + + @Autowired + private GlobalConfigService globalConfigService; + + @Override + protected ApiResult> handle(Input input) throws Exception { + TaskMySqlModel task = taskService.findOne(input.taskId); + DeepLearningConfigModel deepLearningConfig = globalConfigService.getDeepLearningConfig(); + if (deepLearningConfig == null || StringUtil.isEmpty(deepLearningConfig.paddleVisualDlBaseUrl)) { + StatusCode.RPC_ERROR.throwException("尚未设置VisualFL服务地址,请在[全局设置][计算引擎设置]中设置VisualFL服务地址。"); + } + + String url = deepLearningConfig.paddleVisualDlBaseUrl + "/serving_model/download?task_id=" + task.getTaskId() + "&job_id=" + task.getJobId(); + + File file = WeFeFileSystem.CallDeepLearningModel.getModelFile(input.taskId); + try { + long start = System.currentTimeMillis(); + download(url, file); + + LOG.info("从VisualFL下载模型耗时:" + TimeSpan.fromMs(System.currentTimeMillis() - start) + " taskId:" + input.taskId); + } catch (StatusCodeWithException e) { + LOG.error(e.getClass().getSimpleName() + " " + e.getMessage(), e); + throw e; + } catch (Exception e) { + LOG.error("下载模型失败:" + e.getMessage(), e); + StatusCode.RPC_ERROR.throwException("下载模型失败:" + e.getMessage()); + } + + return file(file); + } + + private void download(String url, File file) throws IOException, StatusCodeWithException { + // 创建Http请求配置参数 + RequestConfig requestConfig = RequestConfig.custom() + // 获取连接超时时间 + .setConnectionRequestTimeout(10 * 1000) + // 请求超时时间 + .setConnectTimeout(10 * 1000) + // 响应超时时间 + .setSocketTimeout(10_000) + .build(); + CloseableHttpClient client = HttpClients.custom().setDefaultRequestConfig(requestConfig).build(); + HttpGet httpGet = new HttpGet(url); + try (CloseableHttpResponse response = client.execute(httpGet)) { + int code = response.getStatusLine().getStatusCode(); + if (code != 200) { + StatusCode.RPC_ERROR.throwException("下载模型失败(" + code + "):" + response.getStatusLine().getReasonPhrase()); + } + InputStream is = response.getEntity().getContent(); + if (file.exists()) { + file.delete(); + } + file.getParentFile().mkdirs(); + FileOutputStream fileout = new FileOutputStream(file); + /** + * 根据实际运行效果 设置缓冲区大小 + */ + byte[] buffer = new byte[10 * 1024]; + int ch = 0; + long downloadSize = 0; + while ((ch = is.read(buffer)) != -1) { + fileout.write(buffer, 0, ch); + downloadSize += ch; + if (downloadSize % 1024 == 0) { + LOG.info("模型下载进度:" + getSizeString(downloadSize)); + } + } + is.close(); + fileout.flush(); + fileout.close(); + LOG.info("模型下载完毕:" + getSizeString(downloadSize)); + } catch (Exception e) { + throw e; + } finally { + httpGet.releaseConnection(); + } + } + + private String getSizeString(long byteSize) { + if (byteSize < 1024) { + return byteSize + "byte"; + } + if (byteSize < 1024 * 1024) { + return BigDecimal.valueOf(byteSize) + .divide(BigDecimal.valueOf(1024), 2, RoundingMode.FLOOR) + "KB"; + } + return BigDecimal.valueOf(byteSize) + .divide(BigDecimal.valueOf(1024 * 1024), 2, RoundingMode.FLOOR) + "MB"; + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + public String taskId; + } + + @Override + public boolean canParallel() { + return false; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/README.md b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/README.md new file mode 100644 index 000000000..e34d24cf5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/README.md @@ -0,0 +1,8 @@ +# 深度学习模型在线推理 +这里的 api 提供的是非生产环境的在线推理功能,如果需要将模型部署到生产环境,请将模型下载后在 serving 服务中导入,serving 服务是一个专门针对生产环境设计的应用。 + +推理流程: +1. 上传文件:file/upload +2. 开始预测:/model/deep_learning/call/start +3. 获取预测结果:flow/job/task/detail +4. 下载原始图片: \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/StartCallModelApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/StartCallModelApi.java new file mode 100644 index 000000000..452ac8465 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/StartCallModelApi.java @@ -0,0 +1,116 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.model.deep_learning; + +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; +import com.welab.wefe.board.service.sdk.PaddleVisualService; +import com.welab.wefe.board.service.service.TaskService; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.File; +import java.util.UUID; + +/** + * @author zane + * @date 2022/2/14 + */ +@Api(path = "model/deep_learning/call/start", name = "调用深度学习模型") +public class StartCallModelApi extends AbstractApi { + + @Autowired + private TaskService taskService; + @Autowired + private PaddleVisualService paddleVisualService; + + @Override + protected ApiResult handle(StartCallModelApi.Input input) throws Exception { + File rawFile = WeFeFileSystem.CallDeepLearningModel.getRawFile(input.filename); + + TaskMySqlModel task = taskService.findOne(input.taskId); + if (task == null) { + rawFile.delete(); + StatusCode.PARAMETER_VALUE_INVALID.throwException("此task不存在:" + input.taskId); + } + + // 考虑到文件名冲突、并发问题,这里使用UUID作为本次预测的标识号。 + String inferSessionId = UUID.randomUUID().toString().replace("-", ""); + + // 如果是单张图片 + if (FileUtil.isImage(rawFile)) { + WeFeFileSystem.CallDeepLearningModel.moveSingleImageToSessionDir(rawFile, input.taskId, inferSessionId); + } else { + WeFeFileSystem.CallDeepLearningModel.moveZipFileToSessionDir(rawFile, input.taskId, inferSessionId); + } + + File zipFile = WeFeFileSystem.CallDeepLearningModel.zipImageSimpleDir(input.taskId, inferSessionId); + + // 调用VisualFL开始推理 + JObject dataSetInfo = JObject.create(); + dataSetInfo.put("download_url", buildZipDownloadUrl(input.taskId, inferSessionId)); + dataSetInfo.put("name", zipFile.getName()); + dataSetInfo.put("infer_session_id", inferSessionId); + + JSONObject json = JSON.parseObject(task.getTaskConf()); + json.put("data_set", dataSetInfo); + + JObject response = paddleVisualService.infer(json); + + return success(new Output(inferSessionId, response)); + } + + private String buildZipDownloadUrl(String taskId, String inferSessionId) { + Api annotation = DownloadDataSetZipApi.class.getAnnotation(Api.class); + + return Launcher.getBean(GlobalConfigService.class) + .getBoardConfig() + .intranetBaseUri + + "/" + + annotation.path() + + "?taskId=" + taskId + + "&inferSessionId=" + inferSessionId; + } + + public static class Output { + public String inferSessionId; + public JObject response; + + public Output(String inferSessionId, JObject response) { + this.inferSessionId = inferSessionId; + this.response = response; + } + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + public String taskId; + + @Check(require = true, messageOnEmpty = "请指定数据集文件") + public String filename; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/test/download-model.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/test/download-model.http new file mode 100644 index 000000000..78faed9fc --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/test/download-model.http @@ -0,0 +1,7 @@ +POST http://localhost:8080/board-service/model/deep_learning/download +Content-Type: application/json +token: {{token}} + +{ + "task_id": "c825d03e1e424b51bf99fe61d94ec947_promoter_PaddleClassify_16452527021516948" +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/test/start.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/test/start.http new file mode 100644 index 000000000..24aa95014 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/model/deep_learning/test/start.http @@ -0,0 +1,8 @@ +POST http://localhost:8080/board-service/model/deep_learning/call/start +Content-Type: application/json +token: {{token}} + +{ + "taskId": "ed94d6b25f344ad3849bf42b5d86ff89_promoter_PaddleClassify_16487121262871114", + "filename": "fruit.zip" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/CheckAccountExistApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/CheckAccountExistApi.java index 26754abcb..42ec56b85 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/CheckAccountExistApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/CheckAccountExistApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/CreateOnlineDemoAccountApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/CreateOnlineDemoAccountApi.java index f05996c86..4fd6bd312 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/CreateOnlineDemoAccountApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/CreateOnlineDemoAccountApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,11 +19,11 @@ import com.welab.wefe.board.service.base.OnlineDemoApi; import com.welab.wefe.board.service.dto.vo.AccountInputModel; import com.welab.wefe.board.service.service.account.AccountService; -import com.welab.wefe.common.enums.BoardUserSource; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.BoardUserSource; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/TianmiantechCallApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/TianmiantechCallApi.java index 59f53aea6..998bf6ac0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/TianmiantechCallApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/TianmiantechCallApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/TianmiantechPageApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/TianmiantechPageApi.java index a2ef61a4b..2bc9865d3 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/TianmiantechPageApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/TianmiantechPageApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/account-create.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/account-create.http index 34dbc4b84..12f454a5c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/account-create.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/account-create.http @@ -1,6 +1,6 @@ ### 正常注册 -POST {{baseUrl}}/account/online_demo/create +POST http://localhost:8080/board-service/account/online_demo/create Content-Type: application/json { @@ -12,7 +12,7 @@ Content-Type: application/json ### 手机号冲突 -POST {{baseUrl}}/account/online_demo/create +POST http://localhost:8080/board-service/account/online_demo/create Content-Type: application/json { @@ -25,7 +25,7 @@ Content-Type: application/json ### 手机号错误 -POST {{baseUrl}}/account/online_demo/create +POST http://localhost:8080/board-service/account/online_demo/create Content-Type: application/json { @@ -37,7 +37,7 @@ Content-Type: application/json ### 邮箱错误 -POST {{baseUrl}}/account/online_demo/create +POST http://localhost:8080/board-service/account/online_demo/create Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/account-exist.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/account-exist.http index 42e1c7314..f0fb3cd93 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/account-exist.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/account-exist.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/account/online_demo/exist +POST http://localhost:8080/board-service/account/online_demo/exist Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/tianmiantech-call.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/tianmiantech-call.http index 87d79557c..c361b213b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/tianmiantech-call.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/tianmiantech-call.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/tianmiantech/call_api +POST http://localhost:8080/board-service/tianmiantech/call_api Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/tianmiantech-page.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/tianmiantech-page.http index ddd442a5b..6505fc7dc 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/tianmiantech-page.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/online_demo/test/tianmiantech-page.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/tianmiantech/page_url +POST http://localhost:8080/board-service/tianmiantech/page_url Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/LogQueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/LogQueryApi.java new file mode 100644 index 000000000..4c7667560 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/LogQueryApi.java @@ -0,0 +1,68 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.operation; + +import com.welab.wefe.board.service.dto.base.PagingInput; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.entity.OperationLogOutputModel; +import com.welab.wefe.board.service.service.OperationLogService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author eval + **/ +@Api(path = "log/query", name = "query log") +public class LogQueryApi extends AbstractApi> { + + @Autowired + OperationLogService service; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + return success(service.query(input)); + } + + public static class Input extends PagingInput { + @Check(name = "请求接口") + public String logInterface; + @Check(name = "操作人员Id") + public String operatorId; + private Long startTime; + private Long endTime; + + public Long getStartTime() { + return startTime; + } + + public void setStartTime(Long startTime) { + this.startTime = startTime; + } + + public Long getEndTime() { + return endTime; + } + + public void setEndTime(Long endTime) { + this.endTime = endTime; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/QueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/QueryApi.java deleted file mode 100644 index f0ef748c3..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/QueryApi.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.operation; - -import com.welab.wefe.board.service.dto.base.PagingInput; -import com.welab.wefe.board.service.dto.base.PagingOutput; -import com.welab.wefe.board.service.dto.entity.OperationLogOutputModel; -import com.welab.wefe.board.service.service.OperationLogService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author eval - **/ -@Api(path = "log/query", name = "query log") -public class QueryApi extends AbstractApi> { - - @Autowired - OperationLogService service; - - @Override - protected ApiResult> handle(Input input) throws StatusCodeWithException { - return success(service.query(input)); - } - - public static class Input extends PagingInput { - private String action; - private String operatorPhone; - private Long startTime; - private Long endTime; - - public String getAction() { - return action; - } - - public void setAction(String action) { - this.action = action; - } - - public String getOperatorPhone() { - return operatorPhone; - } - - public void setOperatorPhone(String operatorPhone) { - this.operatorPhone = operatorPhone; - } - - public Long getStartTime() { - return startTime; - } - - public void setStartTime(Long startTime) { - this.startTime = startTime; - } - - public Long getEndTime() { - return endTime; - } - - public void setEndTime(Long endTime) { - this.endTime = endTime; - } - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/test/operation_log-query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/test/operation_log-query.http index 83c2cd66a..ca60d065c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/test/operation_log-query.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/operation/test/operation_log-query.http @@ -1,10 +1,10 @@ ### 查询用户操作日志 -POST {{baseUrl}}/log/query +POST http://localhost:8080/board-service/log/query Content-Type: application/json token: {{token}} { - "action":"model_result", + "action": "model_result", "operator_phone": "15914412294", "start_time": 1597567046000, "end_time": 1597627277000 diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/AddDataSetApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/AddDataSetApi.java index 2c04156fa..6db5c1490 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/AddDataSetApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/AddDataSetApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -31,7 +31,7 @@ /** * @author zan.luo */ -@Api(path = "project/data_set/add", name = "add data set to project") +@Api(path = "project/data_resource/add", name = "add data set to project") public class AddDataSetApi extends AbstractNoneOutputApi { @Autowired @@ -50,7 +50,9 @@ public static class Input extends AbstractApiInput { private String projectId; @Check(name = "数据集列表", require = true) - private List dataSetList; + private List dataResourceList; + + // region getter/setter public String getProjectId() { return projectId; @@ -60,13 +62,16 @@ public void setProjectId(String projectId) { this.projectId = projectId; } - public List getDataSetList() { - return dataSetList; + public List getDataResourceList() { + return dataResourceList; } - public void setDataSetList(List dataSetList) { - this.dataSetList = dataSetList; + public void setDataResourceList(List dataResourceList) { + this.dataResourceList = dataResourceList; } + + // endregion + } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/AuditDataSetApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/AuditDataSetApi.java index 6d81314c8..d3af0dae4 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/AuditDataSetApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/AuditDataSetApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,7 +18,6 @@ import com.welab.wefe.board.service.service.ProjectDataSetAuditService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.AuditStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.StringUtil; @@ -26,12 +25,13 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.AuditStatus; import org.springframework.beans.factory.annotation.Autowired; /** * @author zane.luo */ -@Api(path = "project/data_set/audit", name = "audit the data set authorization application in the project") +@Api(path = "project/data_resource/audit", name = "audit the data set authorization application in the project") public class AuditDataSetApi extends AbstractNoneOutputApi { @Autowired diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/ListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/ListApi.java index a682c2f2b..75ce5b578 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/ListApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/ListApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,7 @@ package com.welab.wefe.board.service.api.project.dataset; -import com.welab.wefe.board.service.dto.entity.project.ProjectDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.project.data_set.ProjectDataResourceOutputModel; import com.welab.wefe.board.service.service.ProjectDataSetService; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; @@ -24,6 +24,7 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DataResourceType; import org.springframework.beans.factory.annotation.Autowired; import java.util.List; @@ -31,7 +32,7 @@ /** * @author zane.luo */ -@Api(path = "project/data_set/list", name = "list all of the project data sets") +@Api(path = "project/data_resource/list", name = "list all of the project data sets") public class ListApi extends AbstractApi { @Autowired @@ -39,7 +40,7 @@ public class ListApi extends AbstractApi { @Override protected ApiResult handle(Input input) throws StatusCodeWithException { - List list = projectDataSetService.list(input.projectId, input.memberId); + List list = projectDataSetService.list(input.projectId, input.dataResourceType, input.memberId); return success(new Output(list)); } @@ -47,6 +48,9 @@ public static class Input extends AbstractApiInput { @Check(name = "项目Id", require = true) private String projectId; + @Check(name = "数据集类型") + private DataResourceType dataResourceType; + @Check(name = "成员Id", desc = "当此参数为空时,返回项目中所有数据集") private String memberId; @@ -61,6 +65,15 @@ public void setProjectId(String projectId) { this.projectId = projectId; } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + public String getMemberId() { return memberId; } @@ -73,17 +86,17 @@ public void setMemberId(String memberId) { } public static class Output { - private List list; + private List list; - public Output(List list) { + public Output(List list) { this.list = list; } - public List getList() { + public List getList() { return list; } - public void setList(List list) { + public void setList(List list) { this.list = list; } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/QueryDerivedDataSetApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/QueryDerivedDataSetApi.java index 4c756003d..e9413b467 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/QueryDerivedDataSetApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/QueryDerivedDataSetApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,14 +18,15 @@ import com.welab.wefe.board.service.dto.base.PagingInput; import com.welab.wefe.board.service.dto.base.PagingOutput; -import com.welab.wefe.board.service.dto.entity.project.DerivedProjectDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.project.data_set.DerivedProjectDataSetOutputModel; import com.welab.wefe.board.service.service.ProjectDataSetService; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.DataResourceType; import org.springframework.beans.factory.annotation.Autowired; /** @@ -46,6 +47,9 @@ public static class Input extends PagingInput { @Check(name = "项目Id", require = true) private String projectId; + @Check(name = "数据集类型", require = true) + private DataResourceType dataResourceType; + @Check(name = "来源") private ComponentType sourceType; @@ -69,6 +73,14 @@ public void setProjectId(String projectId) { this.projectId = projectId; } + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + public ComponentType getSourceType() { return sourceType; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/RawDataSetListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/RawDataSetListApi.java index a07fb14e4..eeef8bd88 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/RawDataSetListApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/RawDataSetListApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,15 +16,17 @@ package com.welab.wefe.board.service.api.project.dataset; -import com.welab.wefe.board.service.dto.entity.project.ProjectDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.project.data_set.ProjectDataResourceOutputModel; import com.welab.wefe.board.service.service.ProjectDataSetService; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import java.util.List; @@ -40,7 +42,7 @@ public class RawDataSetListApi extends AbstractApi handle(Input input) throws StatusCodeWithException { - List list = projectDataSetService.listRawDataSet(input.projectId, input.memberId, input.memberRole, input.containsY); + List list = projectDataSetService.listRawDataSet(input.projectId, input.dataResourceType, input.memberId, input.memberRole, input.containsY, input.forJobType); return success(new Output(list)); } @@ -51,12 +53,18 @@ public static class Input extends AbstractApiInput { @Check(name = "成员Id", require = true, desc = "当此参数为空时,返回项目中所有数据集") private String memberId; + @Check(name = "数据集类型", require = true) + private DataResourceType dataResourceType; + @Check(name = "成员角色", require = true) private JobMemberRole memberRole; @Check(name = "是否包含Y") private Boolean containsY; + @Check(name = "目标任务类型") + private DeepLearningJobType forJobType; + //region getter/setter public String getProjectId() { @@ -75,6 +83,14 @@ public void setMemberId(String memberId) { this.memberId = memberId; } + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + public JobMemberRole getMemberRole() { return memberRole; } @@ -91,21 +107,29 @@ public void setContainsY(Boolean containsY) { this.containsY = containsY; } + public DeepLearningJobType getForJobType() { + return forJobType; + } + + public void setForJobType(DeepLearningJobType forJobType) { + this.forJobType = forJobType; + } + //endregion } public static class Output { - private List list; + private List list; - public Output(List list) { + public Output(List list) { this.list = list; } - public List getList() { + public List getList() { return list; } - public void setList(List list) { + public void setList(List list) { this.list = list; } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/RemoveDataSetApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/RemoveDataSetApi.java index 8d463ea8a..e2bf521d8 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/RemoveDataSetApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/RemoveDataSetApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,19 +17,19 @@ package com.welab.wefe.board.service.api.project.dataset; import com.welab.wefe.board.service.service.ProjectService; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; /** * @author zane.luo */ -@Api(path = "project/data_set/remove", name = "Delete the data set in the project") +@Api(path = "project/data_resource/remove", name = "Delete the data set in the project") public class RemoveDataSetApi extends AbstractNoneOutputApi { @Autowired diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/add.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/add.http index d3c2b6d53..3d27b6d8c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/add.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/add.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/data_set/add +POST http://localhost:8080/board-service/project/data_set/add Content-Type: application/json { @@ -10,4 +10,20 @@ Content-Type: application/json "data_set_id": "bc7a824f64da488ab4e849504f3cd270" } ] +} + +### + +POST http://localhost:8080/board-service/project/data_set/add +Content-Type: application/json + +{ + "project_id": "76c0bab0068f4ff9989592b3185e6ea9", + "dataSetList": [ + { + "member_role": "provider", + "member_id": "cbd2c82da50d4408877061e3e981f8ae", + "data_set_type": "ImageDataSet" + } + ] } \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/audit.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/audit.http index 9f93f7bbf..2b3845149 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/audit.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/audit.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/data_set/audit +POST http://localhost:8080/board-service/project/data_set/audit Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/derived_data_set_list.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/derived_data_set_list.http index 9a7cba980..5438f404c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/derived_data_set_list.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/derived_data_set_list.http @@ -1,6 +1,9 @@ -POST {{baseUrl}}/project/derived_data_set/query +POST http://localhost:8080/board-service/project/derived_data_set/query Content-Type: application/json +token: {{token}} { - "projectId": "f6920e302e724da78f488883dfde5bd5" + "projectId": "1028fefa736b425eadfda5705cde5504", + "data_resource_type": "TableDataSet", + "data_set_id": "c22a8fc4f7571d3ff7f3b6454de04cc8" } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/list.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/list.http index 6ce03240d..ec3d66f4e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/list.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/list.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/data_set/list +POST http://localhost:8080/board-service/project/data_set/list Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/raw_data_set_list.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/raw_data_set_list.http index d36935415..ca0999577 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/raw_data_set_list.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/dataset/test/raw_data_set_list.http @@ -1,9 +1,19 @@ -POST {{baseUrl}}/project/raw_data_set/list +POST http://localhost:8080/board-service/project/raw_data_set/list Content-Type: application/json { - "projectId": "a76c2b3024da49e4a79b3701495e9cd3", - "memberRole": "promoter", - "memberId": "d3c9199e15154d9eac22690a55abc0f4", - "containsY": true + "allList": "", + "contains_y": "", + "dataResourceType": [ + "ImageDataSet" + ], + "data_resource_type": "ImageDataSet", + "data_set_id": "", + "list": "", + "member_id": "290007c2a71d470ba00f486b18875d31", + "member_role": "promoter", + "name": "", + "page_index": 0, + "page_size": 20, + "project_id": "35ef81d771e24bd09c77a432652061b3" } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/AddFlowApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/AddFlowApi.java index 4a5d220e1..ea90abd42 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/AddFlowApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/AddFlowApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,13 +17,14 @@ package com.welab.wefe.board.service.api.project.flow; import com.welab.wefe.board.service.service.ProjectFlowService; -import com.welab.wefe.common.enums.FederatedLearningType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; import org.springframework.beans.factory.annotation.Autowired; /** @@ -47,6 +48,8 @@ public static class Input extends AbstractApiInput { @Check(name = "联邦类型(横向/纵向)", require = true) private FederatedLearningType federatedLearningType; + @Check(name = "深度学习任务类型(目标检测、图像分类)") + private DeepLearningJobType deepLearningJobType; @Check(name = "流程名", require = true) private String name; @@ -57,7 +60,7 @@ public static class Input extends AbstractApiInput { @Check(name = "模板Id", desc = "如果是基于模板创建流程,则指定模板Id") private String templateId; - @Check(name = "流程Id", hiddenForFrontEnd = true) + @Check(name = "流程Id", donotShow = true) private String flowId; /** * is oot model @@ -75,6 +78,14 @@ public void setFederatedLearningType(FederatedLearningType federatedLearningType this.federatedLearningType = federatedLearningType; } + public DeepLearningJobType getDeepLearningJobType() { + return deepLearningJobType; + } + + public void setDeepLearningJobType(DeepLearningJobType deepLearningJobType) { + this.deepLearningJobType = deepLearningJobType; + } + public String getProjectId() { return projectId; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/AddOotFlowApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/AddOotFlowApi.java index b1f827bc4..60fb8daea 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/AddOotFlowApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/AddOotFlowApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/CopyFlowApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/CopyFlowApi.java index e6cbcaa68..8906494c7 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/CopyFlowApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/CopyFlowApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/DeleteApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/DeleteApi.java index 30a52fb6a..815292141 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/DeleteApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/DeleteApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/DetailFlowApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/DetailFlowApi.java index e7b53c18a..f54953872 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/DetailFlowApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/DetailFlowApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -23,13 +23,13 @@ import com.welab.wefe.board.service.service.ModelOotRecordService; import com.welab.wefe.board.service.service.ProjectFlowService; import com.welab.wefe.board.service.service.ProjectService; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; import org.springframework.beans.factory.annotation.Autowired; /** @@ -56,7 +56,7 @@ protected ApiResult handle(Input input) throws Sta ProjectFlowDetailOutputModel output = ModelMapper.map(flow, ProjectFlowDetailOutputModel.class); output.setProject(projectService.detail(flow.getProjectId())); output.setParamsIsNullFlowNodes(projectFlowService.getParamsIsNullFlowNodes(input.flowId)); - output.setIsCreator(CacheObjects.isCurrentMember(flow.getCreatedBy())); + output.setIsCreator(CacheObjects.isCurrentMemberAccount(flow.getCreatedBy())); // OOT model ModelOotRecordMysqlModel modelOotRecordMysqlModel = modelOotRecordService.findByFlowId(input.flowId); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/FlowDataSetInfoApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/FlowDataSetInfoApi.java index 4a33503ba..81fb9daf7 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/FlowDataSetInfoApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/FlowDataSetInfoApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,7 @@ package com.welab.wefe.board.service.api.project.flow; -import com.welab.wefe.board.service.dto.kernel.JobDataSet; +import com.welab.wefe.board.service.dto.kernel.machine_learning.JobDataSet; import com.welab.wefe.board.service.service.ProjectFlowNodeService; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/FlowQueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/FlowQueryApi.java new file mode 100644 index 000000000..f282f015a --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/FlowQueryApi.java @@ -0,0 +1,80 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.flow; + +import com.welab.wefe.board.service.dto.base.PagingInput; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.entity.project.ProjectFlowListOutputModel; +import com.welab.wefe.board.service.service.ProjectFlowService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.List; + +/** + * @author winter.zou + */ +@Api(path = "project/flow/query", name = "query flow list") +public class FlowQueryApi extends AbstractApi> { + + @Autowired + ProjectFlowService flowService; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + return success(flowService.query(input)); + } + + public static class Input extends PagingInput { + + @Check(name = "是否已被删除") + private boolean deleted = false; + @Check(name = "项目ID 主键") + private String projectId; + @Check(name = "flow id 列表") + private List flowIdList; + + + public boolean isDeleted() { + return deleted; + } + + public void setDeleted(boolean deleted) { + this.deleted = deleted; + } + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public List getFlowIdList() { + return flowIdList; + } + + public void setFlowIdList(List flowIdList) { + this.flowIdList = flowIdList; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/GetProgressApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/GetProgressApi.java index 9129833ce..bb9c197bf 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/GetProgressApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/GetProgressApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/ListFlowNodeApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/ListFlowNodeApi.java new file mode 100644 index 000000000..b58689866 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/ListFlowNodeApi.java @@ -0,0 +1,81 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.flow; + +import com.welab.wefe.board.service.dto.entity.job.ProjectFlowNodeOutputModel; +import com.welab.wefe.board.service.service.ProjectFlowService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.AbstractApiOutput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; +import java.util.List; + +/** + * @author aaron.li + **/ +@Api(path = "project/flow_node/list", name = "query flow node list by flow id", login = false) +public class ListFlowNodeApi extends AbstractApi { + + @Autowired + private ProjectFlowService projectFlowService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + List projectFlowNodeOutputModelList = projectFlowService.getFlowNodes(input.flowId); + Output output = new Output(); + output.setList(projectFlowNodeOutputModelList); + return success(output); + } + + public static class Input extends AbstractApiInput { + public Input(String flowId) { + this.flowId = flowId; + } + + @Check(name = "flow id", require = true) + private String flowId; + + public String getFlowId() { + return flowId; + } + + public void setFlowId(String flowId) { + this.flowId = flowId; + } + + } + + + public static class Output extends AbstractApiOutput { + + private List list; + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryDataIoTaskConfigApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryDataIoTaskConfigApi.java index ec45bb627..25723eb48 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryDataIoTaskConfigApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryDataIoTaskConfigApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,6 @@ package com.welab.wefe.board.service.api.project.flow; import com.welab.wefe.board.service.service.TaskService; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; @@ -25,6 +24,7 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import java.io.IOException; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryDataIoTaskFeaturesApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryDataIoTaskFeaturesApi.java index e51ea85d6..46366a893 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryDataIoTaskFeaturesApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryDataIoTaskFeaturesApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,6 +19,7 @@ import com.welab.wefe.board.service.dto.entity.DataIoTaskFeatureInfoOutputModel; import com.welab.wefe.board.service.service.TaskService; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; @@ -50,17 +51,11 @@ protected ApiResult handle(Input input) throws StatusCodeWithException, } public static class Input extends AbstractApiInput { - /** - * Process ID (non OOT mode) - */ - private String flowId; - /** - * Job ID (OOT mode) - */ + @Check(name = "Process ID (non OOT mode)") + protected String flowId; + @Check(name = "Job ID (OOT mode)") private String jobId; - /** - * The member ID to query. If it is blank, it means to query all members under the jobid - */ + @Check(desc = "The member ID to query. If it is blank, it means to query all members under the jobid") private String memberId; public String getJobId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowListApi.java deleted file mode 100644 index 5ad6345cd..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowListApi.java +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.project.flow; - -import com.welab.wefe.board.service.dto.base.PagingInput; -import com.welab.wefe.board.service.dto.base.PagingOutput; -import com.welab.wefe.board.service.dto.entity.project.ProjectFlowListOutputModel; -import com.welab.wefe.board.service.service.ProjectFlowService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author winter.zou - */ -@Api(path = "project/flow/query", name = "query flow list") -public class QueryFlowListApi extends AbstractApi> { - - @Autowired - ProjectFlowService flowService; - - @Override - protected ApiResult> handle(Input input) throws StatusCodeWithException { - return success(flowService.query(input)); - } - - public static class Input extends PagingInput { - - @Check(name = "是否已被删除") - private boolean deleted = false; - @Check(name = "项目ID 主键") - private String projectId; - - - public boolean isDeleted() { - return deleted; - } - - public void setDeleted(boolean deleted) { - this.deleted = deleted; - } - - public String getProjectId() { - return projectId; - } - - public void setProjectId(String projectId) { - this.projectId = projectId; - } - - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowNodeListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowNodeListApi.java deleted file mode 100644 index 987f1d3ab..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowNodeListApi.java +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.project.flow; - -import com.welab.wefe.board.service.dto.entity.job.ProjectFlowNodeOutputModel; -import com.welab.wefe.board.service.service.ProjectFlowService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.AbstractApiOutput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.io.IOException; -import java.util.List; - -/** - * @author aaron.li - **/ -@Api(path = "project/flow_node/query", name = "query flow node list by flow id", login = false) -public class QueryFlowNodeListApi extends AbstractApi { - - @Autowired - private ProjectFlowService projectFlowService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { - List projectFlowNodeOutputModelList = projectFlowService.getFlowNodes(input.flowId); - Output output = new Output(); - output.setList(projectFlowNodeOutputModelList); - return success(output); - } - - public static class Input extends AbstractApiInput { - public Input(String flowId) { - this.flowId = flowId; - } - - @Check(name = "flow id", require = true) - private String flowId; - - public String getFlowId() { - return flowId; - } - - public void setFlowId(String flowId) { - this.flowId = flowId; - } - - } - - - public static class Output extends AbstractApiOutput { - - private List list; - - public List getList() { - return list; - } - - public void setList(List list) { - this.list = list; - } - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowTemplateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowTemplateApi.java index 377b1807e..0af963ccf 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowTemplateApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/QueryFlowTemplateApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,24 +16,23 @@ package com.welab.wefe.board.service.api.project.flow; -import java.util.List; -import java.util.stream.Collectors; - -import javax.persistence.EnumType; -import javax.persistence.Enumerated; - -import org.springframework.beans.factory.annotation.Autowired; - import com.welab.wefe.board.service.api.project.flow.QueryFlowTemplateApi.TemplateListOutput; import com.welab.wefe.board.service.database.entity.flow.FlowTemplateMySqlModel; import com.welab.wefe.board.service.service.FlowTemplateService; -import com.welab.wefe.board.service.util.ModelMapper; -import com.welab.wefe.common.enums.FederatedLearningType; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractNoneInputApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiOutput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import org.springframework.beans.factory.annotation.Autowired; + +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import java.util.List; +import java.util.stream.Collectors; /** * @author winter.zou @@ -74,18 +73,14 @@ public void setTemplates(List templates) { public static class TemplateOutput { private String id; - /** - * template name - */ + @Check(name = "template name") private String name; - /** - * template name - */ + @Check(name = "template name") private String description; private String enname; - + @Enumerated(EnumType.STRING) private FederatedLearningType federatedLearningType; @@ -121,12 +116,12 @@ public void setDescription(String description) { this.description = description; } - public FederatedLearningType getFederatedLearningType() { - return federatedLearningType; - } + public FederatedLearningType getFederatedLearningType() { + return federatedLearningType; + } - public void setFederatedLearningType(FederatedLearningType federatedLearningType) { - this.federatedLearningType = federatedLearningType; - } + public void setFederatedLearningType(FederatedLearningType federatedLearningType) { + this.federatedLearningType = federatedLearningType; + } } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/SaveFlowTemplateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/SaveFlowTemplateApi.java index 47c7cd2aa..5930a3038 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/SaveFlowTemplateApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/SaveFlowTemplateApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/StartFlowApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/StartFlowApi.java index 553187035..da026a039 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/StartFlowApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/StartFlowApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -49,7 +49,7 @@ public static class Input extends AbstractApiInput { @Check(name = "终止节点", desc = "为空时表示执行全流程") private String endNodeId; - @Check(name = "jobId", hiddenForFrontEnd = true) + @Check(name = "jobId", donotShow = true) private String jobId; @Check(name = "arbiterMemberId", desc = "arbiter成员id") diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/UpdateFlowBaseInfoApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/UpdateFlowBaseInfoApi.java index fee787ddd..b7f5e66b6 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/UpdateFlowBaseInfoApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/UpdateFlowBaseInfoApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,13 +17,13 @@ package com.welab.wefe.board.service.api.project.flow; import com.welab.wefe.board.service.service.ProjectFlowService; -import com.welab.wefe.common.enums.FederatedLearningType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/UpdateFlowGraphApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/UpdateFlowGraphApi.java index 1a341bc83..bcae090d9 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/UpdateFlowGraphApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/UpdateFlowGraphApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/add.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/add.http index 62c82e8ae..0ac43640a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/add.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/add.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/flow/add +POST http://localhost:8080/board-service/project/flow/add Content-Type: application/json { @@ -6,3 +6,16 @@ Content-Type: application/json "name": "namename", "desc": "描述描述" } + + +### +POST http://localhost:8080/board-service/project/flow/add +Content-Type: application/json +token:{{token}} + +{ + "project_id": "a27218842ab8474099cf42844ddc1428", + "federatedLearningType": "horizontal", + "name": "新流程-14:47:39", + "desc": "" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/copy.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/copy.http index 25f7d555f..bf3af9a92 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/copy.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/copy.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/flow/copy +POST http://localhost:8080/board-service/project/flow/copy Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/delete.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/delete.http index eda905b7c..3027135ee 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/delete.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/delete.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/flow/delete +POST http://localhost:8080/board-service/project/flow/delete Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/detail.http index 2cf2d150a..eb6ce7a06 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/detail.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/detail.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/flow/detail +POST http://localhost:8080/board-service/project/flow/detail Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/flow_finished.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/flow_finished.http index 54d980bc4..2ca5f9e8f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/flow_finished.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/flow_finished.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/flow/finished +POST http://localhost:8080/board-service/project/flow/finished Content-Type: application/json token: {{token}} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/get-progress.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/get-progress.http index 981dd056d..de8807768 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/get-progress.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/get-progress.http @@ -1,8 +1,8 @@ -POST {{baseUrl}}/project/flow/get_progress +POST http://localhost:8080/board-service/project/flow/get_progress Content-Type: application/json { - "flowIdList":[ + "flowIdList": [ "81576f08c5fc48cca66f0f7c7834e7be" ] } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/list-nodes.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/list-nodes.http new file mode 100644 index 000000000..86006cac9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/list-nodes.http @@ -0,0 +1,10 @@ +POST http://localhost:8080/board-service/project/flow_node/list +Content-Type: application/json +token: {{token}} + + +{ + "flowId": "c5dbbe315fe1416e86ca58f4015e9be2" +} + +### diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/query.http index 19d354579..c74c293ef 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/query.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/query.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/flow/query +POST http://localhost:8080/board-service/project/flow/query Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/start.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/start.http index cbb96ce56..36592bace 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/start.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/start.http @@ -1,8 +1,9 @@ -POST {{baseUrl}}/flow/start +POST http://localhost:8080/board-service/flow/start Content-Type: application/json +token:{{token}} { - "flow_id": "34bb11fb249c4181a62c9b0743a7cd9f", + "flow_id": "f4847e9317164b86a564f98d71841272", "use_cache": false } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_base_info.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_base_info.http index e373bd83d..cb33876e1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_base_info.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_base_info.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/flow/update/base_info +POST http://localhost:8080/board-service/project/flow/update/base_info Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_graph.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_graph.http index fc60765cb..8ae59ab0e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_graph.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_graph.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/flow/update/graph +POST http://localhost:8080/board-service/project/flow/update/graph Content-Type: application/json token: {{token}} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_graph2.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_graph2.http index eeaae0ad6..8389f04cf 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_graph2.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/flow/test/update_graph2.http @@ -1,6 +1,6 @@ ### 这个包含离群节点 -POST {{baseUrl}}/project/flow/update/graph +POST http://localhost:8080/board-service/project/flow/update/graph Content-Type: application/json token: {{token}} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/HashOptionsEnumApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/HashOptionsEnumApi.java new file mode 100644 index 000000000..4fa2a635b --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/HashOptionsEnumApi.java @@ -0,0 +1,43 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion; + +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.HashOptions; + +import java.util.EnumSet; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/hash_options_enum", name = "任务状态", desc = "任务状态") +public class HashOptionsEnumApi extends AbstractApi> { + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + EnumSet hashOptions = EnumSet.allOf(HashOptions.class); + return success(hashOptions); + } + + public static class Input extends AbstractApiInput { + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/DownloadBFApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/DownloadBFApi.java new file mode 100644 index 000000000..622a6342f --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/DownloadBFApi.java @@ -0,0 +1,71 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.project.fusion.actuator.psi; + + + +import com.welab.wefe.board.service.fusion.actuator.psi.ServerActuator; +import com.welab.wefe.board.service.fusion.manager.ActuatorManager; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.fusion.core.dto.PsiActuatorMeta; + +import java.io.IOException; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/psi/download_bloom_filter", + name = "download bloomfilter", + desc = "download bloomfilter", + login = false, + rsaVerify = true +) +public class DownloadBFApi extends AbstractApi { + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + ServerActuator actuator = (ServerActuator) ActuatorManager.get(input.getBusinessId()); + if (actuator == null) { + LOG.error("Actuator not found,businessId is {}", input.getBusinessId()); + throw new StatusCodeWithException("Actuator not found", StatusCode.DATA_NOT_FOUND); + } + + return success(actuator.getActuatorParam()); + } + + public static class Input extends AbstractApiInput { + @Check(name = "businessId", require = true) + String businessId; + + public Input(String businessId) { + this.businessId = businessId; + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/PsiCryptoApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/PsiCryptoApi.java new file mode 100644 index 000000000..0e1be91cf --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/PsiCryptoApi.java @@ -0,0 +1,87 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.project.fusion.actuator.psi; + + + +import com.welab.wefe.board.service.dto.fusion.PsiMeta; +import com.welab.wefe.board.service.fusion.actuator.psi.ServerActuator; +import com.welab.wefe.board.service.fusion.manager.ActuatorManager; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; + +import java.io.IOException; +import java.util.List; + +/** + * @author hunter.zhao + */ +@Api( + path = "fusion/psi/crypto", + name = "psi crypto", + desc = "psi crypto", + login = false, + rsaVerify = true +) +public class PsiCryptoApi extends AbstractApi { + + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + ServerActuator actuator = (ServerActuator) ActuatorManager.get(input.getBusinessId()); + if (actuator == null) { + LOG.error("Actuator not found,businessId is {}", input.getBusinessId()); + throw new StatusCodeWithException("Actuator not found", StatusCode.DATA_NOT_FOUND); + } + + return success(PsiMeta.of(actuator.compute(input.getBs()))); + } + + + public static class Input extends AbstractApiInput { + @Check(name = "businessId", require = true) + String businessId; + + @Check(name = "bs", blockReactionaryKeyword = false) + List bs; + + public Input(String businessId, List bs) { + this.businessId = businessId; + this.bs = bs; + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public List getBs() { + return bs; + } + + public void setBs(List bs) { + this.bs = bs; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/ReceiveResultApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/ReceiveResultApi.java new file mode 100644 index 000000000..cda2ef2e2 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/ReceiveResultApi.java @@ -0,0 +1,86 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.project.fusion.actuator.psi; + + + +import com.welab.wefe.board.service.fusion.actuator.psi.ServerActuator; +import com.welab.wefe.board.service.fusion.manager.ActuatorManager; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; + +import java.util.List; + +/** + * @author hunter.zhao + */ +@Api( + path = "fusion/receive/result", + name = "receive result", + desc = "receive result", + login = false, + rsaVerify = true +) +public class ReceiveResultApi extends AbstractNoneOutputApi { + + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + ServerActuator actuator = (ServerActuator) ActuatorManager.get(input.getBusinessId()); + if (actuator == null) { + LOG.error("Actuator not found,businessId is {}", input.getBusinessId()); + throw new StatusCodeWithException("Actuator not found", StatusCode.DATA_NOT_FOUND); + } + + actuator.receiveResult(input.getRs()); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(name = "businessId", require = true) + String businessId; + + @Check(name = "rs", blockReactionaryKeyword = false) + List rs; + + public Input(String businessId, List rs) { + this.businessId = businessId; + this.rs = rs; + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public List getRs() { + return rs; + } + + public void setRs(List rs) { + this.rs = rs; + } + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/ServerCloseApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/ServerCloseApi.java new file mode 100644 index 000000000..469b65ac7 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/ServerCloseApi.java @@ -0,0 +1,94 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.project.fusion.actuator.psi; + + +import com.welab.wefe.board.service.fusion.actuator.psi.ServerActuator; +import com.welab.wefe.board.service.fusion.manager.ActuatorManager; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.fusion.core.enums.PSIActuatorStatus; + +/** + * @author hunter.zhao + */ +@Api( + path = "fusion/server/close", + name = "server close", + desc = "server close", + login = false, + rsaVerify = true +) +public class ServerCloseApi extends AbstractNoneOutputApi { + @Override + protected ApiResult handler(ServerCloseApi.Input input) throws StatusCodeWithException { + ServerActuator actuator = (ServerActuator) ActuatorManager.get(input.getBusinessId()); + if (actuator == null) { + LOG.error("Actuator not found,businessId is {}", input.getBusinessId()); + throw new StatusCodeWithException("Actuator not found", StatusCode.DATA_NOT_FOUND); + } + + actuator.status = PSIActuatorStatus.valueOf(input.getStatus()); + actuator.error = input.getError(); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(name = "businessId", require = true) + String businessId; + + @Check(name = "任务状态", require = true) + String status; + + @Check(name = "错误信息") + String error; + + public Input(String businessId, String status, String error) { + this.businessId = businessId; + this.status = status; + this.error = error; + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public String getStatus() { + return status; + } + + public void setStatus(String status) { + this.status = status; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/ServerSynStatusApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/ServerSynStatusApi.java new file mode 100644 index 000000000..badea09d4 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/actuator/psi/ServerSynStatusApi.java @@ -0,0 +1,65 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.actuator.psi; + +import com.welab.wefe.board.service.fusion.actuator.psi.ServerActuator; +import com.welab.wefe.board.service.fusion.manager.ActuatorManager; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/psi/server_is_ready", + name = "query server status", + desc = "query server status", + login = false, + rsaVerify = true +) +public class ServerSynStatusApi extends AbstractApi { + + @Override + protected ApiResult handle(Input input) throws Exception { + ServerActuator actuator = (ServerActuator) ActuatorManager.get(input.getBusinessId()); + if (actuator == null) { + return success(JObject.create().append("ready", false)); + } + + return success(JObject.create().append("ready", true)); + } + + public static class Input extends AbstractApiInput { + @Check(name = "businessId", require = true) + String businessId; + + public Input(String businessId) { + this.businessId = businessId; + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + } +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/member/QueryProvidersApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/member/QueryProvidersApi.java new file mode 100644 index 000000000..470b2cef4 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/member/QueryProvidersApi.java @@ -0,0 +1,65 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.member; + + +import com.welab.wefe.board.service.database.entity.job.ProjectMemberMySqlModel; +import com.welab.wefe.board.service.dto.entity.project.ProjectMemberOutputModel; +import com.welab.wefe.board.service.service.ProjectMemberService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; +import java.util.List; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/query/providers", + name = "query provider list", + desc = "query provider list" +) +public class QueryProvidersApi extends AbstractApi> { + @Autowired + ProjectMemberService projectMemberService; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException, IOException { + List memberMySqlModelList = + projectMemberService.listFormalProjectProviders(input.getProjectId()); + return success(ModelMapper.maps(memberMySqlModelList, ProjectMemberOutputModel.class)); + } + + public static class Input extends AbstractApiInput { + @Check(name = "project_id", require = true) + String projectId; + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/ResultExportApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/ResultExportApi.java new file mode 100644 index 000000000..9362720e7 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/ResultExportApi.java @@ -0,0 +1,146 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.project.fusion.result; + + +import com.welab.wefe.board.service.service.fusion.FusionResultService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DatabaseType; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/result/export", name = "结果导出", desc = "结果导出") +public class ResultExportApi extends AbstractApi { + + @Autowired + FusionResultService fusionResultService; + + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + return success(fusionResultService.export(input)); + } + + public static class Input extends AbstractApiInput { + @Check(name = "指定操作的businessId", require = true) + private String businessId; + + @Check(messageOnEmpty = "数据库类型不能为空", require = true) + private DatabaseType databaseType; + + @Check(messageOnEmpty = "IP不能为空", require = true) + private String host; + + @Check(messageOnEmpty = "端口不能为空", require = true) + private Integer port; + + @Check(messageOnEmpty = "数据库名称不能为空", require = true) + private String databaseName; + + @Check(name = "用户名") + private String userName; + + @Check(name = "密码") + private String password; + + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public DatabaseType getDatabaseType() { + return databaseType; + } + + public void setDatabaseType(DatabaseType databaseType) { + this.databaseType = databaseType; + } + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public Integer getPort() { + return port; + } + + public void setPort(Integer port) { + this.port = port; + } + + public String getDatabaseName() { + return databaseName; + } + + public void setDatabaseName(String databaseName) { + this.databaseName = databaseName; + } + + public String getUserName() { + return userName; + } + + public void setUserName(String userName) { + this.userName = userName; + } + + public String getPassword() { + return password; + } + + public void setPassword(String password) { + this.password = password; + } + } + + + public static class Output { + + private String tableName; + + public Output(String tableName) { + this.tableName = tableName; + } + + //region getter/setter + + public String getTableName() { + return tableName; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/ResultExportProgressApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/ResultExportProgressApi.java new file mode 100644 index 000000000..9a5976556 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/ResultExportProgressApi.java @@ -0,0 +1,56 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.project.fusion.result; + + +import com.welab.wefe.board.service.dto.fusion.FusionResultExportProgress; +import com.welab.wefe.board.service.fusion.manager.ExportManager; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/result/export_progress", name = "结果导出", desc = "结果导出", login = false) +public class ResultExportProgressApi extends AbstractApi { + + + @Override + protected ApiResult handle(Input input) throws Exception { + return success(ExportManager.get(input.getBusinessId())); + } + + public static class Input extends AbstractApiInput { + @Check(name = "指定操作的businessId", require = true) + private String businessId; + + //region + + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/ResultPreviewApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/ResultPreviewApi.java new file mode 100644 index 000000000..dcbfbb068 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/ResultPreviewApi.java @@ -0,0 +1,137 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.project.fusion.result; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.database.entity.fusion.FusionTaskMySqlModel; +import com.welab.wefe.board.service.service.fusion.FusionResultStorageService; +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.common.data.storage.common.Constant; +import com.welab.wefe.common.data.storage.model.DataItemModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/result/preview", name = "结果预览", desc = "结果预览") +public class ResultPreviewApi extends AbstractApi { + + @Autowired + FusionTaskService fusionTaskService; + + + @Autowired + FusionResultStorageService fusionResultStorageService; + + @Override + protected ApiResult handle(Input input) throws Exception { + FusionTaskMySqlModel model = fusionTaskService.findByBusinessId(input.getBusinessId()); + if (model == null) { + return success(); + } + + if (model.getFusionCount() == 0) { + return success(); + } + + + DataItemModel headerModel = fusionResultStorageService.getByKey( + Constant.DBName.WEFE_DATA, + fusionResultStorageService.createRawDataSetTableName(input.getBusinessId()) + ".meta", + "header" + ); + List columns = StringUtil.splitWithoutEmptyItem(headerModel.getV().toString().replace("\"", ""), ","); + List> rows = fusionResultStorageService.previewDataSet( + fusionResultStorageService.createRawDataSetTableName(input.getBusinessId()), + 10 + ); + + List list = new ArrayList<>(); + for (List row : rows) { + + JSONObject item = new JSONObject(); + for (int i = 0; i < columns.size(); i++) { + if (row.size() > i) { + item.put(columns.get(i), row.get(i)); + } + } + + list.add(item); + } + + return success(new ResultPreviewApi.Output(columns, list)); + } + + public static class Input extends AbstractApiInput { + @Check(name = "指定操作的businessId", require = true) + private String businessId; + + //region + + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + //endregion + } + + + public static class Output { + + private List header; + private List list; + + public Output(List header, List list) { + this.header = header; + this.list = list; + } + + //region getter/setter + + public List getHeader() { + return header; + } + + public void setHeader(List header) { + this.header = header; + } + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/TestDBConnectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/TestDBConnectApi.java new file mode 100644 index 000000000..2bf489569 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/result/TestDBConnectApi.java @@ -0,0 +1,142 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.project.fusion.result; + + +import com.welab.wefe.board.service.util.JdbcManager; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.AbstractApiOutput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DatabaseType; + +import java.sql.Connection; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/test_db_connect", name = "测试数据库连接", desc = "测试数据库连接") +public class TestDBConnectApi extends AbstractApi { + + @Override + protected ApiResult handle(Input input) throws Exception { + Connection conn = JdbcManager.getConnection(input.getDatabaseType(), input.getHost(), input.getPort(), input.getUserName(), input.getPassword(), input.getDatabaseName()); + if (conn != null) { + boolean success = JdbcManager.testQuery(conn); + if (!success) { + throw new StatusCodeWithException(StatusCode.DATABASE_LOST, "测试连接数据库失败,请检查数据库是否正常或者账号密码是否填写错误"); + } + } + + TestDBConnectApi.Output output = new TestDBConnectApi.Output(); + output.setResult(true); + return success(output); + } + + + public static class Input extends AbstractApiInput { + + @Check(messageOnEmpty = "数据库类型不能为空", require = true) + private DatabaseType databaseType; + + @Check(messageOnEmpty = "IP不能为空", require = true) + private String host; + + @Check(messageOnEmpty = "端口不能为空", require = true) + private Integer port; + + @Check(messageOnEmpty = "数据库名称不能为空", require = true) + private String databaseName; + + @Check(name = "用户名") + private String userName; + + @Check(name = "密码") + private String password; + + public DatabaseType getDatabaseType() { + return databaseType; + } + + public void setDatabaseType(DatabaseType databaseType) { + this.databaseType = databaseType; + } + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public Integer getPort() { + return port; + } + + public void setPort(Integer port) { + this.port = port; + } + + public String getDatabaseName() { + return databaseName; + } + + public void setDatabaseName(String databaseName) { + this.databaseName = databaseName; + } + + public String getUserName() { + return userName; + } + + public void setUserName(String userName) { + this.userName = userName; + } + + public String getPassword() { + return password; + } + + public void setPassword(String password) { + this.password = password; + } + } + + public static class Output extends AbstractApiOutput { + private Boolean result; + + public Output() { + + } + + public Output(Boolean result) { + this.result = result; + } + + public Boolean getResult() { + return result; + } + + public void setResult(Boolean result) { + this.result = result; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/AddApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/AddApi.java new file mode 100644 index 000000000..8da2156b7 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/AddApi.java @@ -0,0 +1,233 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.board.service.util.primarykey.FieldInfo; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.fusion.core.enums.AlgorithmType; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.List; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/task/add", name = "添加对齐任务", desc = "添加对齐任务") +public class AddApi extends AbstractNoneOutputApi { + + @Autowired + FusionTaskService fusionTaskService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + fusionTaskService.add(input); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(name = "项目id", require = true) + private String projectId; + + @Check(name = "任务名称", require = true, regex = "^.{4,40}$", messageOnInvalid = "任务名称长度不能少于4,不能大于40") + private String name; + + @Check(name = "描述", regex = "^.{0,1024}$", messageOnInvalid = "你写的描述太多了~") + private String description; + + @Check(name = "合作方id", require = true) + private String dstMemberId; + + @Check(name = "数据资源id", require = true) + private String dataResourceId; + + @Check(name = "数据资源类型", require = true) + private DataResourceType dataResourceType; + + @Check(name = "样本量", require = true) + private Long rowCount; + + @Check(name = "对方数据资源id", require = true) + private String partnerDataResourceId; + + @Check(name = "对方数据资源类型", require = true) + private DataResourceType partnerDataResourceType; + + @Check(name = "对方样本量", require = true) + private Long partnerRowCount; + + @Check(name = "算法", require = true) + private AlgorithmType algorithm; + + @Check(name = "主键处理") + private List fieldInfoList; + + @Check(name = "是否追溯", require = true) + private Boolean isTrace; + + @Check(name = "追溯字段") + private String traceColumn; + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + if (DataResourceType.TableDataSet.equals(dataResourceType) + && fieldInfoList.isEmpty()) { + throw new StatusCodeWithException("请设置主键", StatusCode.PARAMETER_VALUE_INVALID); + } + + if (isTrace && StringUtil.isEmpty(traceColumn)) { + throw new StatusCodeWithException("追溯字段不能为空", StatusCode.PARAMETER_VALUE_INVALID); + } + + if (AlgorithmType.RSA_PSI.equals(algorithm) && partnerDataResourceType.equals(dataResourceType)) { + throw new StatusCodeWithException(" RSA-PSI 算法要求至少一方需要选择布隆过滤器资源, 另一方则必须为数据资源资源!", StatusCode.PARAMETER_VALUE_INVALID); + } + + if (isTrace && CollectionUtils.isNotEmpty(fieldInfoList)) { + for (int i = 0; i < fieldInfoList.size(); i++) { + if (fieldInfoList.get(i).getColumnList().contains(traceColumn)) { + throw new StatusCodeWithException("追溯字段不能为融合主键组成字段", StatusCode.PARAMETER_VALUE_INVALID); + } + } + } + + } + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getDstMemberId() { + return dstMemberId; + } + + public void setDstMemberId(String dstMemberId) { + this.dstMemberId = dstMemberId; + } + + public String getDataResourceId() { + return dataResourceId; + } + + public void setDataResourceId(String dataResourceId) { + this.dataResourceId = dataResourceId; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public AlgorithmType getAlgorithm() { + return algorithm; + } + + public void setAlgorithm(AlgorithmType algorithm) { + this.algorithm = algorithm; + } + + public Long getRowCount() { + return rowCount; + } + + public void setRowCount(Long rowCount) { + this.rowCount = rowCount; + } + + public List getFieldInfoList() { + return fieldInfoList; + } + + public void setFieldInfoList(List fieldInfoList) { + this.fieldInfoList = fieldInfoList; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public Boolean getTrace() { + return isTrace; + } + + public void setTrace(Boolean trace) { + isTrace = trace; + } + + public String getTraceColumn() { + return traceColumn; + } + + public void setTraceColumn(String traceColumn) { + this.traceColumn = traceColumn; + } + + public String getPartnerDataResourceId() { + return partnerDataResourceId; + } + + public void setPartnerDataResourceId(String partnerDataResourceId) { + this.partnerDataResourceId = partnerDataResourceId; + } + + public DataResourceType getPartnerDataResourceType() { + return partnerDataResourceType; + } + + public void setPartnerDataResourceType(DataResourceType partnerDataResourceType) { + this.partnerDataResourceType = partnerDataResourceType; + } + + public Long getPartnerRowCount() { + return partnerRowCount; + } + + public void setPartnerRowCount(Long partnerRowCount) { + this.partnerRowCount = partnerRowCount; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/AuditApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/AuditApi.java new file mode 100644 index 000000000..57258989f --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/AuditApi.java @@ -0,0 +1,132 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.board.service.util.primarykey.FieldInfo; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.List; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/task/audit", name = "任务处理", desc = "任务处理") +public class AuditApi extends AbstractNoneOutputApi { + + @Autowired + FusionTaskService fusionTaskService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + fusionTaskService.handle(input); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(name = "businessId", require = true) + private String businessId; + + @Check(name = "主键处理") + private List fieldInfoList; + + @Check(name = "是否追溯") + private Boolean isTrace = false; + + @Check(name = "追溯字段") + private String traceColumn; + + @Check(name = "审核字段", require = true) + private AuditStatus auditStatus; + + @Check(name = "审核评论") + private String auditComment; + + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + +// if (DataResourceType.DataSet.equals(dataResourceType) && fieldInfoList.isEmpty()) { +// throw new StatusCodeWithException("请设置主键", StatusCode.PARAMETER_VALUE_INVALID); +// } + + if (isTrace && StringUtil.isEmpty(traceColumn)) { + throw new StatusCodeWithException("追溯字段不能为空", StatusCode.PARAMETER_VALUE_INVALID); + } + + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public List getFieldInfoList() { + return fieldInfoList; + } + + public void setFieldInfoList(List fieldInfoList) { + this.fieldInfoList = fieldInfoList; + } + + public Boolean getTrace() { + return isTrace; + } + + public void setTrace(Boolean trace) { + isTrace = trace; + } + + public String getTraceColumn() { + return traceColumn; + } + + public void setTraceColumn(String traceColumn) { + this.traceColumn = traceColumn; + } + + public AuditStatus getAuditStatus() { + return auditStatus; + } + + public void setAuditStatus(AuditStatus auditStatus) { + this.auditStatus = auditStatus; + } + + public String getAuditComment() { + return auditComment; + } + + public void setAuditComment(String auditComment) { + this.auditComment = auditComment; + } + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/AuditCallbackApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/AuditCallbackApi.java new file mode 100644 index 000000000..bf2f33c27 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/AuditCallbackApi.java @@ -0,0 +1,90 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.board.service.service.fusion.CallbackService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/audit/callback", name = "接收消息接口", login = false, rsaVerify = true) +public class AuditCallbackApi extends AbstractNoneOutputApi { + @Autowired + CallbackService callbackService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + callbackService.audit(input); + return success(); + } + + + public static class Input extends AbstractApiInput { + + @Check(name = "指定操作的businessId", require = true) + private String businessId; + + @Check(name = "审核字段", require = true) + private AuditStatus auditStatus; + + @Check(name = "审核评论") + private String auditComment; + + @Check(name = "审核评论") + private String partnerHashFunction; + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public AuditStatus getAuditStatus() { + return auditStatus; + } + + public void setAuditStatus(AuditStatus auditStatus) { + this.auditStatus = auditStatus; + } + + public String getAuditComment() { + return auditComment; + } + + public void setAuditComment(String auditComment) { + this.auditComment = auditComment; + } + + public String getPartnerHashFunction() { + return partnerHashFunction; + } + + public void setPartnerHashFunction(String partnerHashFunction) { + this.partnerHashFunction = partnerHashFunction; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/DeleteApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/DeleteApi.java new file mode 100644 index 000000000..bcac81e98 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/DeleteApi.java @@ -0,0 +1,54 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/task/delete", name = "删除任务", desc = "删除任务") +public class DeleteApi extends AbstractNoneOutputApi { + @Autowired + FusionTaskService fusionTaskService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + fusionTaskService.delete(input.id); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(name = "id", require = true) + String id; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/DetailApi.java new file mode 100644 index 000000000..2a89fafa5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/DetailApi.java @@ -0,0 +1,100 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.board.service.dto.fusion.FusionTaskOutput; +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.fusion.core.utils.CryptoUtils; +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.params.RSAKeyParameters; +import org.bouncycastle.crypto.params.RSAPrivateCrtKeyParameters; +import org.springframework.beans.factory.annotation.Autowired; + +import java.math.BigInteger; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/task/detail", name = "任务列表", desc = "任务列表") +public class DetailApi extends AbstractApi { + @Autowired + FusionTaskService fusionTaskService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + return success(fusionTaskService.detail(input.id)); + } + + public static void main(String[] args) { + AsymmetricCipherKeyPair keyPair = CryptoUtils.generateKeys(1024); + + RSAKeyParameters pk = (RSAKeyParameters) keyPair.getPublic(); + RSAKeyParameters sk = (RSAPrivateCrtKeyParameters) keyPair.getPrivate(); + BigInteger e = pk.getExponent(); + BigInteger N = pk.getModulus(); + BigInteger d = sk.getExponent(); + BigInteger p = ((RSAPrivateCrtKeyParameters) sk).getP(); + BigInteger q = ((RSAPrivateCrtKeyParameters) sk).getQ(); + + + long s1 = System.currentTimeMillis(); + BigInteger tq = p.modInverse(q); + BigInteger tp = q.modInverse(p); + BigInteger cp = tp.multiply(q); + BigInteger cq = tq.multiply(p); + + + + for(int i=1; i <= 200000;i++) { + BigInteger x = BigInteger.valueOf(4328423048302L * i); + + BigInteger rp = x.modPow(d.remainder(p.subtract(BigInteger.valueOf(1))), p); + BigInteger rq = x.modPow(d.remainder(q.subtract(BigInteger.valueOf(1))), q); + BigInteger r = (rp.multiply(cp).add(rq.multiply(cq))).remainder(N); + } + long s2 = System.currentTimeMillis(); + System.out.println(s2-s1 + "ms"); + long s3 = System.currentTimeMillis(); + for(int i=1; i <= 200000;i++) { + BigInteger x = BigInteger.valueOf(4328423048302L * i); + BigInteger r1 = x.modPow(d, N); + } + long s4 = System.currentTimeMillis(); + System.out.println(s4-s3 + "ms"); +// System.out.println(r1.equals(r)); + + } + + public static class Input extends AbstractApiInput { + @Check(name = "指定操作的taskId", require = true) + private String id; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/InfoApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/InfoApi.java new file mode 100644 index 000000000..0830d946e --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/InfoApi.java @@ -0,0 +1,57 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.board.service.fusion.manager.ActuatorManager; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/task/info", name = "查询任务进度", desc = "查询任务进度") +public class InfoApi extends AbstractApi { + + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + return success(ActuatorManager.getTaskInfo(input.getBusinessId())); + } + + public static class Input extends AbstractApiInput { + @Check(name = "指定操作的taskId", require = true) + private String businessId; + + //region + + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/PagingApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/PagingApi.java new file mode 100644 index 000000000..5732d3aed --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/PagingApi.java @@ -0,0 +1,84 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.board.service.dto.base.PagingInput; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.fusion.FusionTaskOutput; +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.fusion.core.enums.FusionTaskStatus; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/task/paging", name = "任务列表", desc = "任务列表") +public class PagingApi extends AbstractApi> { + @Autowired + FusionTaskService fusionTaskService; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + return success(fusionTaskService.paging(input)); + } + + + public static class Input extends PagingInput { + @Check(name = "projectId") + private String projectId; + + @Check(name = "businessId") + private String businessId; + + @Check(name = "任务状态") + private FusionTaskStatus status; + + //region + + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public FusionTaskStatus getStatus() { + return status; + } + + public void setStatus(FusionTaskStatus status) { + this.status = status; + } + + //endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/ReceiveApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/ReceiveApi.java new file mode 100644 index 000000000..f464ab60f --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/ReceiveApi.java @@ -0,0 +1,203 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.fusion.core.enums.AlgorithmType; +import com.welab.wefe.fusion.core.enums.PSIActuatorRole; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author hunter.zhao + */ +@Api(path = "task/receive", name = "接收对齐请求", desc = "接收对齐请求", login = false, rsaVerify = true) +public class ReceiveApi extends AbstractNoneOutputApi { + @Autowired + FusionTaskService fusionTaskService; + + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + fusionTaskService.alignByPartner(input); + return success(); + } + + public static class Input extends AbstractApiInput { + + @Check(name = "指定操作的projectId", require = true) + private String projectId; + + @Check(name = "指定操作的businessId", require = true) + private String businessId; + + @Check(name = "任务名称", require = true) + private String name; + + @Check(name = "合作方id", require = true) + private String dstMemberId; + + @Check(name = "数据资源id", require = true) + private String dataResourceId; + + @Check(name = "数据资源类型", require = true) + private DataResourceType dataResourceType; + + @Check(name = "对方数据资源id", require = true) + private String partnerDataResourceId; + + @Check(name = "数据资源类型", require = true) + private DataResourceType partnerDataResourceType; + + @Check(name = "对方数据融合公式", require = true) + private String partnerHashFunction; + + @Check(name = "数据资源的数据量") + private Long rowCount; + + @Check(name = "合作方的数据资源的数据量") + private Long partnerRowCount; + + @Check(name = "对齐角色", require = true) + private PSIActuatorRole psiActuatorRole; + + @Check(name = "算法", require = true) + private AlgorithmType algorithm; + + @Check(name = "描述", regex = "^.{0,1024}$", messageOnInvalid = "你写的描述太多了~") + private String description; + + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getDstMemberId() { + return dstMemberId; + } + + public void setDstMemberId(String dstMemberId) { + this.dstMemberId = dstMemberId; + } + + public Long getRowCount() { + return rowCount; + } + + public void setRowCount(Long rowCount) { + this.rowCount = rowCount; + } + + public PSIActuatorRole getPsiActuatorRole() { + return psiActuatorRole; + } + + public void setPsiActuatorRole(PSIActuatorRole psiActuatorRole) { + this.psiActuatorRole = psiActuatorRole; + } + + public AlgorithmType getAlgorithm() { + return algorithm; + } + + public void setAlgorithm(AlgorithmType algorithm) { + this.algorithm = algorithm; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public String getDataResourceId() { + return dataResourceId; + } + + public void setDataResourceId(String dataResourceId) { + this.dataResourceId = dataResourceId; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public String getPartnerDataResourceId() { + return partnerDataResourceId; + } + + public void setPartnerDataResourceId(String partnerDataResourceId) { + this.partnerDataResourceId = partnerDataResourceId; + } + + public DataResourceType getPartnerDataResourceType() { + return partnerDataResourceType; + } + + public void setPartnerDataResourceType(DataResourceType partnerDataResourceType) { + this.partnerDataResourceType = partnerDataResourceType; + } + + public Long getPartnerRowCount() { + return partnerRowCount; + } + + public void setPartnerRowCount(Long partnerRowCount) { + this.partnerRowCount = partnerRowCount; + } + + public String getPartnerHashFunction() { + return partnerHashFunction; + } + + public void setPartnerHashFunction(String partnerHashFunction) { + this.partnerHashFunction = partnerHashFunction; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/RestartApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/RestartApi.java new file mode 100644 index 000000000..d9724f2a5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/RestartApi.java @@ -0,0 +1,41 @@ +package com.welab.wefe.board.service.api.project.fusion.task; + +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/task/restart", name = "任务重跑对齐任务", desc = "任务重跑对齐任务") +public class RestartApi extends AbstractNoneOutputApi { + + @Autowired + FusionTaskService fusionTaskService; + + @Override + protected ApiResult handler(AddApi.Input input) throws StatusCodeWithException { + fusionTaskService.add(input); + return success(); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/TaskStatusApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/TaskStatusApi.java new file mode 100644 index 000000000..16c7331b9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/TaskStatusApi.java @@ -0,0 +1,43 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.fusion.core.enums.FusionTaskStatus; + +import java.util.EnumSet; + +/** + * @author hunter.zhao + */ +@Api(path = "fusion/task/status", name = "任务状态", desc = "任务状态") +public class TaskStatusApi extends AbstractApi> { + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + EnumSet statuses = EnumSet.allOf(FusionTaskStatus.class); + return success(statuses); + } + + public static class Input extends AbstractApiInput { + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/UpdateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/UpdateApi.java new file mode 100644 index 000000000..b2accc6c4 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/fusion/task/UpdateApi.java @@ -0,0 +1,112 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.fusion.task; + +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.fusion.core.enums.AlgorithmType; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author hunter.zhao + */ +@Api(path = "task/update", name = "修改对齐任务", desc = "修改对齐任务") +public class UpdateApi extends AbstractNoneOutputApi { + + @Autowired + FusionTaskService taskService; + + @Override + protected ApiResult handler(Input input) throws StatusCodeWithException { + taskService.update(input); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(name = "任务Id", require = true) + private String id; + + @Check(name = "任务名称", require = true) + private String name; + + @Check(name = "合作方id", require = true) + private String dstMemberId; + + @Check(name = "数据资源id", require = true) + private String dataResourceId; + + @Check(name = "布隆过滤器id", require = true) + private DataResourceType dataResourceType; + + @Check(name = "算法") + private AlgorithmType algorithm = AlgorithmType.RSA_PSI; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getDstMemberId() { + return dstMemberId; + } + + public void setDstMemberId(String dstMemberId) { + this.dstMemberId = dstMemberId; + } + + public String getDataResourceId() { + return dataResourceId; + } + + public void setDataResourceId(String dataResourceId) { + this.dataResourceId = dataResourceId; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public AlgorithmType getAlgorithm() { + return algorithm; + } + + public void setAlgorithm(AlgorithmType algorithm) { + this.algorithm = algorithm; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/DetailApi.java index 537be8755..644f64ed3 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/DetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/DetailApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,7 +26,6 @@ import com.welab.wefe.board.service.service.ProjectFlowNodeService; import com.welab.wefe.board.service.service.TaskService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.StringUtil; @@ -34,6 +33,7 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import java.util.List; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/DownloadLogApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/DownloadLogApi.java index e291d4876..7180da268 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/DownloadLogApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/DownloadLogApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -33,7 +33,7 @@ /** * @author zane.luo */ -@Api(path = "/job/log/download", name = "download job log") +@Api(path = "job/log/download", name = "download job log") public class DownloadLogApi extends AbstractApi { @Autowired diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/GetJobProgressApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/GetJobProgressApi.java index ec3bfd2ed..216813c7f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/GetJobProgressApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/GetJobProgressApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,6 @@ package com.welab.wefe.board.service.api.project.job; -import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.api.gateway.GetMemberJobProgressApi; import com.welab.wefe.board.service.dto.entity.job.JobMemberOutputModel; import com.welab.wefe.board.service.dto.vo.JobProgressOutput; @@ -65,15 +64,12 @@ protected ApiResult> handle(Input input) throws StatusCo else { try { - ApiResult apiResult = gatewayService.callOtherMemberBoard( + progress = gatewayService.callOtherMemberBoard( member.getMemberId(), GetMemberJobProgressApi.class, - new GetMemberJobProgressApi.Input(input.jobId, member.getJobRole()) + new GetMemberJobProgressApi.Input(input.jobId, member.getJobRole()), + JobProgressOutput.class ); - - if (apiResult.data != null) { - progress = ((JSONObject) apiResult.data).toJavaObject(JobProgressOutput.class); - } } catch (Exception e) { progress = JobProgressOutput.fail(member, e); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/OnJobFinishedApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/OnJobFinishedApi.java index da171c16b..053c31f35 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/OnJobFinishedApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/OnJobFinishedApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/PreviewJobApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/PreviewJobApi.java index e0932353c..7e4b46402 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/PreviewJobApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/PreviewJobApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,13 +20,13 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.service.JobService; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; import org.springframework.beans.factory.annotation.Autowired; import java.util.List; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/QueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/QueryApi.java index e4f280c2d..325027338 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/QueryApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/QueryApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,12 +20,12 @@ import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.dto.entity.job.JobListOutputModel; import com.welab.wefe.board.service.service.FlowJobService; -import com.welab.wefe.common.enums.JobStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobStatus; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/ResumeJobApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/ResumeJobApi.java index bf2ba9bbc..ca44a8508 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/ResumeJobApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/ResumeJobApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/StopJobApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/StopJobApi.java index 3e1292a53..6215256bb 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/StopJobApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/StopJobApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/UpdateJobStatusApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/UpdateJobStatusApi.java index 69e446a86..a7de39246 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/UpdateJobStatusApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/UpdateJobStatusApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,13 +17,13 @@ package com.welab.wefe.board.service.api.project.job; import com.welab.wefe.board.service.service.JobService; -import com.welab.wefe.common.enums.JobStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobStatus; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/ViewDataSetApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/ViewDataSetApi.java index 24c7f4638..a77eca61f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/ViewDataSetApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/ViewDataSetApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,7 +19,6 @@ import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; import com.welab.wefe.board.service.service.TaskResultService; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; @@ -28,6 +27,7 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; @@ -39,7 +39,7 @@ /** * @author zane.luo */ -@Api(path = "/job/data_set/view", name = "view data set data rows", login = false) +@Api(path = "job/data_set/view", name = "view data set data rows", login = false) public class ViewDataSetApi extends AbstractApi { @Autowired diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/DetailApi.java index 9e5c5e4af..95ed98a08 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/DetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/DetailApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,6 +17,7 @@ package com.welab.wefe.board.service.api.project.job.task; import com.welab.wefe.board.service.component.Components; +import com.welab.wefe.board.service.component.base.AbstractComponent; import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.dto.entity.job.JobOutputModel; import com.welab.wefe.board.service.dto.entity.job.TaskOutputView; @@ -32,6 +33,7 @@ import com.welab.wefe.common.web.dto.ApiResult; import org.springframework.beans.factory.annotation.Autowired; +import java.util.Arrays; import java.util.List; /** @@ -54,10 +56,12 @@ protected ApiResult handle(Input input) throws StatusCodeWithException { List results = null; if (input.needResult) { - results = Components - .get(task.getTaskType()) - .getTaskAllResult(task.getTaskId()); - + AbstractComponent component = Components.get(task.getTaskType()); + if (StringUtil.isEmpty(input.resultType)) { + results = component.getTaskAllResult(task.getTaskId()); + } else { + results = Arrays.asList(component.getTaskResult(task.getTaskId(), input.resultType)); + } } return success( @@ -117,6 +121,9 @@ public static class Input extends AbstractApiInput { @Check(name = "是否需要返回 task 执行结果", require = true, desc = "task 的执行结果体积较大,在不需要时,请指定为 false") private boolean needResult; + @Check(name = "task result 的 type") + public String resultType; + @Override public void checkAndStandardize() throws StatusCodeWithException { super.checkAndStandardize(); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetFeatureApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetFeatureApi.java index 1898a6de2..b5e88c5f0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetFeatureApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetFeatureApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -31,7 +31,7 @@ /** * @author lonnie */ -@Api(path = "/flow/job/task/feature", name = "get feature list", desc = "Get the feature column in the output result of feature calculation in the parent node") +@Api(path = "flow/job/task/feature", name = "get feature list", desc = "Get the feature column in the output result of feature calculation in the parent node") public class GetFeatureApi extends AbstractApi { @Autowired @@ -84,6 +84,12 @@ public static class Output { private boolean hasFeatureStatistic; private boolean hasFeatureCalculation; + + private boolean hasCV; + + private boolean hasIV; + + private boolean hasLossRate; List members; @@ -110,5 +116,29 @@ public List getMembers() { public void setMembers(List members) { this.members = members; } + + public boolean isHasCV() { + return hasCV; + } + + public void setHasCV(boolean hasCV) { + this.hasCV = hasCV; + } + + public boolean isHasIV() { + return hasIV; + } + + public void setHasIV(boolean hasIV) { + this.hasIV = hasIV; + } + + public boolean isHasLossRate() { + return hasLossRate; + } + + public void setHasLossRate(boolean hasLossRate) { + this.hasLossRate = hasLossRate; + } } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetResultApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetResultApi.java index a2b76385c..00b07bc1f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetResultApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetResultApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,25 +16,24 @@ package com.welab.wefe.board.service.api.project.job.task; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import org.springframework.beans.factory.annotation.Autowired; - import com.welab.wefe.board.service.component.Components; import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.dto.entity.job.TaskResultOutputModel; import com.welab.wefe.board.service.service.TaskService; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; /** * @author zane.luo @@ -42,64 +41,74 @@ @Api(path = "flow/job/task/result", name = "get task result", desc = "Use taskId or flowId + nodeId to get the node execution result.") public class GetResultApi extends AbstractApi> { - @Autowired - private TaskService taskService; - - @Override - protected ApiResult> handle(Input input) throws StatusCodeWithException { - - List tasks = taskService.findAll(input); - if (tasks == null || tasks.isEmpty()) { - return success(); - } - List results = new ArrayList<>(); - Set temp = new HashSet<>(); - for (TaskMySqlModel task : tasks) { - String taskConf = task.getTaskConf(); - JObject taskConfigJson = JObject.create(taskConf); - TaskResultOutputModel result = Components.get(task.getTaskType()).getTaskResult(task.getTaskId(), - input.type); - if (result == null) { - result = new TaskResultOutputModel(); - } - // put task info to TaskResultOutputModel - result.setStatus(task.getStatus()); - result.setStartTime(task.getStartTime()); - result.setFinishTime(task.getFinishTime()); - result.setMessage(task.getMessage()); - result.setErrorCause(task.getErrorCause()); - result.setPosition(task.getPosition()); - result.setSpend(task.getSpend()); - result.setMembers(taskConfigJson.getJObject("task").getJSONList("members")); - if (result.getResult() != null && !temp.add(result.getResult().toJSONString()) && task.getRole() == JobMemberRole.provider - && (task.getTaskType() == ComponentType.MixStatistic - || task.getTaskType() == ComponentType.MixBinning - || task.getTaskType() == ComponentType.FillMissingValue - || task.getTaskType() == ComponentType.MixLR)) { - continue; - } - results.add(result); - } - - return success(results); - } - - public static class Input extends DetailApi.Input { - - @Check(name = "结果类型") - private String type; - - // region getter/setter - - public String getType() { - return type; - } - - public void setType(String type) { - this.type = type; - } - - // endregion - } + @Autowired + private TaskService taskService; + + @Override + protected ApiResult> handle(Input input) throws StatusCodeWithException { + + List tasks = taskService.findAll(input); + if (tasks == null || tasks.isEmpty()) { + return success(); + } + List results = new ArrayList<>(); + Set temp = new HashSet<>(); + for (TaskMySqlModel task : tasks) { + String taskConf = task.getTaskConf(); + JObject taskConfigJson = JObject.create(taskConf); + TaskResultOutputModel result = Components.get(task.getTaskType()) + .getTaskResult(task.getTaskId(), input.type); + if (result == null) { + result = new TaskResultOutputModel(); + } + // put task info to TaskResultOutputModel + result.setStatus(task.getStatus()); + result.setStartTime(task.getStartTime()); + result.setFinishTime(task.getFinishTime()); + result.setMessage(task.getMessage()); + result.setErrorCause(task.getErrorCause()); + result.setPosition(task.getPosition()); + result.setSpend(task.getSpend()); + result.setJobId(task.getJobId()); + result.setFlowId(task.getFlowId()); + result.setFlowNodeId(task.getFlowNodeId()); + result.setTaskId(task.getTaskId()); + + JObject taskInfo = taskConfigJson.getJObject("task"); + if (taskInfo != null) { + result.setMembers(taskConfigJson.getJObject("task").getJSONList("members")); + } + + // 防止返回两个相同的结果 + if (result.getResult() != null && !temp.add(result.getResult().toJSONString()) && task.getRole() == JobMemberRole.provider + && (task.getTaskType() == ComponentType.MixStatistic + || task.getTaskType() == ComponentType.MixBinning + || task.getTaskType() == ComponentType.FillMissingValue + || task.getTaskType() == ComponentType.MixLR)) { + continue; + } + results.add(result); + } + + return success(results); + } + + public static class Input extends DetailApi.Input { + + @Check(name = "结果类型") + private String type; + + // region getter/setter + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + // endregion + } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetResultHistoryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetResultHistoryApi.java index 810b0a847..07c95f1de 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetResultHistoryApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/GetResultHistoryApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,14 +20,14 @@ import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.dto.entity.job.TaskResultOutputModel; import com.welab.wefe.board.service.service.TaskService; -import com.welab.wefe.board.service.util.ModelMapper; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import java.util.ArrayList; @@ -68,7 +68,10 @@ protected ApiResult handle(Input input) throws StatusCodeWithException { result.setErrorCause(task.getErrorCause()); result.setPosition(task.getPosition()); result.setSpend(task.getSpend()); - + result.setJobId(task.getJobId()); + result.setFlowId(task.getFlowId()); + result.setFlowNodeId(task.getFlowNodeId()); + result.setTaskId(task.getTaskId()); list.add(result); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/SelectFeatureApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/SelectFeatureApi.java index 375002384..3cc5776da 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/SelectFeatureApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/SelectFeatureApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,7 +18,6 @@ import com.welab.wefe.board.service.service.TaskResultService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; @@ -28,6 +27,7 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import java.util.List; @@ -36,7 +36,7 @@ * @author lonnie */ @Api( - path = "/flow/job/task/select", + path = "flow/job/task/select", name = "filter features", desc = "Through the passed cv/iv value and feature rate, select the features that meet the conditions" ) diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/TaskProgressDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/TaskProgressDetailApi.java index c9652df22..7f0aa53d1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/TaskProgressDetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/TaskProgressDetailApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,14 +19,14 @@ import com.welab.wefe.board.service.database.entity.job.TaskProgressMysqlModel; import com.welab.wefe.board.service.dto.entity.job.TaskProgressOuputModel; import com.welab.wefe.board.service.service.TaskProgressService; -import com.welab.wefe.board.service.util.ModelMapper; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/detail.http index 67ae103c9..49b2213e4 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/detail.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/detail.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/flow/job/task/detail +POST http://localhost:8080/board-service/flow/job/task/detail Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/getFeature.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/getFeature.http index b623df026..7bd6c4450 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/getFeature.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/getFeature.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/flow/job/task/feature +POST http://localhost:8080/board-service/flow/job/task/feature Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/result.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/result.http index 23f28f0b3..29a206232 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/result.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/result.http @@ -1,11 +1,11 @@ -POST {{baseUrl}}/flow/job/task/result +POST http://localhost:8080/board-service/flow/job/task/result Content-Type: application/json +token:{{token}} { - "jobId": "69e6df7d606c470a83dadf4ddd41d8d9", - "flowId": "a2e07682b5d74eaab1e0227743e6dbd2 ", - "flowNodeId": "16167496011639956", - "type": "ks" + "jobId": "e8bacfd9e1924e2ca9de019fc9efc039", + "flowNodeId": "16403319532745323", + "type": "loss" } ### \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/selectFeature.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/selectFeature.http index 7f7bbac63..0eaf5601c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/selectFeature.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/selectFeature.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/flow/job/task/select +POST http://localhost:8080/board-service/flow/job/task/select Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/task_progress_detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/task_progress_detail.http index 1ba048ba6..7eb4b355f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/task_progress_detail.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/task/test/task_progress_detail.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/task/progress/detail +POST http://localhost:8080/board-service/task/progress/detail Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/detail.http index 4fd794bc0..b15c1340d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/detail.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/detail.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/flow/job/detail +POST http://localhost:8080/board-service/flow/job/detail Content-Type: application/json { @@ -11,7 +11,7 @@ Content-Type: application/json ### -POST {{baseUrl}}/flow/job/detail +POST http://localhost:8080/board-service/flow/job/detail Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/getl-progress.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/getl-progress.http index 2db5b05f0..ffe59ca5b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/getl-progress.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/getl-progress.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/flow/job/get_progress +POST http://localhost:8080/board-service/flow/job/get_progress Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/log-download.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/log-download.http index acd3b48a4..24a5d9416 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/log-download.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/log-download.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/job/log/download +POST http://localhost:8080/board-service/job/log/download Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/preview.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/preview.http index 38e019501..1dbe9dfdb 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/preview.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/preview.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/flow/job/preview +POST http://localhost:8080/board-service/project/flow/job/preview Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/query.http index 348860855..0dd57ffcc 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/query.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/query.http @@ -1,15 +1,16 @@ -POST {{baseUrl}}/flow/job/query +POST http://localhost:8080/board-service/flow/job/query Content-Type: application/json +token:{{token}} { - "flow_id": "015d34e5192e4831a4d6ea321a3584a5", - "job_id": "10d421d323974914ad5a934f1c159a21" + "flow_id": "36ab34740ace48558f887e5ff77f99c1", + "project_id": "6f4caef3760a4a3da785f4f46bdd41d7" } ### -POST {{baseUrl}}/flow/job/query +POST http://localhost:8080/board-service/flow/job/query Content-Type: application/json { @@ -19,7 +20,7 @@ Content-Type: application/json ### -POST {{baseUrl}}/flow/job/query +POST http://localhost:8080/board-service/flow/job/query Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/stop.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/stop.http index 97fe815e0..6a9368d65 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/stop.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/stop.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/flow/job/stop +POST http://localhost:8080/board-service/flow/job/stop Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/view-dataset.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/view-dataset.http index f4a082650..76cfe08cf 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/view-dataset.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/job/test/view-dataset.http @@ -1,8 +1,8 @@ -POST {{baseUrl}}/job/data_set/view +POST http://localhost:8080/board-service/job/data_set/view Content-Type: application/json { - "job_id":"18bc3d0dde244035ba9a8dc16fd58fae", + "job_id": "18bc3d0dde244035ba9a8dc16fd58fae", "node_id": "16196070742596633", "member_role": "promoter" } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/AddApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/AddApi.java index 1cc51bbc9..a785f83c9 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/AddApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/AddApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ExitProjectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ExitProjectApi.java index 33b0df2c4..fd21e8e44 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ExitProjectApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ExitProjectApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListAllApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListAllApi.java deleted file mode 100644 index 37a488870..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListAllApi.java +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.project.member; - -import com.welab.wefe.board.service.database.repository.ProjectMemberRepository; -import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.util.StringUtil; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.util.Comparator; -import java.util.List; -import java.util.stream.Collectors; - -/** - * @author zane.luo - */ -@Api(path = "project/member/all", name = "Get a list of all the members who work with me") -public class ListAllApi extends AbstractApi { - - @Autowired - private ProjectMemberRepository projectMemberRepository; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - List list = projectMemberRepository.listAllMemberId(); - - List output = list - .parallelStream() - .map(x -> new Member(x, CacheObjects.getMemberName(x))) - .filter(x -> StringUtil.isNotEmpty(x.memberName)) - .sorted(Comparator.comparing(x -> x.memberName == null ? "" : x.memberName)) - .collect(Collectors.toList()); - - return success(new Output(output)); - } - - public static class Input extends AbstractApiInput { - - } - - public static class Output { - private List list; - - public Output(List list) { - this.list = list; - } - - public List getList() { - return list; - } - - public void setList(List list) { - this.list = list; - } - } - - public static class Member { - public String memberId; - public String memberName; - - public Member() { - } - - public Member(String memberId, String memberName) { - this.memberId = memberId; - this.memberName = memberName; - } - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListApi.java deleted file mode 100644 index f5282f835..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListApi.java +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.project.member; - -import com.welab.wefe.board.service.database.entity.job.ProjectMemberMySqlModel; -import com.welab.wefe.board.service.dto.entity.project.ProjectMemberOutputModel; -import com.welab.wefe.board.service.service.ProjectMemberService; -import com.welab.wefe.board.service.util.ModelMapper; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.util.List; -import java.util.stream.Collectors; - -/** - * @author zane.luo - */ -@Api(path = "project/member/list", name = "Get the list of members in the project") -public class ListApi extends AbstractApi { - - @Autowired - private ProjectMemberService projectMemberService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - List list = projectMemberService.findList(input); - - List output = list - .parallelStream() - .map(x -> ModelMapper.map(x, ProjectMemberOutputModel.class)) - .collect(Collectors.toList()); - - return success(new Output(output)); - } - - public static class Input extends AbstractApiInput { - - public Input() { - - } - - public Input(String projectId) { - this.projectId = projectId; - } - - @Check(name = "项目Id", require = true) - private String projectId; - - private String ootJobId; - - //region getter/setter - - public String getProjectId() { - return projectId; - } - - public void setProjectId(String projectId) { - this.projectId = projectId; - } - - public String getOotJobId() { - return ootJobId; - } - - public void setOotJobId(String ootJobId) { - this.ootJobId = ootJobId; - } - //endregion - } - - public static class Output { - private List list; - - public Output(List list) { - this.list = list; - } - - public List getList() { - return list; - } - - public void setList(List list) { - this.list = list; - } - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListInAllProjectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListInAllProjectApi.java new file mode 100644 index 000000000..26ac498d9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListInAllProjectApi.java @@ -0,0 +1,88 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.member; + +import com.welab.wefe.board.service.database.repository.ProjectMemberRepository; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * @author zane.luo + */ +@Api(path = "project/member/all", name = "Get a list of all the members who work with me") +public class ListInAllProjectApi extends AbstractApi { + + @Autowired + private ProjectMemberRepository projectMemberRepository; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + List list = projectMemberRepository.listAllMemberId(); + + List output = list + .parallelStream() + .map(x -> new Member(x, CacheObjects.getMemberName(x))) + .filter(x -> StringUtil.isNotEmpty(x.memberName)) + .sorted(Comparator.comparing(x -> x.memberName == null ? "" : x.memberName)) + .collect(Collectors.toList()); + + return success(new Output(output)); + } + + public static class Input extends AbstractApiInput { + + } + + public static class Output { + private List list; + + public Output(List list) { + this.list = list; + } + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + } + + public static class Member { + public String memberId; + public String memberName; + + public Member() { + } + + public Member(String memberId, String memberName) { + this.memberId = memberId; + this.memberName = memberName; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListInProjectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListInProjectApi.java new file mode 100644 index 000000000..f7089213e --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/ListInProjectApi.java @@ -0,0 +1,111 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.member; + +import com.welab.wefe.board.service.database.entity.job.ProjectMemberMySqlModel; +import com.welab.wefe.board.service.dto.entity.project.ProjectMemberOutputModel; +import com.welab.wefe.board.service.service.ProjectMemberService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * @author zane.luo + */ +@Api(path = "project/member/list", name = "Get the list of members in the project") +public class ListInProjectApi extends AbstractApi { + + @Autowired + private ProjectMemberService projectMemberService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + List list = projectMemberService.findList(input); + + List output = list + .parallelStream() + .map(x -> ModelMapper.map(x, ProjectMemberOutputModel.class)) + .collect(Collectors.toList()); + + return success(new Output(output)); + } + + public static class Input extends AbstractApiInput { + + @Check(name = "项目Id", require = true) + private String projectId; + + private String ootJobId; + + public Input() { + } + + public Input(String projectId) { + this.projectId = projectId; + } + + //region getter/setter + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getOotJobId() { + return ootJobId; + } + + public void setOotJobId(String ootJobId) { + this.ootJobId = ootJobId; + } + //endregion + } + + public static class Output { + private List list; + + public Output() { + } + + public Output(List list) { + this.list = list; + } + + // region getter/setter + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + + // endregion + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/OnlineCheckApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/OnlineCheckApi.java index 4c987f367..9c19df33e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/OnlineCheckApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/OnlineCheckApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/RemoveApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/RemoveApi.java index be91746fb..a4c038188 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/RemoveApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/RemoveApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,13 +17,13 @@ package com.welab.wefe.board.service.api.project.member; import com.welab.wefe.board.service.service.ProjectService; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractNoneOutputApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/AuditApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/AuditApi.java index be07f0f54..0d9b6941f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/AuditApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/AuditApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,7 +18,6 @@ import com.welab.wefe.board.service.service.ProjectMemberAuditService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.AuditStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.StringUtil; @@ -26,6 +25,7 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.AuditStatus; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/ListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/ListApi.java deleted file mode 100644 index 7feae7c44..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/ListApi.java +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.project.member.audit; - -import com.welab.wefe.board.service.database.entity.job.ProjectMemberAuditMySqlModel; -import com.welab.wefe.board.service.dto.entity.ProjectMemberAuditOutput; -import com.welab.wefe.board.service.service.ProjectMemberAuditService; -import com.welab.wefe.board.service.util.ModelMapper; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.util.List; -import java.util.stream.Collectors; - -/** - * @author zane.luo - */ -@Api(path = "project/member/add/audit/list", name = "Get the review status of new members in the project") -public class ListApi extends AbstractApi { - - @Autowired - private ProjectMemberAuditService projectMemberAuditService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - List list = projectMemberAuditService.listAll(input.projectId, input.memberId); - - List output = list - .parallelStream() - .map(x -> ModelMapper.map(x, ProjectMemberAuditOutput.class)) - .collect(Collectors.toList()); - - return success(new Output(output)); - } - - public static class Input extends AbstractApiInput { - @Check(name = "项目Id", require = true) - private String projectId; - - @Check(name = "成员Id", desc = "当成员 Id 为空时查所有成员") - private String memberId; - - - //region getter/setter - - public String getMemberId() { - return memberId; - } - - public void setMemberId(String memberId) { - this.memberId = memberId; - } - - public String getProjectId() { - return projectId; - } - - public void setProjectId(String projectId) { - this.projectId = projectId; - } - - - //endregion - } - - public static class Output { - private List list; - - public Output(List list) { - this.list = list; - } - - public List getList() { - return list; - } - - public void setList(List list) { - this.list = list; - } - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/ProjectMemberAuditListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/ProjectMemberAuditListApi.java new file mode 100644 index 000000000..386dbee77 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/ProjectMemberAuditListApi.java @@ -0,0 +1,100 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.project.member.audit; + +import com.welab.wefe.board.service.database.entity.job.ProjectMemberAuditMySqlModel; +import com.welab.wefe.board.service.dto.entity.ProjectMemberAuditOutput; +import com.welab.wefe.board.service.service.ProjectMemberAuditService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * @author zane.luo + */ +@Api(path = "project/member/add/audit/list", name = "Get the review status of new members in the project") +public class ProjectMemberAuditListApi extends AbstractApi { + + @Autowired + private ProjectMemberAuditService projectMemberAuditService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + List list = projectMemberAuditService.listAll(input.projectId, input.memberId); + + List output = list + .parallelStream() + .map(x -> ModelMapper.map(x, ProjectMemberAuditOutput.class)) + .collect(Collectors.toList()); + + return success(new Output(output)); + } + + public static class Input extends AbstractApiInput { + @Check(name = "项目Id", require = true) + private String projectId; + + @Check(name = "成员Id", desc = "当成员 Id 为空时查所有成员") + private String memberId; + + + //region getter/setter + + public String getMemberId() { + return memberId; + } + + public void setMemberId(String memberId) { + this.memberId = memberId; + } + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + + //endregion + } + + public static class Output { + private List list; + + public Output(List list) { + this.list = list; + } + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/test/list.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/test/list.http index b3d7828e9..f80d27733 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/test/list.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/audit/test/list.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/member/add/audit/list +POST http://localhost:8080/board-service/project/member/add/audit/list Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/add.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/add.http index dfa5994a3..c9503e6a6 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/add.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/add.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/member/add +POST http://localhost:8080/board-service/project/member/add Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/all.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/all.http index 1adbfdb66..3897ad6cd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/all.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/all.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/member/all +POST http://localhost:8080/board-service/project/member/all Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/list.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/list.http index 0babbfef2..5e4f1efa1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/list.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/list.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/member/list +POST http://localhost:8080/board-service/project/member/list Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/onlineCheck.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/onlineCheck.http index ccf0c1a05..f825642d5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/onlineCheck.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/onlineCheck.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/member/online_check +POST http://localhost:8080/board-service/project/member/online_check Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/remove.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/remove.http index 4b0a82796..603f4fe94 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/remove.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/member/test/remove.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/member/remove +POST http://localhost:8080/board-service/project/member/remove Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/DetailApi.java index e47600051..990365c70 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/DetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/DetailApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/QueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/QueryApi.java index 1d3f76106..34fc53d4b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/QueryApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/QueryApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,12 +20,12 @@ import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.dto.entity.modeling_config.ModelingInfoOutputModel; import com.welab.wefe.board.service.service.ProjectFlowService; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.beans.factory.annotation.Autowired; /** @@ -52,7 +52,7 @@ public Input(String projectId) { this.projectId = projectId; } - @Check(name = "项目id", require = true) + @Check(name = "项目id") private String projectId; @Check(name = "任务id") diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/test/detail.http index a82b26d07..86a0328b5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/test/detail.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/test/detail.http @@ -1,10 +1,10 @@ -POST {{baseUrl}}/project/modeling/detail +POST http://localhost:8080/board-service/project/modeling/detail Content-Type: application/json { "jobId": "ac44d921de1146ac947ccfa862cdc632", - "flowNodeId":"16158855548416030", - "taskId":"ac44d921de1146ac947ccfa862cdc632_HorzLR_16158855548416030", + "flowNodeId": "16158855548416030", + "taskId": "ac44d921de1146ac947ccfa862cdc632_HorzLR_16158855548416030", "type": "ks" } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/test/query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/test/query.http index ba8da4e67..417006eb8 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/test/query.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/modeling/test/query.http @@ -1,9 +1,17 @@ -POST {{baseUrl}}/project/modeling/query +POST http://localhost:8080/board-service/project/modeling/query Content-Type: application/json { "projectId": "ac6c60a15875487f8b32864512920653", - "componentType":"HorzLR" + "componentType": "HorzLR" } ### +POST http://localhost:8080/board-service/project/modeling/query +Content-Type: application/json +token: {{token}} + +{ + "projectId": "35ef81d771e24bd09c77a432652061b3", + "job_id": "eded9797545a4cfa93faa5e2af3b30fc" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/CheckExistEvaluationComponentApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/CheckExistEvaluationComponentApi.java index 621e32a97..6a234040e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/CheckExistEvaluationComponentApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/CheckExistEvaluationComponentApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,6 +18,7 @@ import com.welab.wefe.board.service.service.ProjectFlowNodeService; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; @@ -52,22 +53,14 @@ protected ApiResult handle(Input input) throws StatusCodeWithException { } public static class Input extends AbstractApiInput { - /** - * This parameter is used in non OOT mode - */ + @Check(desc = "This parameter is used in non OOT mode") private String flowId; - /** - * The OOT component ID on the canvas (mainly used to find the front node and the OOT node on the canvas. This parameter is used in non OOT mode) - */ + @Check(desc = "The OOT component ID on the canvas (mainly used to find the front node and the OOT node on the canvas. This parameter is used in non OOT mode)") private String nodeId; - /** - * Original model job ID (this parameter is used in OOT mode) - */ + @Check(desc = "Original model job ID (this parameter is used in OOT mode)") private String jobId; - /** - * Original model node ID (this parameter is used in OOT mode) - */ + @Check(desc = "Original model node ID (this parameter is used in OOT mode)") private String modelNodeId; public String getFlowId() { @@ -104,9 +97,7 @@ public void setModelNodeId(String modelNodeId) { } public static class Output extends AbstractApiOutput { - /** - * check result - */ + @Check(name = "check result") private boolean checkResult; public boolean isCheckResult() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/DetailApi.java index aaad178b3..b39b5940e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/DetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/DetailApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,16 +16,19 @@ package com.welab.wefe.board.service.api.project.node; +import com.welab.wefe.board.service.component.deep_learning.ImageDataIOComponent; import com.welab.wefe.board.service.database.entity.job.ProjectFlowNodeMySqlModel; import com.welab.wefe.board.service.dto.entity.job.ProjectFlowNodeOutputModel; import com.welab.wefe.board.service.service.ProjectFlowNodeService; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.beans.factory.annotation.Autowired; /** @@ -44,7 +47,20 @@ protected ApiResult handle(Input input) throws Statu if (one == null) { return success(); } - return success(ModelMapper.map(one, ProjectFlowNodeOutputModel.class)); + + ProjectFlowNodeOutputModel output = ModelMapper.map(one, ProjectFlowNodeOutputModel.class); + output.setParams(one.getParams()); + + // ImageDataIO 节点顺带输出数据集信息。 + if (one.getComponentType() == ComponentType.ImageDataIO) { + if (output.getParams() != null) { + ImageDataIOComponent.Params params = output.getParams().toJavaObject(ImageDataIOComponent.Params.class); + params.fillDataSetDetail(); + output.setParams(JObject.create(params).toString()); + } + } + + return success(output); } public static class Input extends AbstractApiInput { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/UpdateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/UpdateApi.java index fd3cd489c..58e82e3ea 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/UpdateApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/UpdateApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,15 +16,16 @@ package com.welab.wefe.board.service.api.project.node; +import com.welab.wefe.board.service.component.Components; import com.welab.wefe.board.service.dto.entity.job.ProjectFlowNodeOutputModel; import com.welab.wefe.board.service.service.ProjectFlowNodeService; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.beans.factory.annotation.Autowired; import java.util.List; @@ -57,6 +58,15 @@ public static class Input extends AbstractApiInput { @Check(name = "组件参数", require = true, blockXss = false) private String params; + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + // 对表单有效性进行检查 + Components + .get(componentType) + .deserializationParam(params); + } //region getter/setter diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/test/detail.http index f7135f445..fcf978516 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/test/detail.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/test/detail.http @@ -1,13 +1,14 @@ -POST {{baseUrl}}/project/flow/node/detail +POST http://localhost:8080/board-service/project/flow/node/detail Content-Type: application/json +token: {{token}} { - "flow_id": "a9632fcc436f4400bba1726532475f7b", - "nodeId": "1608695123123294703624" + "flow_id": "27f72edb3e7e4ac698dea3fe238b735d", + "nodeId": "16196070742596633" } ### -GET {{baseUrl}}/project/flow/node/detail?nodeId=16088884188236603&flow_id=a9632fcc436f4400bba1726532475f7b \ No newline at end of file +GET http://localhost:8080/board-service/project/flow/node/detail?nodeId=16088884188236603&flow_id=a9632fcc436f4400bba1726532475f7b \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/test/update.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/test/update.http index 3051b82d6..cad7b9ad9 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/test/update.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/node/test/update.http @@ -1,14 +1,80 @@ -POST {{baseUrl}}/project/flow/node/update +POST http://localhost:8080/board-service/project/flow/node/update Content-Type: application/json - +token: {{token}} { - "nodeId": "16100034352598215", - "componentType": "Intersection", - "flow_id": "6b25864fa5c84be0af83ddf4a12e78bb", + "nodeId": "16196070742596633", + "componentType": "DataIO", + "flowId": "27f72edb3e7e4ac698dea3fe238b735d", "params": { - "intersect_method": "dh", - "save_dataset": true + "dataset_list": [ + { + "member_id": "81051c7b7f4f4beaa581ba046f8a6048", + "member_role": "promoter", + "data_set_id": "88047fb892a44e3d8496e4c21a1f9432", + "features": [ + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7", + "x8", + "x9", + "x10", + "x11", + "x12", + "x13", + "x14", + "x15", + "x16", + "x17", + "x18", + "x19", + "x20" + ], + "feature_name_list": "x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16,x17,x18,x19,x20", + "feature_count": 20, + "contains_y": true, + "derived_from": null, + "row_count": 500, + "name": "500-y03" + }, + { + "member_id": "a94d7a71189c44b5a298ad4ed2668f67", + "member_role": "provider", + "data_set_id": "f884fe77643a4d5995c0f29eed71ff9a", + "features": [ + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7", + "x8", + "x9", + "x10", + "x11", + "x12", + "x13", + "x14", + "x15", + "x16", + "x17", + "x18", + "x19", + "x20" + ], + "feature_name_list": "x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16,x17,x18,x19,x20", + "feature_count": 20, + "contains_y": false, + "derived_from": null, + "row_count": 500, + "name": "500-无-y" + } + ] } } @@ -16,51 +82,93 @@ Content-Type: application/json ### 自动创建 -GET {{baseUrl}}/project/flow/node/update?nodeId=16088878120465932&componentType=FeatureStatistic&flow_id=a9632fcc436f4400bba1726532475f7b +GET http://localhost:8080/board-service/project/flow/node/update?nodeId=16088878120465932&componentType=FeatureStatistic&flow_id=a9632fcc436f4400bba1726532475f7b ### debug -POST {{baseUrl}}/project/flow/node/update +POST http://localhost:8080/board-service/project/flow/node/update Content-Type: application/json - +token:{{token}} { - "nodeId": "16088893089992860", - "componentType": "FeatureStatistic", - "flow_id": "a9632fcc436f4400bba1726532475f7b", + "nodeId": "16195987735668056", + "componentType": "DataIO", + "flowId": "fd4606478e734212955ea8dc3f8bcf70", "params": { - "featureMethods": [ - { - "name": "Max", - "value": "" - }, - { - "name": "Min", - "value": "" - }, - { - "name": "Percentile", - "value": 50 - } - ], - "members": [ + "dataset_list": [ { - "memberId": "682d087e965a4c5d8a2807a772a571c1", + "member_id": "087973c99d26410683944bf3f46c8635", + "member_role": "promoter", + "data_set_id": "037aa3c2d6134de4868129a0397a2ed2", "features": [ + "x0", + "x1", + "x2", + "x3", "x4", "x5", "x6", - "x7" - ] + "x7", + "x8", + "x9", + "x10", + "x11", + "x12", + "x13", + "x14", + "x15", + "x16", + "x17", + "x18", + "x19", + "x20", + "x21", + "x22", + "x23", + "x24", + "x25", + "x26", + "x27", + "x28" + ], + "feature_name_list": "x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16,x17,x18,x19,x20,x21,x22,x23,x24,x25,x26,x27,x28", + "feature_count": 29, + "contains_y": true, + "row_count": 11851, + "name": "阳光保险" }, { - "memberId": "269330b21e1342b48124303e55c75ce0", + "member_id": "727c393686b74eb680a0ed8ba539f58a", + "member_role": "provider", + "data_set_id": "79a33d5dfc2e4dc7a452da928d5aa259", "features": [ - "id", - "y", - "x1" - ] + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7", + "x8", + "x9", + "x10", + "x11", + "x12", + "x13", + "x14", + "x15", + "x16", + "x17", + "x18", + "x19" + ], + "feature_name_list": "x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16,x17,x18,x19", + "feature_count": 20, + "contains_y": false, + "row_count": 569, + "name": "569-noy" } ] } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/AddApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/AddApi.java index 70f71f5f4..ddbe33d8a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/AddApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/AddApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -23,7 +23,6 @@ import com.welab.wefe.board.service.service.CacheObjects; import com.welab.wefe.board.service.service.ProjectService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.Launcher; @@ -31,6 +30,8 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectType; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -70,14 +71,17 @@ public void setProjectId(String projectId) { public static class Input extends AbstractApiInput { - @Check(name = "业务层面的项目ID", hiddenForFrontEnd = true) + @Check(name = "业务层面的项目ID", donotShow = true) private String projectId; - @Check(name = "所有成员列表", hiddenForFrontEnd = true) + @Check(name = "所有成员列表", donotShow = true) private List members; @Check(name = "项目名称", require = true) private String name; + @Check(name = "项目类型", require = true) + private ProjectType projectType; + @Check(name = "项目描述", require = true) private String desc; @@ -93,14 +97,13 @@ public static class Input extends AbstractApiInput { @Check(name = "角色") private JobMemberRole role; - @Override public void checkAndStandardize() throws StatusCodeWithException { super.checkAndStandardize(); // Project name cannot be repeated if (!super.fromGateway()) { - List allByName = Launcher.CONTEXT.getBean(ProjectRepository.class).findAllByName(name); + List allByName = Launcher.getBean(ProjectRepository.class).findAllByName(name); if (!allByName.isEmpty()) { StatusCode.PARAMETER_VALUE_INVALID.throwException( "这个项目名称已经被用过了哟~ 再想一个吧~" @@ -200,6 +203,14 @@ public void setRole(JobMemberRole role) { this.role = role; } + public ProjectType getProjectType() { + return projectType; + } + + public void setProjectType(ProjectType projectType) { + this.projectType = projectType; + } + //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/AuditApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/AuditApi.java index c64f3ffae..178ead0f1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/AuditApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/AuditApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,7 +18,6 @@ import com.welab.wefe.board.service.service.ProjectService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.AuditStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.StringUtil; @@ -26,6 +25,7 @@ import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.AuditStatus; import org.springframework.beans.factory.annotation.Autowired; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/CloseProjectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/CloseProjectApi.java index 22439ecb0..98f31cb9b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/CloseProjectApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/CloseProjectApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/CountStatisticsApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/CountStatisticsApi.java index bc76687cf..2baf38bff 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/CountStatisticsApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/CountStatisticsApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,12 +17,12 @@ package com.welab.wefe.board.service.api.project.project; import com.welab.wefe.board.service.service.ProjectService; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import java.util.Map; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/DataInfoApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/DataInfoApi.java index bf767976f..9b409aea5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/DataInfoApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/DataInfoApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/DetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/DetailApi.java index f5973f04e..816b25b50 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/DetailApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/DetailApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/QueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/QueryApi.java index 0de24ee93..3e600dea3 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/QueryApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/QueryApi.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -20,13 +20,14 @@ import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.dto.entity.project.ProjectQueryOutputModel; import com.welab.wefe.board.service.service.ProjectService; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectType; import org.springframework.beans.factory.annotation.Autowired; /** @@ -72,6 +73,9 @@ public static class Input extends PagingInput { @Check(name = "是否已关闭") private Boolean closed; + @Check(name = "项目类型") + private ProjectType projectType; + public String getName() { return name; } @@ -143,5 +147,13 @@ public Boolean getClosed() { public void setClosed(Boolean closed) { this.closed = closed; } + + public ProjectType getProjectType() { + return projectType; + } + + public void setProjectType(ProjectType projectType) { + this.projectType = projectType; + } } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/UpdateProjectApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/UpdateProjectApi.java index b5429a353..18d9a6687 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/UpdateProjectApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/UpdateProjectApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -62,7 +62,7 @@ public void checkAndStandardize() throws StatusCodeWithException { // Project name cannot be repeated if (!super.fromGateway()) { - List allByName = Launcher.CONTEXT.getBean(ProjectRepository.class).findAllByName(name); + List allByName = Launcher.getBean(ProjectRepository.class).findAllByName(name); if (!allByName.isEmpty()) { if (allByName.size() > 1 || !allByName.get(0).getProjectId().equals(projectId)) { StatusCode.PARAMETER_VALUE_INVALID.throwException( diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/add.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/add.http index a9e2d23b9..b561a279d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/add.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/add.http @@ -1,43 +1,49 @@ ### -POST {{baseUrl}}/project/add +POST http://localhost:8080/board-service/project/add Content-Type: application/json - +token: {{token}} { - "name": "重构项目", - "desc": "重构项目", - "dataSetList": [ + "name": "添加datatype", + "desc": "bug重现", + "projectType": "MachineLearning", + "promoterDataSetList": [ { - "member_id": "269330b21e1342b48124303e55c75ce0", - "data_set_keys": ",bbb,", - "data_set_name": "kkk-b", - "member_name": "弟依偎零三", "member_role": "promoter", - "data_set_id": "a4b9810d7cb34f6997381482e48c17a2", - "contains_y": true - }, - { - "member_id": "682d087e965a4c5d8a2807a772a571c1", - "data_set_keys": ",test,", - "data_set_name": "wingo-noy", - "data_set_column_num": "20", - "member_name": "弟依偎零四", - "member_role": "provider", - "data_set_rows": "568", - "data_set_id": "d9ea082b855141378ed257111a1a07e4", - "contains_y": false + "member_id": "290007c2a71d470ba00f486b18875d31", + "data_set_id": "8a675a8f7a2e48f6bcbaece3cdfca3d6", + "data_set_type": "TableDataSet" } ], - "memberList": [ + "providerList": [ { - "member_id": "269330b21e1342b48124303e55c75ce0", - "member_role": "promoter", - "member_name": "弟依偎零三" + "member_id": "d0f47307804844898ecfc65b875abe87", + "dataSetList": [ + { + "member_role": "provider", + "member_id": "d0f47307804844898ecfc65b875abe87", + "data_set_id": "c81ad8dbe793410db43594bfc36afb11", + "data_set_type": "TableDataSet" + } + ] }, { - "member_id": "682d087e965a4c5d8a2807a772a571c1", - "member_role": "provider", - "member_name": "弟依偎零四" + "member_id": "8896e74890a5459386287ec817e8b4f3", + "dataSetList": [ + { + "member_role": "provider", + "member_id": "8896e74890a5459386287ec817e8b4f3", + "data_set_id": "eea81ef35a9e45bbb18cad5c251b8280", + "data_set_type": "TableDataSet" + }, + { + "member_role": "provider", + "member_id": "8896e74890a5459386287ec817e8b4f3", + "data_set_id": "8c7e7065785844fcaf2cbcda4d44fba8", + "data_set_type": "TableDataSet" + } + ] } - ] + ], + "promoterList": [] } \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/audit.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/audit.http index 4a3c85cf4..c67f5e632 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/audit.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/audit.http @@ -1,5 +1,5 @@ ### -POST {{baseUrl}}/project/add/audit +POST http://localhost:8080/board-service/project/add/audit Content-Type: application/json diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/dataInfo.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/dataInfo.http index cc618f410..a34548e98 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/dataInfo.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/dataInfo.http @@ -1,4 +1,4 @@ -POST {{baseUrl}}/project/data/info +POST http://localhost:8080/board-service/project/data/info Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/detail.http index 2931697e4..824d7da98 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/detail.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/detail.http @@ -1,7 +1,8 @@ ### -POST {{baseUrl}}/project/detail +POST http://localhost:8080/board-service/project/detail Content-Type: application/json +token: {{token}} { - "projectId": "fd1d4a0df470442bb7ffea1c81846385" + "projectId": "340a01d6107247bd9bdf00b54274effe" } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/query.http index 95c5bc4cc..1d1bd2db5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/query.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/query.http @@ -1,5 +1,5 @@ ### -POST {{baseUrl}}/project/query +POST http://localhost:8080/board-service/project/query Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/statistics.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/statistics.http index a920e3387..190230774 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/statistics.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/project/project/test/statistics.http @@ -1,9 +1,9 @@ ### -POST {{baseUrl}}/project/count_statistics +POST http://localhost:8080/board-service/project/count_statistics Content-Type: application/json { } ### -GET {{baseUrl}}/project/count_statistics?name=&member_id=1b750c545b494242997e0d0a054fd853&member_name=%E5%85%8B%E5%8A%B3%E5%BE%B7002&member_role=&audit_status=&start_create_time=&end_create_time=&page_index=1&page_size=20&activeTab=myProjects \ No newline at end of file +GET http://localhost:8080/board-service/project/count_statistics?name=&member_id=1b750c545b494242997e0d0a054fd853&member_name=%E5%85%8B%E5%8A%B3%E5%BE%B7002&member_role=&audit_status=&start_create_time=&end_create_time=&page_index=1&page_size=20&activeTab=myProjects \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/server/AvailableApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/server/AvailableApi.java deleted file mode 100644 index bfa9f4321..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/server/AvailableApi.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.server; - -import com.welab.wefe.board.service.dto.vo.ServerCheckPointOutput; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.web.api.base.AbstractNoneInputApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; - -import java.util.List; - -/** - * @author zane - */ -@Api(path = "server/available", name = "list all checkpoint in board service to show its availability.") -public class AvailableApi extends AbstractNoneInputApi { - @Override - protected ApiResult handle() throws StatusCodeWithException { - return null; - } - - public static class Output { - public boolean success; - public List list; - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/service/AliveApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/service/AliveApi.java new file mode 100644 index 000000000..386283f48 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/service/AliveApi.java @@ -0,0 +1,35 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.service; + +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractNoneInputApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.ApiResult; + +/** + * @author zane + */ +@Api(path = "service/alive", name = "", login = false) +public class AliveApi extends AbstractNoneInputApi { + + @Override + protected ApiResult handle() throws StatusCodeWithException { + return success(); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/service/ServiceAvailableApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/service/ServiceAvailableApi.java new file mode 100644 index 000000000..75d7275c6 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/service/ServiceAvailableApi.java @@ -0,0 +1,54 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.service; + +import com.welab.wefe.board.service.service.ServiceCheckService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.checkpoint.dto.ServiceAvailableCheckOutput; +import com.welab.wefe.common.wefe.enums.ServiceType; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * @author zane + */ +@Api(path = "service/available", name = "list all checkpoint in board service to show its availability.") +public class ServiceAvailableApi extends AbstractApi { + + @Autowired + private ServiceCheckService serviceCheckService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + ServiceAvailableCheckOutput output = serviceCheckService.getServiceAvailableInfo(input.serviceType); + if (input.fromGateway()) { + output.cleanValues(); + } + return success(output); + } + + public static class Input extends AbstractApiInput { + @Check(name = "服务类型", require = true) + public ServiceType serviceType; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/service/test/available.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/service/test/available.http new file mode 100644 index 000000000..958ff51eb --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/service/test/available.http @@ -0,0 +1,9 @@ + +### +POST http://localhost:8080/board-service/service/available +Content-Type: application/json +token: {{token}} + +{ + "serviceType": "GatewayService" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/storage/PreviewDataSetApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/storage/PreviewDataSetApi.java index 6c6dd4a03..4dcde160b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/storage/PreviewDataSetApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/storage/PreviewDataSetApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,8 @@ package com.welab.wefe.board.service.api.storage; import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; -import com.welab.wefe.board.service.database.repository.DataSetRepository; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.TableDataSetRepository; import com.welab.wefe.board.service.service.DataSetStorageService; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import com.welab.wefe.common.StatusCode; @@ -39,26 +39,26 @@ /** * @author Zane */ -@Api(path = "storage/data_set/preview", name = "View data sets in storage") +@Api(path = "storage/table_data_set/preview", name = "View data sets in storage") public class PreviewDataSetApi extends AbstractApi { @Autowired DataSetStorageService dataSetStorageService; @Autowired - DataSetRepository dataSetRepository; + TableDataSetRepository dataSetRepository; @Autowired private GlobalConfigService globalConfigService; @Override protected ApiResult handle(Input input) throws StatusCodeWithException { - DataSetMysqlModel model = dataSetRepository.findById(input.getId()).orElse(null); + TableDataSetMysqlModel model = dataSetRepository.findById(input.getId()).orElse(null); if (model == null) { return success(); } List columns = StringUtil.splitWithoutEmptyItem(model.getColumnNameList(), ","); List> rows; - if (model.getSourceType() == null) { - rows = dataSetStorageService.previewDataSet(model.getNamespace(), model.getTableName(), 100); + if (!model.isDerivedResource()) { + rows = dataSetStorageService.previewDataSet(model.getStorageNamespace(), model.getStorageResourceName(), 100); } else { rows = getRowsFromFlow(model); } @@ -82,8 +82,13 @@ protected ApiResult handle(Input input) throws StatusCodeWithException { /** * View the data of the derived data set from flow service */ - private List> getRowsFromFlow(DataSetMysqlModel model) throws StatusCodeWithException { - String url = globalConfigService.getFlowConfig().intranetBaseUri + String.format("/data_set/view?table_name=%s&table_namespace=%s", model.getTableName(), model.getNamespace()); + private List> getRowsFromFlow(TableDataSetMysqlModel model) throws StatusCodeWithException { + String url = globalConfigService.getFlowConfig().intranetBaseUri + + String.format( + "/data_set/view?table_name=%s&table_namespace=%s", + model.getStorageResourceName(), + model.getStorageNamespace() + ); HttpResponse response = HttpRequest .create(url) diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/storage/test/storage-data_set-priview-.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/storage/test/storage-data_set-priview-.http index c398c3514..1c12c2c9d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/storage/test/storage-data_set-priview-.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/storage/test/storage-data_set-priview-.http @@ -1,6 +1,6 @@ ### 预览数据集中的数据 -POST {{baseUrl}}/storage/data_set/preview +POST http://localhost:8080/board-service/storage/data_set/preview Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/AbstractThroughUnionApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/AbstractThroughUnionApi.java new file mode 100644 index 000000000..128b209ff --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/AbstractThroughUnionApi.java @@ -0,0 +1,46 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.union; + + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.dto.NoneApiInput; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * 继承此类的接口不声明具体的入参和响应参数类型,尽数透传到 union,board 中不再做任何逻辑。 + * + * @author Zane + */ +public abstract class AbstractThroughUnionApi extends AbstractApi { + + @Autowired + private UnionService unionService; + + protected abstract String api(); + + @Override + protected ApiResult handle(NoneApiInput input) throws StatusCodeWithException { + JSONObject response = unionService.request(api(), input.rawRequestParams); + return super.unionApiResultToBoardApiResult(response); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/DataSetDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/DataSetDetailApi.java deleted file mode 100644 index c30f1170f..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/DataSetDetailApi.java +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.union; - -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; -import com.welab.wefe.board.service.sdk.UnionService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author Zane - */ -@Api(path = "union/data_set/detail", name = "Get data set details from union") -public class DataSetDetailApi extends AbstractApi { - - @Autowired - UnionService unionService; - - @Override - protected ApiResult handle(DataSetDetailApi.Input input) throws StatusCodeWithException { - - return success(unionService.queryDataSetDetail(input.getId())); - } - - public static class Input extends AbstractApiInput { - @Check(name = "数据集 Id", require = true) - private String id; - - //region getter/setter - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - - //endregion - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/DataSetTagListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/DataSetTagListApi.java deleted file mode 100644 index 4007c91c7..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/DataSetTagListApi.java +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.union; - -import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.board.service.dto.base.PagingInput; -import com.welab.wefe.board.service.sdk.UnionService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author Zane - */ -@Api(path = "union/data_set/tag/query", name = "Query the tags of the data set from the union") -public class DataSetTagListApi extends AbstractApi { - - @Autowired - UnionService unionService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - JSONObject result = unionService.queryDataSetTags(input); - return unionApiResultToBoardApiResult(result); - } - - public static class Input extends PagingInput { - @Check(name = "tag 名称") - private String tag; - - //region getter/setter - - public String getTag() { - return tag; - } - - public void setTag(String tag) { - this.tag = tag; - } - - - //endregion - - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/GetMemberMapApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/GetMemberMapApi.java new file mode 100644 index 000000000..1e094b0d4 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/GetMemberMapApi.java @@ -0,0 +1,30 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.union; + +import com.welab.wefe.common.web.api.base.Api; + +/** + * @author zane + * @date 2021/12/27 + */ +@Api(path = "union/member/map", name = "获取全量的成员信息列表") +public class GetMemberMapApi extends AbstractThroughUnionApi { + @Override + protected String api() { + return "member/map"; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/MemberListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/MemberListApi.java index a99e48532..14139f5df 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/MemberListApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/MemberListApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,7 +18,7 @@ import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.dto.base.PagingInput; -import com.welab.wefe.board.service.sdk.UnionService; +import com.welab.wefe.board.service.sdk.union.UnionService; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.api.base.AbstractApi; @@ -30,13 +30,13 @@ * @author Zane */ @Api(path = "union/member/query", name = "Query members from union") -public class MemberListApi extends AbstractApi { +public class MemberListApi extends AbstractApi { @Autowired - UnionService unionService; + private UnionService unionService; @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { + protected ApiResult handle(Input input) throws StatusCodeWithException { JSONObject result = unionService.queryMembers(input); return unionApiResultToBoardApiResult(result); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/QueryDataSetApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/QueryDataSetApi.java deleted file mode 100644 index ca18be823..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/QueryDataSetApi.java +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.union; - -import com.alibaba.fastjson.JSONArray; -import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.board.service.dto.base.PagingInput; -import com.welab.wefe.board.service.sdk.UnionService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -import java.util.ArrayList; - -/** - * @author Zane - */ -@Api(path = "union/data_set/query", name = "Query data set from union") -public class QueryDataSetApi extends AbstractApi { - - @Autowired - UnionService unionService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - JSONObject result = unionService.queryDataSets(input); - - JSONObject data = result.getJSONObject("data"); - - // Rename the fields of the data from the union so that the front end can access it uniformly. - JSONArray list = data.getJSONArray("list"); - if (list == null) { - data.put("list", new ArrayList<>()); - } - - return unionApiResultToBoardApiResult(result); - } - - public static class Input extends PagingInput { - - private String id; - - @Check(name = "数据集名称") - private String name; - - @Check(name = "标签名称") - private String tag; - - private String memberId; - - private Boolean containsY; - - - //region getter/setter - - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public String getTag() { - return tag; - } - - public void setTag(String tag) { - this.tag = tag; - } - - public String getMemberId() { - return memberId; - } - - public void setMemberId(String memberId) { - this.memberId = memberId; - } - - public Boolean getContainsY() { - return containsY; - } - - public void setContainsY(Boolean containsY) { - this.containsY = containsY; - } - - - //endregion - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/SendForgetPasswordSmsCodeApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/SendForgetPasswordSmsCodeApi.java deleted file mode 100644 index 756d14f4f..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/SendForgetPasswordSmsCodeApi.java +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.union; - -import com.welab.wefe.board.service.sdk.UnionService; -import com.welab.wefe.common.enums.SmsBusinessType; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; -import com.welab.wefe.common.web.dto.NoneApiOutput; -import org.springframework.beans.factory.annotation.Autowired; - -import java.io.IOException; - -/** - * @author aaron.li - * @date 2021/11/11 09:45 - **/ -@Api(path = "union/send_forget_password_sms_code", name = "send sms verification code", login = false) -public class SendForgetPasswordSmsCodeApi extends AbstractApi { - @Autowired - private UnionService unionService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { - unionService.sendVerificationCode(input.phoneNumber, SmsBusinessType.AccountForgetPasswordVerificationCode); - return success(); - } - - public static class Input extends AbstractApiInput { - @Check(require = true) - private String phoneNumber; - - public String getPhoneNumber() { - return phoneNumber; - } - - public void setPhoneNumber(String phoneNumber) { - this.phoneNumber = phoneNumber; - } - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/SendForgetPasswordVerificationCodeApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/SendForgetPasswordVerificationCodeApi.java new file mode 100644 index 000000000..75c185d4c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/SendForgetPasswordVerificationCodeApi.java @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.union; + +import com.welab.wefe.board.service.service.verificationcode.VerificationCodeService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.dto.NoneApiOutput; +import com.welab.wefe.common.wefe.enums.VerificationCodeBusinessType; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * Send forget password verification code + * + *

+ * Due to problems left over by history, the interface for sending verification code is temporarily placed here + *

+ * + * @author aaron.li + * @date 2021/11/11 09:45 + **/ +@Api(path = "union/send_forget_password_sms_code", name = "send sms verification code", login = false) +public class SendForgetPasswordVerificationCodeApi extends AbstractApi { + + @Autowired + private VerificationCodeService verificationCodeService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + verificationCodeService.send(input.phoneNumber, VerificationCodeBusinessType.accountForgetPassword); + return success(); + } + + public static class Input extends AbstractApiInput { + @Check(require = true) + private String phoneNumber; + + public String getPhoneNumber() { + return phoneNumber; + } + + public void setPhoneNumber(String phoneNumber) { + this.phoneNumber = phoneNumber; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/TagListApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/TagListApi.java deleted file mode 100644 index 68f4e977b..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/TagListApi.java +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.union; - -import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.board.service.dto.base.PagingInput; -import com.welab.wefe.board.service.sdk.UnionService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author lonnie - */ -@Api(path = "union/tag/query", name = "Get the default dataset tags from union") -public class TagListApi extends AbstractApi { - - @Autowired - UnionService unionService; - - @Override - protected ApiResult handle(Input input) throws StatusCodeWithException { - JSONObject result = unionService.queryTags(input); - return unionApiResultToBoardApiResult(result); - } - - public static class Input extends PagingInput { - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/UnionOnlineCheckApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/UnionOnlineCheckApi.java deleted file mode 100644 index 5333da09e..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/UnionOnlineCheckApi.java +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.api.union; - -import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.board.service.constant.Config; -import com.welab.wefe.board.service.sdk.UnionService; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.web.api.base.AbstractNoneInputApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import org.springframework.beans.factory.annotation.Autowired; - -/** - * @author lonnie - */ -@Api(path = "union/online/check", name = "Check the access status of the union") -public class UnionOnlineCheckApi extends AbstractNoneInputApi { - - @Autowired - private UnionService unionService; - - @Autowired - private Config config; - - @Override - protected ApiResult handle() throws StatusCodeWithException { - - OutPut outPut = new OutPut(); - outPut.setUnionUrl(config.getUNION_BASE_URL()); - - try { - - JSONObject unionResult = unionService.queryMember(0, 10); - int code = unionResult.getInteger("code"); - - outPut.setStatus(code == 0); - } catch (StatusCodeWithException e) { - outPut.setStatus(false); - return success(outPut); - } - - return success(outPut); - } - - public static class OutPut { - - private String unionUrl; - - private boolean status; - - public String getUnionUrl() { - return unionUrl; - } - - public void setUnionUrl(String unionUrl) { - this.unionUrl = unionUrl; - } - - public boolean isStatus() { - return status; - } - - public void setStatus(boolean status) { - this.status = status; - } - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/DataResourceDetailApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/DataResourceDetailApi.java new file mode 100644 index 000000000..8e36b63a2 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/DataResourceDetailApi.java @@ -0,0 +1,54 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.union.data_resource; + + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * @author Zane + */ +@Api(path = "union/data_resource/detail", name = "") +public class DataResourceDetailApi extends AbstractApi { + + @Autowired + private UnionService unionService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + JSONObject output = unionService.getDataResourceDetail(input.dataResourceId, input.dataResourceType, JSONObject.class); + return success(output); + } + + public static class Input extends AbstractApiInput { + @Check(name = "资源id", require = true) + public String dataResourceId; + @Check(name = "资源类型", require = true) + public DataResourceType dataResourceType; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/DataResourceQueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/DataResourceQueryApi.java new file mode 100644 index 000000000..70ab2b43b --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/DataResourceQueryApi.java @@ -0,0 +1,68 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.union.data_resource; + +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * @author zane + * @date 2021/12/17 + */ +@Api(path = "union/data_resource/query", name = "query data resource from union service") +public class DataResourceQueryApi extends AbstractApi { + + @Autowired + private UnionService unionService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException, IOException { + JSONObject result = unionService.request( + "data_resource/query", + input.rawRequestParams + ); + + JSONObject data = result.getJSONObject("data"); + if (data != null) { + JSONArray list = data.getJSONArray("list"); + if (list != null) { + for (int i = 0; i < list.size(); i++) { + JSONObject item = list.getJSONObject(i); + JSONObject extraData = item.getJSONObject("extra_data"); + if (extraData != null) { + item.putAll(extraData); + item.remove("extra_data"); + } + } + } + } + + return super.unionApiResultToBoardApiResult(result); + } + + public static class Input extends AbstractApiInput { + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/QueryDefaultTagApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/QueryDefaultTagApi.java new file mode 100644 index 000000000..afb17db99 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/QueryDefaultTagApi.java @@ -0,0 +1,33 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.union.data_resource.tag; + +import com.welab.wefe.board.service.api.union.AbstractThroughUnionApi; +import com.welab.wefe.common.web.api.base.Api; + +/** + * @author zane + * @date 2021/12/17 + */ +@Api(path = "union/data_resource/default_tag/query", name = "") +public class QueryDefaultTagApi extends AbstractThroughUnionApi { + private static final String API = "data_resource/default_tag/query"; + + @Override + protected String api() { + return API; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/QueryTagApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/QueryTagApi.java new file mode 100644 index 000000000..3e741a791 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/QueryTagApi.java @@ -0,0 +1,33 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.union.data_resource.tag; + +import com.welab.wefe.board.service.api.union.AbstractThroughUnionApi; +import com.welab.wefe.common.web.api.base.Api; + +/** + * @author zane + * @date 2021/12/17 + */ +@Api(path = "union/data_resource/tags/query", name = "") +public class QueryTagApi extends AbstractThroughUnionApi { + private static final String API = "data_resource/tags/query"; + + @Override + protected String api() { + return API; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/test/query-default-tag.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/test/query-default-tag.http new file mode 100644 index 000000000..89f44c126 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/test/query-default-tag.http @@ -0,0 +1,10 @@ + +### 查询全部 tag +POST http://localhost:8080/board-service/union/data_resource/default_tag/query +Content-Type: application/json +token:{{token}} + +{ + "dataResourceType": "TableDataSet" +} + diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/test/query-tag.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/test/query-tag.http new file mode 100644 index 000000000..ba6a28299 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/tag/test/query-tag.http @@ -0,0 +1,8 @@ + +### 查询全部 tag +POST http://localhost:8080/board-service/union/data_resource/tags/query +Content-Type: application/json +token:{{token}} + +{} + diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/test/detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/test/detail.http new file mode 100644 index 000000000..bb41b345c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/test/detail.http @@ -0,0 +1,20 @@ + +### +POST http://localhost:8080/board-service/union/data_resource/detail +Content-Type: application/json +token:{{token}} + +{ + "data_resource_type": "ImageDataSet", + "data_resource_id": "406bbcecbea64f22afe0407acc1cddf2" +} + +### +POST http://localhost:8080/board-service/union/data_resource/detail +Content-Type: application/json +token:{{token}} + +{ + "dataResourceId": "ce9961ae677d4cc68d902003c9d68c97", + "dataResourceType": "ImageDataSet" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/test/query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/test/query.http new file mode 100644 index 000000000..ce0eb42e6 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/data_resource/test/query.http @@ -0,0 +1,16 @@ +### +POST http://localhost:8080/board-service/union/data_resource/query +Content-Type: application/json +token:{{token}} + +{} + +### +POST http://localhost:8080/board-service/union/data_resource/query +Content-Type: application/json +token:{{token}} + +{ + "data_resource_type": "ImageDataSet", + "data_resource_id": "ce9961ae677d4cc68d902003c9d68c97" +} \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/DownloadFileApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/DownloadFileApi.java new file mode 100644 index 000000000..9f9ad14b0 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/DownloadFileApi.java @@ -0,0 +1,60 @@ +package com.welab.wefe.board.service.api.union.member_auth; + +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.util.RSAUtil; +import com.welab.wefe.common.util.UrlUtil; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpMethod; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import java.io.IOException; + +/** + * @Description: + * @author: yuxin.zhang + * @date: 2021/11/3 + */ +@Api(path = "union/download/file", name = "upload file") +public class DownloadFileApi extends AbstractApi> { + @Autowired + private Config config; + + @Override + protected ApiResult> handle(DownloadFileApi.Input input) throws StatusCodeWithException, IOException { + String url = config.getUnionBaseUrl() + "/download/file"; + + JObject params = JObject.create(); + String data = JObject.create("file_id", input.fileId).toJSONString(); + String sign; + try { + sign = RSAUtil.sign(data, CacheObjects.getRsaPrivateKey(), "UTF-8"); + } catch (Exception e) { + e.printStackTrace(); + throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); + } + params.put("member_id", CacheObjects.getMemberId()); + params.put("sign", sign); + params.put("data", data); + + url = UrlUtil.appendQueryParameters(url, params); + RequestEntity requestEntity = new RequestEntity<>(null, null, HttpMethod.GET, UrlUtil.createUri(url)); + + RestTemplate restTemplate = new RestTemplate(); + ResponseEntity response = restTemplate.exchange(requestEntity, byte[].class); + return success(response); + } + + public static class Input extends AbstractApiInput { + public String fileId; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberAuthTypeQueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberAuthTypeQueryApi.java new file mode 100644 index 000000000..fe3944066 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberAuthTypeQueryApi.java @@ -0,0 +1,44 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.union.member_auth; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author Zane + */ +@Api(path = "union/member/authtype/query", name = "Query member authtype from union") +public class MemberAuthTypeQueryApi extends AbstractApi { + + @Autowired + UnionService unionService; + + @Override + protected ApiResult handle(AbstractApiInput input) throws StatusCodeWithException { + JSONObject result = unionService.queryMemberAuthTypeList(); + return unionApiResultToBoardApiResult(result); + } + + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberFileUploadApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberFileUploadApi.java new file mode 100644 index 000000000..f9ab4ffc9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberFileUploadApi.java @@ -0,0 +1,69 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.api.union.member_auth; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractWithFilesApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author zane.luo + */ +@Api(path = "union/member/file/upload", name = "upload file") +public class MemberFileUploadApi extends AbstractApi { + @Autowired + private UnionService unionService; + + @Override + protected ApiResult handle(Input input) throws StatusCodeWithException { + JSONObject result = unionService.uploadFile( + input.files, + JObject.create("filename", input.filename).append("purpose", input.purpose) + ); + + return super.unionApiResultToBoardApiResult(result); + + } + + public static class Input extends AbstractWithFilesApiInput { + private String filename; + private String purpose; + + public String getFilename() { + return filename; + } + + public void setFilename(String filename) { + this.filename = filename; + } + + public String getPurpose() { + return purpose; + } + + public void setPurpose(String purpose) { + this.purpose = purpose; + } + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberRealnameAuthApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberRealnameAuthApi.java new file mode 100644 index 000000000..60a36f76a --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberRealnameAuthApi.java @@ -0,0 +1,84 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.union.member_auth; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; +import java.util.List; + +/** + * @author zane + * @date 2021/11/2 + */ +@Api(path = "union/member/realname/auth", name = "apply realname auth") +public class MemberRealnameAuthApi extends AbstractApi { + + @Autowired + private UnionService unionService; + + @Override + protected ApiResult handle(MemberRealnameAuthApi.Input input) throws StatusCodeWithException, IOException { + JSONObject result = unionService.realnameAuth(input); + return super.unionApiResultToBoardApiResult(result); + } + + public static class Input extends AbstractApiInput { + private String principalName; + private String authType; + private String description; + private List fileIdList; + + public String getPrincipalName() { + return principalName; + } + + public void setPrincipalName(String principalName) { + this.principalName = principalName; + } + + public String getAuthType() { + return authType; + } + + public void setAuthType(String authType) { + this.authType = authType; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public List getFileIdList() { + return fileIdList; + } + + public void setFileIdList(List fileIdList) { + this.fileIdList = fileIdList; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberRealnameAuthInfoQueryApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberRealnameAuthInfoQueryApi.java new file mode 100644 index 000000000..b3307492c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/MemberRealnameAuthInfoQueryApi.java @@ -0,0 +1,44 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.api.union.member_auth; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +/** + * @Description: + * @author: yuxin.zhang + * @date: 2021/11/3 + */ +@Api(path = "union/member/realname/authInfo/query", name = "realname auth agreement template query") +public class MemberRealnameAuthInfoQueryApi extends AbstractApi { + @Autowired + private UnionService unionService; + + @Override + protected ApiResult handle(AbstractApiInput input) throws StatusCodeWithException, IOException { + JSONObject result = unionService.realnameAuthInfoQuery(); + return unionApiResultToBoardApiResult(result); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/QueryRealnameAuthAgreementTemplateApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/QueryRealnameAuthAgreementTemplateApi.java new file mode 100644 index 000000000..aaed4f788 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/member_auth/QueryRealnameAuthAgreementTemplateApi.java @@ -0,0 +1,24 @@ +package com.welab.wefe.board.service.api.union.member_auth; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.web.dto.ApiResult; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; + +@Api(path = "union/realname/auth/agreement/template/query", name = "realname auth agreement template query") +public class QueryRealnameAuthAgreementTemplateApi extends AbstractApi { + @Autowired + private UnionService unionService; + + @Override + protected ApiResult handle(AbstractApiInput input) throws StatusCodeWithException, IOException { + JSONObject result = unionService.realnameAuthAgreementTemplateQuery(); + return unionApiResultToBoardApiResult(result); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/data_set-detail.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/data_set-detail.http deleted file mode 100644 index 01f56977c..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/data_set-detail.http +++ /dev/null @@ -1,9 +0,0 @@ - -### 查询全部数据集 -POST {{baseUrl}}/union/data_set/detail -Content-Type: application/json - -{ - "id": "766fb48a78ed4b80b684fd06301c8a84" -} - diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/data_set-query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/data_set-query.http deleted file mode 100644 index c799ef7c9..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/data_set-query.http +++ /dev/null @@ -1,7 +0,0 @@ - -### 查询全部数据集 -POST {{baseUrl}}/union/data_set/query -Content-Type: application/json - -{} - diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/member-map.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/member-map.http new file mode 100644 index 000000000..c3716308d --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/member-map.http @@ -0,0 +1,7 @@ + +### +POST http://localhost:8080/board-service/union/member/map +Content-Type: application/json +token:{{token}} + +{} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/member-query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/member-query.http index f45c186af..a339c9da0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/member-query.http +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/member-query.http @@ -1,7 +1,8 @@ ### 查询全部 member -POST {{baseUrl}}/union/member/query +POST http://localhost:8080/board-service/union/member/query Content-Type: application/json +token:{{token}} {} @@ -16,7 +17,7 @@ client.test("Request executed successfully", function() { ### 按名字查 -POST {{baseUrl}}/union/member/query +POST http://localhost:8080/board-service/union/member/query Content-Type: application/json { @@ -24,7 +25,7 @@ Content-Type: application/json } ### 按id查 -POST {{baseUrl}}/union/member/query +POST http://localhost:8080/board-service/union/member/query Content-Type: application/json { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/tag-query.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/tag-query.http deleted file mode 100644 index 3aa2dcd1a..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/tag-query.http +++ /dev/null @@ -1,24 +0,0 @@ - -### 查询全部 tag -POST {{baseUrl}}/union/tag/query -Content-Type: application/json - -{} - -> {% - -client.test("Request executed successfully", function() { - client.assert(response.body.code === 0, "Response code is not 0"); -}); - -%} - - - -### 按名字查 -POST {{baseUrl}}/union/tag/query -Content-Type: application/json - -{ - "tag": "xlsx" -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/union-check.http b/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/union-check.http deleted file mode 100644 index 5eab14cd6..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/api/union/test/union-check.http +++ /dev/null @@ -1,7 +0,0 @@ -### 检查union访问状态 -POST {{baseUrl}}/union/online/check -Content-Type: application/json - -{} - -### \ No newline at end of file diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/base/OnlineDemoApi.java b/board/board-service/src/main/java/com/welab/wefe/board/service/base/OnlineDemoApi.java index 42ef9cbf8..31dc580e8 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/base/OnlineDemoApi.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/base/OnlineDemoApi.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/base/file_system/WeFeFileSystem.java b/board/board-service/src/main/java/com/welab/wefe/board/service/base/file_system/WeFeFileSystem.java new file mode 100644 index 000000000..cc3366178 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/base/file_system/WeFeFileSystem.java @@ -0,0 +1,268 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.base.file_system; + +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.file.compression.impl.Zip; +import com.welab.wefe.common.file.decompression.SuperDecompressor; +import com.welab.wefe.common.file.decompression.dto.DecompressionResult; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * 统一管理所有的文件,合理分配各种文件的目录,及时清理不需要的过期文件。 + * + * @author zane + * @date 2022/2/15 + */ +public class WeFeFileSystem { + private static final Logger LOG = LoggerFactory.getLogger(WeFeFileSystem.class); + private static final Path ROOT_DIR = Paths.get(Launcher.getBean(Config.class).getFileUploadDir()); + + static { + File dir = ROOT_DIR.toFile(); + if (!dir.exists()) { + dir.mkdirs(); + } + } + + public static Path getRootDir() { + return ROOT_DIR; + } + + /** + * 文件用途 + */ + public enum UseType { + /** + * 临时目录,文件不会长时间保存。 + * todo:zane 此目录的文件会被自动回收。 + */ + Temp, + /** + * 添加数据资源 + */ + AddTableDataSet, + AddImageDataSet, + AddBloomFilter, + /** + * 调用深度学习模型 + */ + CallDeepLearningModel, + /** + * 下载深度学习模型 + */ + DownloadDeepLearningModel + } + + /** + * 获取文件上传路径 + */ + public static Path getBaseDir(UseType type) { + String childDir = StringUtil.stringToUnderLineLowerCase(type.name()); + return WeFeFileSystem.getRootDir().resolve(childDir); + } + + public static Path getFilePath(UseType type, String filename) { + return getBaseDir(type).resolve(filename); + } + + /** + * 获取资源上传完成后的完整路径 + */ + public static Path getFilePath(DataResourceType dataResourceType, String filename) { + return getFileDir(dataResourceType).resolve(filename); + } + + /** + * 获取资源上传完成后的所在目录 + */ + public static Path getFileDir(DataResourceType dataResourceType) { + switch (dataResourceType) { + case TableDataSet: + return getBaseDir(UseType.AddTableDataSet); + case ImageDataSet: + return getBaseDir(UseType.AddImageDataSet); + case BloomFilter: + return getBaseDir(UseType.AddBloomFilter); + default: + return WeFeFileSystem.getRootDir(); + } + + } + + public static class DownloadDeepLearningModel { + /** + * 获取下载中的模型文件 + */ + public static File getDownloadingModelFile(String taskId) { + return getBaseDir(UseType.DownloadDeepLearningModel).resolve(taskId + ".downloading").toFile(); + } + + /** + * 当模型下载完毕后,执行此操作。 + */ + public static File modelFileDownloadCompleted(String taskId) { + File downloadingModelFile = getDownloadingModelFile(taskId); + File modelFile = getModelFile(taskId); + + downloadingModelFile.renameTo(modelFile); + return modelFile; + } + + /** + * 获取下载完毕的模型文件 + */ + public static File getModelFile(String taskId) { + return getBaseDir(UseType.DownloadDeepLearningModel).resolve(taskId + ".model").toFile(); + } + } + + public static class CallDeepLearningModel { + + /** + * 获取上传的原始文件 + */ + public static File getRawFile(String filename) { + return getBaseDir(UseType.CallDeepLearningModel).resolve(filename).toFile(); + } + + /** + * 包含模型的zip文件 + */ + public static File getModelFile(String taskId) { + return getBaseDir(UseType.CallDeepLearningModel).resolve("model").resolve(taskId + ".zip").toFile(); + } + + /** + * 包含图片的zip文件 + */ + public static File getZipFile(String taskId, String sessionId) { + return getBaseDir(UseType.CallDeepLearningModel).resolve(taskId).resolve(sessionId + ".zip").toFile(); + } + + + /** + * 图片样本所在的目录: /CallDeepLearningModel/{taskId}/{sessionId} + */ + public static Path getImageSimpleDir(String taskId, String sessionId) { + return getBaseDir(UseType.CallDeepLearningModel).resolve(taskId).resolve(sessionId); + + } + + /** + * 将图片所在的文件夹压缩为 zip,供VisualFL下载。 + */ + public static File zipImageSimpleDir(String taskId, String sessionId) throws StatusCodeWithException { + File zipFile = getZipFile(taskId, sessionId); + if (zipFile.exists()) { + zipFile.delete(); + } + + // 压缩文件夹 + try { + new Zip().compression(getImageSimpleDir(taskId, sessionId), zipFile); + } catch (IOException e) { + LOG.error(e.getClass().getSimpleName() + " " + e.getMessage(), e); + StatusCode.FILE_IO_ERROR.throwException(e); + } + return zipFile; + } + + /** + * 将单张图片移到预定目录 + */ + public static void moveSingleImageToSessionDir(File rawFile, String taskId, String sessionId) throws StatusCodeWithException { + // 检查文件是否是图片 + if (!FileUtil.isImage(rawFile)) { + if (rawFile.exists()) { + rawFile.delete(); + } + StatusCode.PARAMETER_VALUE_INVALID.throwException("文件不是图片"); + } + + Path distDir = getImageSimpleDir(taskId, sessionId); + FileUtil.moveFile(rawFile, distDir.toString()); + + } + + /** + * 将上传的文件解压后移动到预定目录 + */ + public static int moveZipFileToSessionDir(File zipFile, String taskId, String sessionId) throws StatusCodeWithException { + if (!zipFile.exists()) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("未找到文件:" + zipFile.getAbsolutePath()); + } + + String suffix = FileUtil.getFileSuffix(zipFile); + if (!"zip".equalsIgnoreCase(suffix)) { + FileUtil.deleteFileOrDir(zipFile); + StatusCode.PARAMETER_VALUE_INVALID.throwException("不支持的文件类型:" + suffix); + } + + Path distDir = getImageSimpleDir(taskId, sessionId); + DecompressionResult result = null; + try { + result = SuperDecompressor.decompression(zipFile, distDir.toString(), true); + } catch (Exception e) { + LOG.error(e.getClass().getSimpleName() + " " + e.getMessage(), e); + StatusCode.FILE_IO_ERROR.throwException(e); + } + + // 安全起见,把非图片文件删除掉。 + int imageCount = 0; + for (File file : result.files) { + // 删除隐藏文件 + if (file.isHidden()) { + FileUtil.deleteFileOrDir(file); + } + // 删除不是图片的文件 + else if (!FileUtil.isImage(file)) { + FileUtil.deleteFileOrDir(file); + } else { + // 将文件移动到解压目录的根目录,避免zip包内有子文件导致路径不好管理。 + FileUtil.moveFile(file, distDir); + imageCount++; + } + } + + // 移除解压后的子目录 + for (File file : result.dirs) { + FileUtil.deleteFileOrDir(file); + } + + // 移除原始文件 + zipFile.delete(); + + if (imageCount == 0) { + FileUtil.deleteFileOrDir(distDir.toFile()); + StatusCode.PARAMETER_VALUE_INVALID.throwException("压缩包中没有图片文件!"); + } + return imageCount; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/Components.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/Components.java index 3f5f61b9e..2ee4bf888 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/Components.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/Components.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,10 +17,13 @@ package com.welab.wefe.board.service.component; import com.welab.wefe.board.service.component.base.AbstractComponent; +import com.welab.wefe.board.service.component.deep_learning.ImageDataIOComponent; +import com.welab.wefe.board.service.component.deep_learning.PaddleClassifyComponent; +import com.welab.wefe.board.service.component.deep_learning.PaddleDetectionComponent; import com.welab.wefe.board.service.component.feature.*; import com.welab.wefe.board.service.component.modeling.*; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -88,10 +91,19 @@ public class Components { private HorzNNComponent horzNNComponent; @Autowired private MixBinningComponent mixBinningComponent; + public static AbstractComponent getDataIOComponent() { return get(ComponentType.DataIO); } + @Autowired + private ImageDataIOComponent imageDataIOComponent; + @Autowired + private PaddleClassifyComponent paddleClassifyComponent; + @Autowired + private PaddleDetectionComponent paddleDetectionComponent; + + public static AbstractComponent get(ComponentType componentType) { switch (componentType) { @@ -100,41 +112,41 @@ public static AbstractComponent get(ComponentType componentType) { case HorzXGBoostValidationDataSetLoader: case VertXGBoostValidationDataSetLoader: case DataIO: - return Launcher.CONTEXT.getBean(Components.class).dataIOComponent; + return Launcher.getBean(Components.class).dataIOComponent; case Intersection: - return Launcher.CONTEXT.getBean(Components.class).intersectionComponent; + return Launcher.getBean(Components.class).intersectionComponent; case Evaluation: - return Launcher.CONTEXT.getBean(Components.class).evaluationComponent; + return Launcher.getBean(Components.class).evaluationComponent; case HorzLR: - return Launcher.CONTEXT.getBean(Components.class).horzLRComponent; + return Launcher.getBean(Components.class).horzLRComponent; case VertLR: - return Launcher.CONTEXT.getBean(Components.class).vertLRComponent; + return Launcher.getBean(Components.class).vertLRComponent; case Binning: - return Launcher.CONTEXT.getBean(Components.class).binningComponent; + return Launcher.getBean(Components.class).binningComponent; case HorzSecureBoost: - return Launcher.CONTEXT.getBean(Components.class).horzSecureBoostComponent; + return Launcher.getBean(Components.class).horzSecureBoostComponent; case VertSecureBoost: - return Launcher.CONTEXT.getBean(Components.class).vertSecureBoostComponent; + return Launcher.getBean(Components.class).vertSecureBoostComponent; case FeatureSelection: - return Launcher.CONTEXT.getBean(Components.class).featureSelectionComponent; + return Launcher.getBean(Components.class).featureSelectionComponent; case Segment: - return Launcher.CONTEXT.getBean(Components.class).segmentComponent; + return Launcher.getBean(Components.class).segmentComponent; case FeatureStatistic: - return Launcher.CONTEXT.getBean(Components.class).featureStatisticsComponent; + return Launcher.getBean(Components.class).featureStatisticsComponent; case FeatureCalculation: - return Launcher.CONTEXT.getBean(Components.class).featureCalculationComponent; + return Launcher.getBean(Components.class).featureCalculationComponent; case FillMissingValue: - return Launcher.CONTEXT.getBean(Components.class).fillMissingValueComponent; + return Launcher.getBean(Components.class).fillMissingValueComponent; case FeatureStandardized: - return Launcher.CONTEXT.getBean(Components.class).featureStandardizedComponent; + return Launcher.getBean(Components.class).featureStandardizedComponent; case VertPearson: - return Launcher.CONTEXT.getBean(Components.class).vertPearsonComponent; + return Launcher.getBean(Components.class).vertPearsonComponent; case MixLR: - return Launcher.CONTEXT.getBean(Components.class).mixLrComponent; + return Launcher.getBean(Components.class).mixLrComponent; case MixSecureBoost: - return Launcher.CONTEXT.getBean(Components.class).mixSecureBoostComponent; + return Launcher.getBean(Components.class).mixSecureBoostComponent; case MixStatistic: - return Launcher.CONTEXT.getBean(Components.class).mixStatisticComponent; + return Launcher.getBean(Components.class).mixStatisticComponent; case Oot: return Launcher.CONTEXT.getBean(Components.class).ootComponent; case VertFilter: @@ -144,19 +156,25 @@ public static AbstractComponent get(ComponentType componentType) { case HorzOneHot: return Launcher.CONTEXT.getBean(Components.class).horzOneHotComponent; case VertOneHot: - return Launcher.CONTEXT.getBean(Components.class).vertOneHotComponent; + return Launcher.getBean(Components.class).vertOneHotComponent; case VertPCA: - return Launcher.CONTEXT.getBean(Components.class).vertPCAComponent; + return Launcher.getBean(Components.class).vertPCAComponent; case HorzFeatureBinning: - return Launcher.CONTEXT.getBean(Components.class).horzFeatureBinningComponent; + return Launcher.getBean(Components.class).horzFeatureBinningComponent; case HorzStatistic: - return Launcher.CONTEXT.getBean(Components.class).horzStatisticComponent; + return Launcher.getBean(Components.class).horzStatisticComponent; case HorzNN: - return Launcher.CONTEXT.getBean(Components.class).horzNNComponent; + return Launcher.getBean(Components.class).horzNNComponent; case VertNN: - return Launcher.CONTEXT.getBean(Components.class).vertNNComponent; + return Launcher.getBean(Components.class).vertNNComponent; case MixBinning: - return Launcher.CONTEXT.getBean(Components.class).mixBinningComponent; + return Launcher.getBean(Components.class).mixBinningComponent; + case ImageDataIO: + return Launcher.getBean(Components.class).imageDataIOComponent; + case PaddleClassify: + return Launcher.getBean(Components.class).paddleClassifyComponent; + case PaddleDetection: + return Launcher.getBean(Components.class).paddleDetectionComponent; default: return null; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/DataIOComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/DataIOComponent.java index 513cac34a..3abe9a2c7 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/DataIOComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/DataIOComponent.java @@ -1,4 +1,4 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,11 +19,13 @@ import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.component.base.AbstractComponent; +import com.welab.wefe.board.service.component.base.dto.AbstractDataIOParam; +import com.welab.wefe.board.service.component.base.dto.AbstractDataSetItem; import com.welab.wefe.board.service.component.base.io.IODataType; import com.welab.wefe.board.service.component.base.io.InputMatcher; import com.welab.wefe.board.service.component.base.io.Names; import com.welab.wefe.board.service.component.base.io.OutputItem; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; import com.welab.wefe.board.service.database.entity.job.JobMemberMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; @@ -31,15 +33,14 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.board.service.service.DataSetService; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskResultType; -import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -56,7 +57,7 @@ public class DataIOComponent extends AbstractComponent { @Autowired - private DataSetService dataSetService; + private TableDataSetService tableDataSetService; @Override public ComponentType taskType() { @@ -94,8 +95,8 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas } DataSetItem promoterProjectDataSet = params.getDataSetList().stream().filter(x -> x.memberId.equals(promoter.getMemberId())).findFirst().orElse(null); - DataSetMysqlModel promoterDataSet = dataSetService.findOne(promoterProjectDataSet.dataSetId); - if (!promoterDataSet.getContainsY()) { + TableDataSetMysqlModel promoterDataSet = tableDataSetService.findOneById(promoterProjectDataSet.dataSetId); + if (!promoterDataSet.isContainsY()) { throw new FlowNodeException(node, "promoter 的数据集必须包含 y 值"); } @@ -105,7 +106,7 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas continue; } - DataSetMysqlModel one = dataSetService.findOne(dataSet.getDataSetId()); + TableDataSetMysqlModel one = tableDataSetService.findOneById(dataSet.getDataSetId()); if (one == null) { throw new FlowNodeException(node, "成员 " + CacheObjects.getMemberName(dataSet.memberId) + " 的数据集 " + dataSet.getDataSetId() + " 不存在,请检查是否已删除。"); } @@ -144,8 +145,6 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT } // Create the input parameters of the components in the kernel according to the component parameter settings in the interface - JSONObject taskParam = new JSONObject(); - DataSetItem myDataSetConfig = params.getDataSetList() .stream() .filter(x -> x.getMemberId().equals(CacheObjects.getMemberId()) && x.getMemberRole() == graph.getJob().getMyRole()) @@ -156,29 +155,25 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT throw new FlowNodeException(node, "请保存自己的数据集信息。"); } - DataSetMysqlModel myDataSet = dataSetService.findOne(myDataSetConfig.dataSetId); + TableDataSetMysqlModel myDataSet = tableDataSetService.findOneById(myDataSetConfig.dataSetId); if (myDataSet == null) { throw new FlowNodeException(node, "找不到自己的数据集。"); } - JObject dataIoParam = JObject + JObject output = JObject .create() .append("data_set_id", myDataSet.getId()) - .append("with_label", myDataSet.getContainsY()) + .append("with_label", myDataSet.isContainsY()) .append("label_name", "y") - .append("namespace", myDataSet.getNamespace()) - .append("name", myDataSet.getTableName()) + .append("namespace", myDataSet.getStorageNamespace()) + .append("name", myDataSet.getStorageResourceName()) .append("need_features", myDataSetConfig.features); - // DataIOParam - taskParam.put("params", dataIoParam); - - return taskParam; + return output; } @Override protected List getAllResult(String taskId) { - return taskResultService.listAllResult(taskId); } @@ -227,7 +222,7 @@ public List outputs(FlowGraph graph, FlowGraphNode node) throws Flow return new ArrayList<>(); } - DataSetMysqlModel myDataSet = params.getMyDataSet(); + TableDataSetMysqlModel myDataSet = params.getMyDataSet(); if (myDataSet == null) { throw new FlowNodeException(node, CacheObjects.getMemberName() + " 的数据集已被删除,不能加载已删除的数据集。"); } @@ -235,13 +230,12 @@ public List outputs(FlowGraph graph, FlowGraphNode node) throws Flow return Arrays.asList(OutputItem.of(Names.Data.NORMAL_DATA_SET, IODataType.DataSetInstance)); } - public static class Params extends AbstractCheckModel { - private List dataSetList; + public static class Params extends AbstractDataIOParam { /** * Find my data set object information from the configuration list */ - public DataSetMysqlModel getMyDataSet() { + public TableDataSetMysqlModel getMyDataSet() { DataSetItem myDataSetConfig = getMyDataSetConfig(); @@ -249,10 +243,10 @@ public DataSetMysqlModel getMyDataSet() { return null; } - DataSetMysqlModel myDataSet = Launcher + TableDataSetMysqlModel myDataSet = Launcher .CONTEXT - .getBean(DataSetService.class) - .findOne(myDataSetConfig.getDataSetId()); + .getBean(TableDataSetService.class) + .findOneById(myDataSetConfig.getDataSetId()); return myDataSet; } @@ -269,56 +263,14 @@ public DataSetItem getMyDataSetConfig() { } - //region getter/setter - - public List getDataSetList() { - return dataSetList; - } - - public void setDataSetList(List dataSetList) { - this.dataSetList = dataSetList; - } - - //endregion - } - public static class DataSetItem extends AbstractCheckModel { - @Check(name = "成员Id", require = true) - private String memberId; - @Check(name = "成员角色", require = true) - private JobMemberRole memberRole; - @Check(name = "数据集 Id", require = true) - private String dataSetId; + public static class DataSetItem extends AbstractDataSetItem { @Check(name = "选择的特征列") private List features; //region getter/setter - public String getMemberId() { - return memberId; - } - - public void setMemberId(String memberId) { - this.memberId = memberId; - } - - public JobMemberRole getMemberRole() { - return memberRole; - } - - public void setMemberRole(JobMemberRole memberRole) { - this.memberRole = memberRole; - } - - public String getDataSetId() { - return dataSetId; - } - - public void setDataSetId(String dataSetId) { - this.dataSetId = dataSetId; - } - public List getFeatures() { return features; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/EvaluationComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/EvaluationComponent.java index cec425eca..f31f18188 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/EvaluationComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/EvaluationComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -25,13 +25,13 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.TaskService; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -73,16 +73,12 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT return null; } - JSONObject taskParam = new JSONObject(); - // Reassembly parameters - JObject evaluationParam = JObject.create(); - evaluationParam.append("eval_type", params.getEvalType()) + JObject output = JObject.create(); + output.append("eval_type", params.getEvalType()) .append("pos_label", params.getPosLabel()); - taskParam.put("params", evaluationParam); - - return taskParam; + return output; } @Override @@ -91,7 +87,7 @@ protected List getAllResult(String taskId) { } @Override - protected TaskResultMySqlModel getResult(String taskId, String type) { + protected TaskResultMySqlModel getResult(String taskId, String type) throws StatusCodeWithException { TaskResultMySqlModel trainTaskResult = taskResultService.findByTaskIdAndType(taskId, TaskResultType.metric_train.name()); TaskResultMySqlModel validateTaskResult = taskResultService.findByTaskIdAndType(taskId, TaskResultType.metric_validate.name()); @@ -114,68 +110,65 @@ protected TaskResultMySqlModel getResult(String taskId, String type) { JObject result = JObject.create(); - try { - // Find out all the same branch nodes with the evaluation node - // and find the modeling node from them - // (this method solves the problem of null pointer when the evaluation node is deleted in the original editing process again) - List homologousBranchTaskList = taskService.findHomologousBranchByJobId(taskResultMySqlModel.getJobId(), trainTaskResult.getRole(), taskResultMySqlModel.getTaskId()); - TaskMySqlModel modelingTask = homologousBranchTaskList.stream().filter(x -> MODEL_COMPONENT_TYPE_LIST.contains(x.getTaskType())).findFirst().orElse(null); - - String modelComponentType = modelingTask.getTaskType().toString(); - String modelNodeId = modelingTask.getFlowNodeId(); - String suffix = ""; - if (!taskId.endsWith(taskResultMySqlModel.getFlowNodeId())) { - suffix = "_" + taskId.split("_")[taskId.split("_").length - 1]; - } - // Start parsing the required result data - String normalName = modelComponentType + "_" + modelNodeId + suffix; - String preValidateName = "validate_" + modelComponentType + "_" + modelNodeId + suffix; - String preTrainName = "train_" + modelComponentType + "_" + modelNodeId + suffix; - - JObject validate = validateObj.getJObject(preValidateName); - JObject train = trainObj.getJObject(preTrainName); - - result.append("validate", validate) - .append("train", train); - - switch (type) { - case "ks": - result.putAll(parserTrainCurveData(trainObj, "ks_fpr", normalName)); - result.putAll(parserValidateCurveData(validateObj, "ks_fpr", normalName)); - result.putAll(parserTrainCurveData(trainObj, "ks_tpr", normalName)); - result.putAll(parserValidateCurveData(validateObj, "ks_tpr", normalName)); - break; - case "lift": - result.putAll(parserTrainCurveData(trainObj, "lift", normalName)); - result.putAll(parserValidateCurveData(validateObj, "lift", normalName)); - break; - case "gain": - result.putAll(parserTrainCurveData(trainObj, "gain", normalName)); - result.putAll(parserValidateCurveData(validateObj, "gain", normalName)); - break; - case "accuracy": - result.putAll(parserTrainCurveData(trainObj, "accuracy", normalName)); - result.putAll(parserValidateCurveData(validateObj, "accuracy", normalName)); - break; - case "precision_recall": - result.putAll(parserTrainCurveData(trainObj, "precision", normalName)); - result.putAll(parserValidateCurveData(validateObj, "precision", normalName)); - result.putAll(parserTrainCurveData(trainObj, "recall", normalName)); - result.putAll(parserValidateCurveData(validateObj, "recall", normalName)); - break; - case "roc": - result.putAll(parserTrainCurveData(trainObj, "roc", normalName)); - result.putAll(parserValidateCurveData(validateObj, "roc", normalName)); - break; - case "topn": - result.putAll(parserTopN(trainObj, normalName, "train")); - result.putAll(parserTopN(validateObj, normalName, "validate")); - default: - break; - - } - } catch (StatusCodeWithException e) { - e.printStackTrace(); + + // Find out all the same branch nodes with the evaluation node + // and find the modeling node from them + // (this method solves the problem of null pointer when the evaluation node is deleted in the original editing process again) + List homologousBranchTaskList = taskService.findHomologousBranchByJobId(taskResultMySqlModel.getJobId(), trainTaskResult.getRole(), taskResultMySqlModel.getTaskId()); + TaskMySqlModel modelingTask = homologousBranchTaskList.stream().filter(x -> MODEL_COMPONENT_TYPE_LIST.contains(x.getTaskType())).findFirst().orElse(null); + + String modelComponentType = modelingTask.getTaskType().toString(); + String modelNodeId = modelingTask.getFlowNodeId(); + String suffix = ""; + if (!taskId.endsWith(taskResultMySqlModel.getFlowNodeId())) { + suffix = "_" + taskId.split("_")[taskId.split("_").length - 1]; + } + // Start parsing the required result data + String normalName = modelComponentType + "_" + modelNodeId + suffix; + String preValidateName = "validate_" + modelComponentType + "_" + modelNodeId + suffix; + String preTrainName = "train_" + modelComponentType + "_" + modelNodeId + suffix; + + JObject validate = validateObj.getJObject(preValidateName); + JObject train = trainObj.getJObject(preTrainName); + + result.append("validate", validate) + .append("train", train); + + switch (type) { + case "ks": + result.putAll(parserTrainCurveData(trainObj, "ks_fpr", normalName)); + result.putAll(parserValidateCurveData(validateObj, "ks_fpr", normalName)); + result.putAll(parserTrainCurveData(trainObj, "ks_tpr", normalName)); + result.putAll(parserValidateCurveData(validateObj, "ks_tpr", normalName)); + break; + case "lift": + result.putAll(parserTrainCurveData(trainObj, "lift", normalName)); + result.putAll(parserValidateCurveData(validateObj, "lift", normalName)); + break; + case "gain": + result.putAll(parserTrainCurveData(trainObj, "gain", normalName)); + result.putAll(parserValidateCurveData(validateObj, "gain", normalName)); + break; + case "accuracy": + result.putAll(parserTrainCurveData(trainObj, "accuracy", normalName)); + result.putAll(parserValidateCurveData(validateObj, "accuracy", normalName)); + break; + case "precision_recall": + result.putAll(parserTrainCurveData(trainObj, "precision", normalName)); + result.putAll(parserValidateCurveData(validateObj, "precision", normalName)); + result.putAll(parserTrainCurveData(trainObj, "recall", normalName)); + result.putAll(parserValidateCurveData(validateObj, "recall", normalName)); + break; + case "roc": + result.putAll(parserTrainCurveData(trainObj, "roc", normalName)); + result.putAll(parserValidateCurveData(validateObj, "roc", normalName)); + break; + case "topn": + result.putAll(parserTopN(trainObj, normalName, "train")); + result.putAll(parserTopN(validateObj, normalName, "validate")); + default: + break; + } taskResultMySqlModel.setResult(result.toJSONString()); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/IntersectionComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/IntersectionComponent.java index e8bc797d1..64a4216b2 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/IntersectionComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/IntersectionComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -27,11 +27,11 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.springframework.beans.BeanUtils; import org.springframework.stereotype.Service; @@ -58,17 +58,13 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - - JSONObject taskParam = new JSONObject(); - // Reassemble front-end parameters - JObject intersectionParam = JObject.create(); - intersectionParam.append("intersect_method", params.getIntersectMethod()) + JObject output = JObject.create(); + output + .append("intersect_method", params.getIntersectMethod()) .append("save_dataset", params.isSaveDataSet()); - taskParam.put("params", intersectionParam); - - return taskParam; + return output; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/OotComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/OotComponent.java index bbf94e1ef..152615a98 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/OotComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/OotComponent.java @@ -1,11 +1,11 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,28 +19,28 @@ import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.api.project.flow.QueryDataIoTaskConfigApi; -import com.welab.wefe.board.service.api.project.member.ListApi; +import com.welab.wefe.board.service.api.project.member.ListInProjectApi; import com.welab.wefe.board.service.component.base.AbstractComponent; +import com.welab.wefe.board.service.component.base.dto.AbstractDataIOParam; import com.welab.wefe.board.service.component.base.io.IODataType; import com.welab.wefe.board.service.component.base.io.InputMatcher; import com.welab.wefe.board.service.component.base.io.Names; import com.welab.wefe.board.service.component.base.io.OutputItem; import com.welab.wefe.board.service.constant.Config; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; import com.welab.wefe.board.service.database.entity.job.*; -import com.welab.wefe.board.service.dto.kernel.Env; -import com.welab.wefe.board.service.dto.kernel.KernelJob; -import com.welab.wefe.board.service.dto.kernel.TaskConfig; +import com.welab.wefe.board.service.dto.kernel.machine_learning.Env; +import com.welab.wefe.board.service.dto.kernel.machine_learning.KernelJob; +import com.welab.wefe.board.service.dto.kernel.machine_learning.TaskConfig; import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.*; -import com.welab.wefe.common.enums.*; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; -import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.*; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -75,7 +75,7 @@ public class OotComponent extends AbstractComponent { private Config config; @Autowired - private DataSetService dataSetService; + private TableDataSetService tableDataSetService; @Autowired private TaskService taskService; @Autowired @@ -101,13 +101,13 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas throw new FlowNodeException(node, "请保存成员[" + CacheObjects.getMemberName() + "]的数据集信息。"); } - DataSetMysqlModel dataSetMysqlModel = dataSetService.findOne(myDataSetConfig.getDataSetId()); - if (null == dataSetMysqlModel) { + TableDataSetMysqlModel TableDataSetMysqlModel = tableDataSetService.findOneById(myDataSetConfig.getDataSetId()); + if (null == TableDataSetMysqlModel) { throw new FlowNodeException(node, "成员[" + CacheObjects.getMemberName() + "]选择的数据集信息不存在。"); } // All characteristic columns of the dataset I selected - List myFeatureNameList = Arrays.asList(dataSetMysqlModel.getFeatureNameList().split(",")); + List myFeatureNameList = Arrays.asList(TableDataSetMysqlModel.getFeatureNameList().split(",")); List taskMySqlModelList = preTasks; // Dataio task component @@ -156,7 +156,7 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas @Override - protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { + protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws StatusCodeWithException { if (graph.getJob().getMyRole() == JobMemberRole.arbiter) { return null; } @@ -166,14 +166,14 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT DataIOComponent.DataSetItem myDataSetConfig = getMyDataSetConfig(graph, params); boolean isSelectedMyself = (null != myDataSetConfig); - DataSetMysqlModel myDataSet = null; + TableDataSetMysqlModel myDataSet = null; if (FederatedLearningType.vertical.equals(graph.getFederatedLearningType()) || isOotMode) { if (!isSelectedMyself) { throw new FlowNodeException(node, "请保存成员[" + CacheObjects.getMemberName() + "]的数据集信息。"); } } if (isSelectedMyself) { - myDataSet = dataSetService.findOne(myDataSetConfig.getDataSetId()); + myDataSet = tableDataSetService.findOneById(myDataSetConfig.getDataSetId()); if (myDataSet == null) { throw new FlowNodeException(node, "找不到成员[" + CacheObjects.getMemberName() + "]的数据集。"); } @@ -229,10 +229,10 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT JObject taskConfigObj = JObject.create(JObject.toJSONString(taskConfig)); // If it is a dataio component, replace it with a new dataset if (DATA_IO_COMPONENT_TYPE_LIST.contains(taskType)) { - newDataIoParam.append("with_label", isSelectedMyself ? myDataSet.getContainsY() : false) + newDataIoParam.append("with_label", isSelectedMyself ? myDataSet.isContainsY() : false) .append("label_name", "y") - .append("namespace", isSelectedMyself ? myDataSet.getNamespace() : taskConfigObj.getStringByPath("params.namespace")) - .append("name", isSelectedMyself ? myDataSet.getTableName() : taskConfigObj.getStringByPath("params.name")) + .append("namespace", isSelectedMyself ? myDataSet.getStorageNamespace() : taskConfigObj.getStringByPath("params.namespace")) + .append("name", isSelectedMyself ? myDataSet.getStorageResourceName() : taskConfigObj.getStringByPath("params.name")) .append("need_features", JObject.parseArray(taskConfigObj.getStringByPath("params.need_features")).toJavaList(String.class)); taskConfigObj.put("params", newDataIoParam); } else if (MODEL_COMPONENT_TYPE_LIST.contains(taskType)) { @@ -254,14 +254,14 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT } // Create input parameters for OOT components - JObject ootParam = JObject.create(newDataIoParam) + JObject output = JObject.create(newDataIoParam) .append("flow_node_id", node.getNodeId()) .append("task_id", node.createTaskId(graph.getJob())) .append("sub_component_name_list", subTaskNameList) .append("sub_component_task_config_dick", subTaskConfigMap); // OotParam - return JObject.create().append("params", ootParam); + return output; } @Override @@ -382,11 +382,14 @@ private void checkSelectedFeatures(FlowGraph graph, FlowGraphNode node, Params p input.setJobId(params.jobId); input.setRole(jobMemberRole); try { - ApiResult apiResult = gatewayService.sendToBoardRedirectApi(memberId, JobMemberRole.promoter, input, QueryDataIoTaskConfigApi.class); - if (0 != apiResult.code) { - throw new FlowNodeException(node, "获取成员[" + memberName + "]的原入模特征列失败,原因:" + apiResult.message); - } - JObject data = JObject.create(apiResult.data); + + JObject data = gatewayService.callOtherMemberBoard( + memberId, + JobMemberRole.promoter, + QueryDataIoTaskConfigApi.class, + input, + JObject.class + ); if (null == data || data.isEmpty()) { throw new FlowNodeException(node, "获取成员[" + memberName + "]的原入模特征列为空。"); } @@ -448,7 +451,7 @@ private void checkSelectedMembersValid(FlowGraph graph, FlowGraphNode node, Para if (null == jobMySqlModel) { throw new FlowNodeException(node, "找不到原流程任务信息"); } - ListApi.Input input = new ListApi.Input(); + ListInProjectApi.Input input = new ListInProjectApi.Input(); input.setProjectId(jobMySqlModel.getProjectId()); input.setOotJobId(params.jobId); try { @@ -642,24 +645,12 @@ private String findComponentTaskId(JObject evaluationObj) { } - private void updateKernelJob(TaskConfig taskConfig, Params params) { + private void updateKernelJob(TaskConfig taskConfig, Params params) throws StatusCodeWithException { KernelJob kernelJob = taskConfig.getJob(); - kernelJob.setEnv(createEvn()); + kernelJob.setEnv(Env.get()); kernelJob.setFederatedLearningMode(FederatedLearningModel.oot); } - /** - * Create sub component running environment - */ - private Env createEvn() { - Env env = new Env(); - env.setBackend(config.getBackend()); - env.setDbType(config.getDbType()); - env.setWorkMode(config.getWorkMode()); - env.setName(config.getEnvName()); - return env; - } - /** * Is it OOT mode @@ -672,8 +663,7 @@ private boolean isOotMode(Params params) { return StringUtil.isNotEmpty(params.getJobId()); } - public static class Params extends AbstractCheckModel { - private List dataSetList; + public static class Params extends AbstractDataIOParam { /** * Specify jobid to create OOT component (used in OOT mode) */ @@ -691,14 +681,6 @@ public static class Params extends AbstractCheckModel { */ private Integer posLabel; - public List getDataSetList() { - return dataSetList; - } - - public void setDataSetList(List dataSetList) { - this.dataSetList = dataSetList; - } - public String getJobId() { return jobId; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/SegmentComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/SegmentComponent.java index b7a60530e..a5217f904 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/SegmentComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/SegmentComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -27,12 +27,12 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.springframework.stereotype.Service; import java.util.Arrays; @@ -64,16 +64,14 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - // Reassemble front-end parameters - JObject segmentParam = JObject.create(); + JObject output = JObject.create(); FederatedLearningType federatedLearningType = graph.getJob().getFederatedLearningType(); if (federatedLearningType == FederatedLearningType.vertical) { - segmentParam.append("mode", "vert"); + output.append("mode", "vert"); } else if (federatedLearningType == FederatedLearningType.horizontal) { - segmentParam.append("mode", "horz"); + output.append("mode", "horz"); } // Take with_label from the task parameter of the dataIO component @@ -87,15 +85,13 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT JObject taskConfig = JObject.create(dataIOTask.getTaskConf()); boolean withLabel = taskConfig.getJObject("params").getBooleanValue("with_label"); - segmentParam.append("random_num", params.getSplitDataRandomNum()) + output.append("random_num", params.getSplitDataRandomNum()) .append("train_ratio", params.getTrainingRatio() / (params.getTrainingRatio() + params.getVerificationRatio())) .append("with_label", withLabel) .append("label_name", "y") .append("label_type", "int"); - taskParam.put("params", segmentParam); - - return taskParam; + return output; } @Override @@ -116,25 +112,38 @@ protected TaskResultMySqlModel getResult(String taskId, String type) { int trainCount = resultObj.getIntegerByPath("train_eval_segment.data.train_count.value", 0); // Number of positive training examples int trainyPositiveExampleCount = resultObj.getIntegerByPath("train_eval_segment.data.train_y_positive_example_count.value", 0); - + // Number of negative training examples + int trainyNegativeExampleCount = trainCount - trainyPositiveExampleCount; // Proportion of training positive examples double trainyPositiveExampleRatio = resultObj.getDoubleByPath("train_eval_segment.data.train_y_positive_example_ratio.value", 0d); - int evalCount = resultObj.getIntegerByPath("train_eval_segment.data.eval_count.value", 0); + // Proportion of training negative examples + double trainyNegativeExampleRatio = 1 - trainyPositiveExampleRatio; + int evalCount = resultObj.getIntegerByPath("train_eval_segment.data.eval_count.value", 0); // Verify the number of positive examples int evalyPositiveExampleCount = resultObj.getIntegerByPath("train_eval_segment.data.eval_y_positive_example_count.value", 0); - + // Verify the number of negative examples + int evalyNegativeExampleCount = evalCount - evalyPositiveExampleCount; // Verify the proportion of positive cases double evalyPositiveExampleVatio = resultObj.getDoubleByPath("train_eval_segment.data.eval_y_positive_example_ratio.value", 0d); + // Verify the proportion of negative cases + double evalyNegativeExampleVatio = 1 - evalyPositiveExampleVatio; + resultModel.setResult(JObject.create() .append("contains_y", withLabel) .append("train_count", trainCount) .append("train_y_positive_example_count", trainyPositiveExampleCount) .append("train_y_positive_example_ratio", trainyPositiveExampleRatio) + .append("train_y_negative_example_count", trainyNegativeExampleCount) + .append("train_y_negative_example_ratio", trainyNegativeExampleRatio) .append("eval_count", evalCount) .append("eval_y_positive_example_count", evalyPositiveExampleCount) - .append("eval_y_positive_example_ratio", evalyPositiveExampleVatio).toJSONString()); + .append("eval_y_positive_example_ratio", evalyPositiveExampleVatio) + .append("eval_y_negative_example_count", evalyNegativeExampleCount) + .append("eval_y_negative_example_ratio", evalyNegativeExampleVatio).toJSONString() + + ); return resultModel; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/AbstractComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/AbstractComponent.java index 31ea1b7a1..bd239c2d3 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/AbstractComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/AbstractComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,11 +18,19 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -39,28 +47,31 @@ import com.welab.wefe.board.service.component.base.io.InputMatcher; import com.welab.wefe.board.service.component.base.io.NodeOutputItem; import com.welab.wefe.board.service.component.base.io.OutputItem; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.job.ProjectMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; import com.welab.wefe.board.service.database.repository.TaskRepository; import com.welab.wefe.board.service.dto.entity.job.TaskResultOutputModel; -import com.welab.wefe.board.service.dto.kernel.KernelJob; -import com.welab.wefe.board.service.dto.kernel.KernelTask; import com.welab.wefe.board.service.dto.kernel.Member; -import com.welab.wefe.board.service.dto.kernel.TaskConfig; +import com.welab.wefe.board.service.dto.kernel.machine_learning.KernelJob; +import com.welab.wefe.board.service.dto.kernel.machine_learning.KernelTask; +import com.welab.wefe.board.service.dto.kernel.machine_learning.TaskConfig; import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; import com.welab.wefe.board.service.service.JobService; import com.welab.wefe.board.service.service.TaskResultService; -import com.welab.wefe.board.service.util.ModelMapper; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskStatus; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectType; +import com.welab.wefe.common.wefe.enums.TaskStatus; /** * @author zane.luo @@ -101,6 +112,8 @@ public abstract class AbstractComponent { protected TaskResultService taskResultService; @Autowired protected TaskRepository taskRepository; + @Autowired + protected GlobalConfigService globalConfigService; /** * create mix flow task @@ -109,7 +122,7 @@ public abstract class AbstractComponent { * @param preTasks pre task list * @param node node */ - public List buildMixTask(FlowGraph graph, List preTasks, KernelJob jobInfo, FlowGraphNode node) throws StatusCodeWithException { + public List buildMixTask(FlowGraph graph, List preTasks, KernelJob jobInfo, FlowGraphNode node) throws Exception { T params = (T) node.getParamsModel(); @@ -163,20 +176,19 @@ public List buildMixTask(FlowGraph graph, List p if (graph.getJob().getMyRole() == JobMemberRole.provider) { if (node.getComponentType() == ComponentType.MixLR || node.getComponentType() == ComponentType.MixSecureBoost) { - JSONObject p = taskParam.getJSONObject("params"); - p.put("random_cipher_seed", randomCipherSeed); - taskParam.put("params", p); + taskParam.put("random_cipher_seed", randomCipherSeed); } } taskConfig.setJob(jobInfo); taskConfig.setModule(taskType()); - taskConfig.setParams(taskParam.getJSONObject("params")); + taskConfig.setParams(taskParam); taskConfig.setInput(generateInput(graph, node, count)); taskConfig.setOutput(getOutputs(graph, node)); taskConfig.setTask(kernelTask); task.setTaskConf(JSON.toJSONString(taskConfig)); task.setRole(graph.getJob().getMyRole()); - task.setStatus(TaskStatus.wait_run); + // ImageDataIO 组件不用执行,直接设置为成功。 + task.setStatus(taskType() == ComponentType.ImageDataIO ? TaskStatus.success : TaskStatus.wait_run); task.setTaskId(node.createTaskId(graph.getJob(), count)); task.setParentTaskIdList( node.createParentTaskIds(graph.getJob(), getCount(preTasks, node.getDeep() - 1, count))); @@ -184,7 +196,7 @@ public List buildMixTask(FlowGraph graph, List p taskRepository.save(task); subTasks.add(task); count++; - + // rollback node.setTaskName(FlowGraphNode.createTaskName(node.getComponentType(), node.getNodeId())); if (parentNode != null) { @@ -195,15 +207,15 @@ public List buildMixTask(FlowGraph graph, List p return subTasks; } - private Map generateInput(FlowGraph graph, FlowGraphNode node, int count) throws FlowNodeException{ + private Map generateInput(FlowGraph graph, FlowGraphNode node, int count) throws FlowNodeException { Map inputs = getInputs(graph, node); try { - + JSONObject json = JSON.parseObject(JSON.toJSONString(inputs)); JSONObject data = json.getJSONObject("data"); Set> entrySet = data.entrySet(); - for(Map.Entry entry : entrySet){ + for (Map.Entry entry : entrySet) { List dataSetList = data.getObject(entry.getKey(), TypeReference.LIST_STRING); String end = "_" + count; List newDataSet = new ArrayList<>(); @@ -228,6 +240,7 @@ private Map generateInput(FlowGraph graph, FlowGraphNode node, i } return inputs; } + private int getCount(List preTasks, int parentDeep, int currentCount) { if (parentDeep < 0 || preTasks == null || preTasks.isEmpty()) { return currentCount; @@ -245,7 +258,7 @@ private int getCount(List preTasks, int parentDeep, int currentC * @param preTasks A collection of created tasks * @param node the node of flow */ - public TaskMySqlModel buildTask(FlowGraph graph, List preTasks, KernelJob jobInfo, FlowGraphNode node) throws StatusCodeWithException { + public TaskMySqlModel buildTask(ProjectMySqlModel project, FlowGraph graph, List preTasks, KernelJob jobInfo, FlowGraphNode node) throws Exception { T params = (T) node.getParamsModel(); @@ -275,20 +288,25 @@ public TaskMySqlModel buildTask(FlowGraph graph, List preTasks, task.setTaskType(taskType()); task.setName(node.getTaskName()); - TaskConfig taskConfig = new TaskConfig(); - taskConfig.setJob(jobInfo); - taskConfig.setModule(taskType()); - taskConfig.setParams(taskParam.getJSONObject("params")); - taskConfig.setInput(getInputs(graph, node)); - taskConfig.setOutput(getOutputs(graph, node)); - taskConfig.setTask(getTaskMembers(graph, node)); + if (project.getProjectType() == ProjectType.MachineLearning) { + TaskConfig taskConfig = new TaskConfig(); + taskConfig.setJob(jobInfo); + taskConfig.setModule(taskType()); + taskConfig.setParams(taskParam); + taskConfig.setInput(getInputs(graph, node)); + taskConfig.setOutput(getOutputs(graph, node)); + taskConfig.setTask(getTaskMembers(graph, node)); - task.setTaskConf( - JSON.toJSONString(taskConfig) - ); + task.setTaskConf( + JSON.toJSONString(taskConfig) + ); + } else if (project.getProjectType() == ProjectType.DeepLearning) { + task.setTaskConf(taskParam.toJSONString()); + } task.setRole(graph.getJob().getMyRole()); - task.setStatus(TaskStatus.wait_run); + // ImageDataIO 组件不用执行,直接设置为成功。 + task.setStatus(taskType() == ComponentType.ImageDataIO ? TaskStatus.success : TaskStatus.wait_run); task.setTaskId(node.createTaskId(graph.getJob())); task.setParentTaskIdList(node.createParentTaskIds(graph.getJob())); task.setProjectId(node.getProjectId()); @@ -332,7 +350,7 @@ public List getTaskAllResult(String taskId) { /** * Show the specified execution result */ - public TaskResultOutputModel getTaskResult(String taskId, String type) { + public TaskResultOutputModel getTaskResult(String taskId, String type) throws StatusCodeWithException { TaskResultMySqlModel result = getResult(taskId, type); if (result == null) { return null; @@ -442,7 +460,7 @@ public T findMyData(Collection list, Function getMemberIdFunc) /** * Deserialize form parameters into Param objects */ - public T deserializationParam(FlowGraphNode node, String json) throws FlowNodeException { + public T deserializationParam(String json) throws StatusCodeWithException { if (json == null) { json = "{}"; } @@ -451,12 +469,7 @@ public T deserializationParam(FlowGraphNode node, String json) throws FlowNodeEx .create(json) .toJavaObject(paramsClass); - // Basic check of entry (non-empty, regular check) - try { - params.checkAndStandardize(); - } catch (StatusCodeWithException e) { - throw new FlowNodeException(node, e.getMessage()); - } + params.checkAndStandardize(); return params; } @@ -481,7 +494,7 @@ private Class getParamsClass(Class clazz) { public List getMixTaskMembers(FlowGraph graph, FlowGraphNode node) { List kernelTasks = new ArrayList<>(); - List allMembers = graph.getMembers().stream().map(Member::new).collect(Collectors.toList()); + List allMembers = Member.forMachineLearning(graph.getMembers()); List promoters = allMembers.stream().filter(s -> s.getMemberRole() == JobMemberRole.promoter) .collect(Collectors.toList()); List providers = allMembers.stream().filter(s -> s.getMemberRole() == JobMemberRole.provider) @@ -496,13 +509,13 @@ public List getMixTaskMembers(FlowGraph graph, FlowGraphNode node) { || node.getComponentType() == ComponentType.MixSecureBoost || node.getComponentType() == ComponentType.MixStatistic || node.getComponentType() == ComponentType.MixBinning) { - Member promoter = allMembers.stream().filter(x -> x.getMemberRole() == JobMemberRole.promoter) + Member promoter = allMembers.stream() + .filter(x -> x.getMemberRole() == JobMemberRole.promoter + && (StringUtils.isBlank(graph.getCreatorMemberId()) + || x.getMemberId().equalsIgnoreCase(graph.getCreatorMemberId()))) .findFirst().orElse(null); if (promoter != null) { - arbiter = new Member(); - arbiter.setMemberId(promoter.getMemberId()); - arbiter.setMemberRole(JobMemberRole.arbiter); - arbiter.setMemberName(promoter.getMemberName()); + arbiter = Member.forMachineLearning(promoter.getMemberId(), JobMemberRole.arbiter); allMembers.add(arbiter); } } @@ -611,34 +624,31 @@ public KernelTask getTaskMembers(FlowGraph graph, FlowGraphNode node) { List members = new ArrayList<>(); KernelTask task = new KernelTask(); dataSetItems.forEach(x -> { - Member member = new Member(); - member.setMemberId(x.getMemberId()); - member.setMemberName(CacheObjects.getMemberName(x.getMemberId())); - member.setMemberRole(x.getMemberRole()); - members.add(member); + members.add(Member.forMachineLearning(x)); // Horizontal modeling component, and the current member is a promoter, need to // increase arbiter. if (Components.needArbiterTask(node.getComponentType())) { if (x.getMemberRole() == JobMemberRole.promoter && CacheObjects.getMemberId().equals(x.getMemberId())) { - Member arbiterMember = new Member(); - arbiterMember.setMemberId(x.getMemberId()); - arbiterMember.setMemberName(CacheObjects.getMemberName(x.getMemberId())); - arbiterMember.setMemberRole(JobMemberRole.arbiter); + Member arbiterMember = Member.forMachineLearning(x.getMemberId(), JobMemberRole.arbiter); members.add(arbiterMember); } } }); - Member promoter = graph.getMembers().stream().map(x -> new Member(x)) - .filter(s -> s.getMemberRole() == JobMemberRole.promoter).findFirst().orElse(null); + Member promoter = Member.forMachineLearning(graph.getMembers()) + .stream() + .filter(s -> s.getMemberRole() == JobMemberRole.promoter) + .findFirst() + .orElse(null); if (Components.needArbiterTask(node.getComponentType())) { if (graph.getJob().getMyRole() == JobMemberRole.provider && promoter != null) { - Member arbiterMember = new Member(); - arbiterMember.setMemberId(promoter.getMemberId()); - arbiterMember.setMemberName(CacheObjects.getMemberName(promoter.getMemberId())); - arbiterMember.setMemberRole(JobMemberRole.arbiter); - members.add(arbiterMember); + members.add( + Member.forMachineLearning( + promoter.getMemberId(), + JobMemberRole.arbiter + ) + ); } } @@ -677,13 +687,13 @@ public boolean parentHasIntersectedDataSet(FlowGraph graph, FlowGraphNode node) if (x.getComponentType() == ComponentType.DataIO) { DataIOComponent.Params dataIOParams = (DataIOComponent.Params) x.getParamsModel(); - DataSetMysqlModel myDataSet = dataIOParams.getMyDataSet(); + TableDataSetMysqlModel myDataSet = dataIOParams.getMyDataSet(); // If it is not a derived data set, it must have been misaligned. - if (myDataSet != null && myDataSet.getSourceType() != null) { + if (myDataSet != null && myDataSet.isDerivedResource()) { // If the derived data set comes from alignment - if (myDataSet.getSourceType() == ComponentType.Intersection) { + if (myDataSet.getDerivedFrom() == ComponentType.Intersection) { return true; } @@ -715,7 +725,7 @@ public boolean parentHasIntersectedDataSet(FlowGraph graph, FlowGraphNode node) /** * Assemble the input parameters of the task according to the component configuration */ - protected abstract JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, T params) throws FlowNodeException; + protected abstract JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, T params) throws Exception; public abstract ComponentType taskType(); @@ -727,7 +737,7 @@ public boolean parentHasIntersectedDataSet(FlowGraph graph, FlowGraphNode node) /** * Show the specified execution result */ - protected abstract TaskResultMySqlModel getResult(String taskId, String type); + protected abstract TaskResultMySqlModel getResult(String taskId, String type) throws StatusCodeWithException; /** * Declare the input parameter type diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/dto/AbstractDataIOParam.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/dto/AbstractDataIOParam.java new file mode 100644 index 000000000..1c39da3a9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/dto/AbstractDataIOParam.java @@ -0,0 +1,88 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.component.base.dto; + +import com.welab.wefe.board.service.component.DataIOComponent; +import com.welab.wefe.board.service.component.deep_learning.ImageDataIOComponent; +import com.welab.wefe.board.service.dto.entity.data_resource.output.DataResourceOutputModel; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetService; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import org.apache.commons.collections4.CollectionUtils; + +import java.util.List; + +/** + * @author zane + * @date 2021/11/24 + */ +public abstract class AbstractDataIOParam extends AbstractCheckModel { + public List dataSetList; + + public T getMyJobDataSetItem(JobMemberRole role) { + if (CollectionUtils.isEmpty(dataSetList)) { + return null; + } + + return dataSetList.stream() + .filter(x -> CacheObjects.getMemberId().equals(x.memberId) && role == x.getMemberRole()) + .findFirst() + .orElse(null); + } + + public DataResourceOutputModel getMyJobDataSet(JobMemberRole role) throws StatusCodeWithException { + T dataSetItem = getMyJobDataSetItem(role); + if (dataSetItem == null) { + return null; + } + + if (dataSetItem instanceof ImageDataIOComponent.DataSetItem) { + return Launcher.CONTEXT + .getBean(ImageDataSetService.class) + .findDataSetFromLocalOrUnion( + dataSetItem.getMemberId(), + dataSetItem.getDataSetId() + ); + + } else if (dataSetItem instanceof DataIOComponent.DataSetItem) { + return Launcher.CONTEXT + .getBean(TableDataSetService.class) + .findDataSetFromLocalOrUnion( + dataSetItem.getMemberId(), + dataSetItem.getDataSetId() + ); + } + + return null; + } + + // region getter/setter + + public List getDataSetList() { + return dataSetList; + } + + public void setDataSetList(List dataSetList) { + this.dataSetList = dataSetList; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/dto/AbstractDataSetItem.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/dto/AbstractDataSetItem.java new file mode 100644 index 000000000..2e8647d2e --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/dto/AbstractDataSetItem.java @@ -0,0 +1,63 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.component.base.dto; + + +import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.JobMemberRole; + +/** + * @author zane + * @date 2021/11/24 + */ +public abstract class AbstractDataSetItem extends AbstractCheckModel { + @Check(name = "成员Id", require = true) + public String memberId; + @Check(name = "成员角色", require = true) + public JobMemberRole memberRole; + @Check(name = "数据集 Id", require = true) + public String dataSetId; + + // region getter/setter + + public String getMemberId() { + return memberId; + } + + public void setMemberId(String memberId) { + this.memberId = memberId; + } + + public JobMemberRole getMemberRole() { + return memberRole; + } + + public void setMemberRole(JobMemberRole memberRole) { + this.memberRole = memberRole; + } + + public String getDataSetId() { + return dataSetId; + } + + public void setDataSetId(String dataSetId) { + this.dataSetId = dataSetId; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/BinnedOutputFilter.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/BinnedOutputFilter.java index a66b470df..81fb515e5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/BinnedOutputFilter.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/BinnedOutputFilter.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,10 +19,11 @@ import com.welab.wefe.board.service.component.DataIOComponent; import com.welab.wefe.board.service.component.base.io.IODataType; import com.welab.wefe.board.service.component.base.io.OutputItem; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.ComponentType; + /** * Query conditions: include data after binning @@ -71,9 +72,9 @@ public static boolean binned(FlowGraph graph, FlowGraphNode node, OutputItem out FlowGraphNode dataIONode = graph.findOneNodeFromParent(node, ComponentType.DataIO); if (dataIONode != null) { DataIOComponent.Params params = (DataIOComponent.Params) dataIONode.getParamsModel(); - DataSetMysqlModel myDataSet = params.getMyDataSet(); + TableDataSetMysqlModel myDataSet = params.getMyDataSet(); - return myDataSet != null && myDataSet.getSourceType() == ComponentType.Binning; + return myDataSet != null && myDataSet.getDerivedFrom() == ComponentType.Binning; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/IntersectedOutputFilter.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/IntersectedOutputFilter.java index e753d5e5b..ea36497a7 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/IntersectedOutputFilter.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/IntersectedOutputFilter.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,10 +19,11 @@ import com.welab.wefe.board.service.component.DataIOComponent; import com.welab.wefe.board.service.component.base.io.IODataType; import com.welab.wefe.board.service.component.base.io.OutputItem; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.ComponentType; + /** * Query condition: aligned data @@ -65,16 +66,16 @@ public static boolean intersected(FlowGraph graph, FlowGraphNode node, OutputIte FlowGraphNode dataIONode = graph.findOneNodeFromParent(node, ComponentType.DataIO); if (dataIONode != null) { DataIOComponent.Params params = (DataIOComponent.Params) dataIONode.getParamsModel(); - DataSetMysqlModel myDataSet = params.getMyDataSet(); + TableDataSetMysqlModel myDataSet = params.getMyDataSet(); // If it is not a derived data set, it must have been misaligned. - if (myDataSet == null || myDataSet.getSourceType() == null) { + if (myDataSet == null || !myDataSet.isDerivedResource()) { return false; } // If the derived data set comes from alignment - if (myDataSet.getSourceType() == ComponentType.Intersection) { + if (myDataSet.getDerivedFrom() == ComponentType.Intersection) { return true; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/OutputDataTypesOutputFilter.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/OutputDataTypesOutputFilter.java index 8cb33754c..1d586afaf 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/OutputDataTypesOutputFilter.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/OutputDataTypesOutputFilter.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/OutputItemFilterFunction.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/OutputItemFilterFunction.java index 3319758e1..bf8ef6102 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/OutputItemFilterFunction.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/filter/OutputItemFilterFunction.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/DataTypeGroup.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/DataTypeGroup.java index 64367bde7..def82a934 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/DataTypeGroup.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/DataTypeGroup.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/IODataType.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/IODataType.java index bc1c172ce..e64746772 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/IODataType.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/IODataType.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputGroup.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputGroup.java index 21847f70e..7364dc1c6 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputGroup.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputGroup.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputMatcher.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputMatcher.java index 89dcdc767..5f5d6babd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputMatcher.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputMatcher.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,7 +21,8 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.ComponentType; + /** * Component's input matcher diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputSupplier.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputSupplier.java index 16a4e8065..97a14a21b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputSupplier.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/InputSupplier.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/Names.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/Names.java index 03b4baf05..779faef6f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/Names.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/Names.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/NodeOutputItem.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/NodeOutputItem.java index dd98b1db9..8b7f36547 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/NodeOutputItem.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/NodeOutputItem.java @@ -5,8 +5,8 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,8 @@ package com.welab.wefe.board.service.component.base.io; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.ComponentType; + /** * @author zane.luo diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/OutputItem.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/OutputItem.java index 9e97a3e9a..b09f3103d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/OutputItem.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/base/io/OutputItem.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/AbstractDeepLearningComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/AbstractDeepLearningComponent.java new file mode 100644 index 000000000..3e6334cf4 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/AbstractDeepLearningComponent.java @@ -0,0 +1,156 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.component.deep_learning; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.api.data_resource.image_data_set.ImageDataSetDownloadApi; +import com.welab.wefe.board.service.component.base.AbstractComponent; +import com.welab.wefe.board.service.component.base.io.InputMatcher; +import com.welab.wefe.board.service.component.base.io.OutputItem; +import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; +import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.DataResourceOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.ImageDataSetOutputModel; +import com.welab.wefe.board.service.dto.kernel.Member; +import com.welab.wefe.board.service.dto.kernel.deep_learning.Env; +import com.welab.wefe.board.service.dto.kernel.deep_learning.KernelJob; +import com.welab.wefe.board.service.exception.FlowNodeException; +import com.welab.wefe.board.service.model.FlowGraph; +import com.welab.wefe.board.service.model.FlowGraphNode; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetService; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.web.api.base.Api; +import com.welab.wefe.common.wefe.enums.ComponentType; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * @author zane.luo + */ +public abstract class AbstractDeepLearningComponent extends AbstractComponent { + + @Autowired + private ImageDataSetService imageDataSetService; + + @Override + protected void checkBeforeBuildTask(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { + FlowGraphNode imageDataIo = graph.findOneNodeFromParent(node, ComponentType.ImageDataIO); + if (imageDataIo == null) { + throw new FlowNodeException(node, "尚未选择数据集"); + } + } + + + @Override + protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws StatusCodeWithException { + ImageDataIOComponent.Params imageDataIoParam = (ImageDataIOComponent.Params) graph.findOneNodeFromParent(node, ComponentType.ImageDataIO).getParamsModel(); + + Set labelNames = new HashSet<>(); + for (ImageDataIOComponent.DataSetItem dataSetItem : imageDataIoParam.getDataSetList()) { + ImageDataSetOutputModel dataSet = imageDataSetService.findDataSetFromLocalOrUnion(dataSetItem.getMemberId(), dataSetItem.getDataSetId()); + List list = StringUtil.splitWithoutEmptyItem(dataSet.getLabelList(), ","); + labelNames.addAll(list); + } + params.numClasses = labelNames.size(); + + KernelJob job = new KernelJob(); + job.projectId = graph.getJob().getProjectId(); + job.jobId = graph.getJob().getJobId(); + job.taskId = node.createTaskId(graph.getJob()); + job.role = graph.getJob().getMyRole(); + job.memberId = CacheObjects.getMemberId(); + job.env = new Env(imageDataIoParam); + job.members = Member.forDeepLearning(graph.getMembers()); + + DataResourceOutputModel myJobDataSet = imageDataIoParam.getMyJobDataSet(job.role); + JObject dataSetInfo = JObject.create(myJobDataSet); + dataSetInfo.put("download_url", buildDataSetDownloadUrl(myJobDataSet.getId(), job.jobId)); + + JObject output = JObject.create(job); + output.put("data_set", dataSetInfo); + output.put("algorithm_config", params); + + return output; + } + + private String buildDataSetDownloadUrl(String dataSetId, String jobId) { + Api annotation = ImageDataSetDownloadApi.class.getAnnotation(Api.class); + return Launcher.getBean(GlobalConfigService.class) + .getBoardConfig() + .intranetBaseUri + + "/" + + annotation.path() + + "?data_set_id=" + dataSetId + + "&job_id=" + jobId; + + } + + @Override + protected List getAllResult(String taskId) { + return taskResultService.listAllResult(taskId); + } + + @Override + protected TaskResultMySqlModel getResult(String taskId, String type) { + return taskResultService.findByTaskIdAndType(taskId, type); + } + + @Override + protected List inputs(FlowGraph graph, FlowGraphNode node) { + return null; + } + + @Override + public List outputs(FlowGraph graph, FlowGraphNode node) throws FlowNodeException { + return null; + } + + public static class Params extends AbstractCheckModel { + @Check( + name = "算法类型", + require = true, + regex = "(paddle_clas|paddle_detection)", + desc = "paddle_clas(分类), paddle_detection(目标检测)" + ) + public String program; + @Check(name = "迭代次数", require = true) + public Integer maxIter; + @Check(name = "聚合步长", require = true) + public Integer innerStep; + @Check(name = "检测模型名称", require = true) + public String architecture; + @Check(name = "类别数") + public Integer numClasses; + @Check(name = "学习率", require = true) + public Double baseLr; + @Check(name = "图像输入尺寸", require = true) + public Integer[] imageShape; + @Check(name = "批量大小", require = true) + public Integer batchSize; + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/ImageDataIOComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/ImageDataIOComponent.java new file mode 100644 index 000000000..87f2abd9c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/ImageDataIOComponent.java @@ -0,0 +1,192 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.component.deep_learning; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.component.base.AbstractComponent; +import com.welab.wefe.board.service.component.base.dto.AbstractDataIOParam; +import com.welab.wefe.board.service.component.base.dto.AbstractDataSetItem; +import com.welab.wefe.board.service.component.base.io.InputMatcher; +import com.welab.wefe.board.service.component.base.io.OutputItem; +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.job.JobMemberMySqlModel; +import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; +import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.ImageDataSetOutputModel; +import com.welab.wefe.board.service.exception.FlowNodeException; +import com.welab.wefe.board.service.model.FlowGraph; +import com.welab.wefe.board.service.model.FlowGraphNode; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetSampleService; +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetService; +import com.welab.wefe.board.service.service.data_resource.image_data_set.data_set_parser.AbstractImageDataSetParser; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.util.List; + +/** + * @author zane.luo + */ +@Service +public class ImageDataIOComponent extends AbstractComponent { + + @Autowired + private ImageDataSetService imageDataSetService; + @Autowired + private ImageDataSetSampleService imageDataSetSampleService; + + @Override + public ComponentType taskType() { + return ComponentType.ImageDataIO; + } + + @Override + protected void checkBeforeBuildTask(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { + List jobMembers = graph.getMembers(); + + if (CollectionUtils.isEmpty(jobMembers) || jobMembers.size() < 2) { + throw new FlowNodeException(node, "请至少为两个成员指定数据集"); + } + + if (CollectionUtils.isEmpty(params.getDataSetList()) || params.getDataSetList().size() < 2) { + throw new FlowNodeException(node, "请选择多个数据集用于联邦"); + } + + if (jobMembers.stream().noneMatch(x -> x.getJobRole() == JobMemberRole.promoter)) { + throw new FlowNodeException(node, "请为 promoter 指定数据集"); + } + + if (graph.getJob().getMyRole() == JobMemberRole.promoter) { + String labelList = null; + // 检查数据集的有效性 + for (DataSetItem dataSetItem : params.getDataSetList()) { + + ImageDataSetOutputModel one = null; + try { + one = imageDataSetService.findDataSetFromLocalOrUnion(dataSetItem.memberId, dataSetItem.dataSetId); + } catch (StatusCodeWithException e) { + throw new FlowNodeException(node, e.getMessage()); + } + if (one == null) { + throw new FlowNodeException(node, "成员 " + CacheObjects.getMemberName(dataSetItem.memberId) + " 的数据集 " + dataSetItem.getDataSetId() + " 不存在,请检查是否已删除。"); + } + if (one.getLabeledCount() == 0) { + throw new FlowNodeException(node, "成员 " + CacheObjects.getMemberName(dataSetItem.memberId) + " 的数据集【" + one.getName() + "】已标注的样本量为 0,无法使用。"); + } + // 检查各成员的数据集的标签列表是否一致 + if (labelList == null) { + labelList = StringUtil.join(one.getLabelSet(), ","); + } else { + if (!labelList.equals(StringUtil.join(one.getLabelSet(), ","))) { + throw new FlowNodeException(node, "各成员提供的数据集标签列表不一致,无法创建任务。"); + } + } + } + } + } + + + @Override + protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws Exception { + DataSetItem myDataSetConfig = params.getDataSetList() + .stream() + .filter(x -> + x.getMemberId().equals(CacheObjects.getMemberId()) + && x.getMemberRole() == graph.getJob().getMyRole() + ) + .findFirst() + .orElse(null); + + ImageDataSetMysqlModel myDataSet = imageDataSetService.findOneById(myDataSetConfig.dataSetId); + + JObject output = JObject.create(myDataSet); + + + // 生成数据集文件 + AbstractImageDataSetParser + .getParser(myDataSet.getForJobType()) + .parseSamplesToDataSetFile( + graph.getJob().getJobId(), + myDataSet, + imageDataSetSampleService.allLabeled(myDataSetConfig.dataSetId), + params.trainTestSplitRatio + ); + + + return output; + } + + @Override + protected List getAllResult(String taskId) { + return taskResultService.listAllResult(taskId); + } + + @Override + protected TaskResultMySqlModel getResult(String taskId, String type) { + return taskResultService.findByTaskIdAndType(taskId, type); + } + + @Override + protected List inputs(FlowGraph graph, FlowGraphNode node) { + return null; + } + + @Override + public List outputs(FlowGraph graph, FlowGraphNode node) throws FlowNodeException { + return null; + } + + public static class Params extends AbstractDataIOParam { + @Check(name = "数据集切割比例", desc = "取值1-99,该值为训练集的百分比。", require = true) + public int trainTestSplitRatio; + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + if (trainTestSplitRatio < 1 || trainTestSplitRatio > 99) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("数据集切割比例(训练:测试),取值必须在 1-99 之间,当前取值:" + trainTestSplitRatio); + } + } + + public void fillDataSetDetail() throws StatusCodeWithException { + + ImageDataSetService imageDataSetService = Launcher.getBean(ImageDataSetService.class); + + for (ImageDataIOComponent.DataSetItem dataSetItem : dataSetList) { + dataSetItem.dataResource = imageDataSetService.findDataSetFromLocalOrUnion(dataSetItem.memberId, dataSetItem.dataSetId); + } + } + } + + public static class DataSetItem extends AbstractDataSetItem { + @Check(desc = "非入参,而是当此对象作为返回值时输出的字段。") + public ImageDataSetOutputModel dataResource; + } + + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/PaddleClassifyComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/PaddleClassifyComponent.java new file mode 100644 index 000000000..135c520c7 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/PaddleClassifyComponent.java @@ -0,0 +1,31 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.component.deep_learning; + +import com.welab.wefe.common.wefe.enums.ComponentType; +import org.springframework.stereotype.Service; + +/** + * @author zane + * @date 2022/1/10 + */ +@Service +public class PaddleClassifyComponent extends AbstractDeepLearningComponent { + @Override + public ComponentType taskType() { + return ComponentType.PaddleClassify; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/PaddleDetectionComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/PaddleDetectionComponent.java new file mode 100644 index 000000000..6062eae40 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/deep_learning/PaddleDetectionComponent.java @@ -0,0 +1,31 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.component.deep_learning; + +import com.welab.wefe.common.wefe.enums.ComponentType; +import org.springframework.stereotype.Service; + +/** + * @author zane + * @date 2022/1/10 + */ +@Service +public class PaddleDetectionComponent extends AbstractDeepLearningComponent { + @Override + public ComponentType taskType() { + return ComponentType.PaddleDetection; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/BinningComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/BinningComponent.java index f78c10842..85d6cc528 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/BinningComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/BinningComponent.java @@ -1,11 +1,11 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -44,12 +44,12 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskResultType; /** * @author lonnie @@ -64,7 +64,7 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas if (intersectionNode == null) { throw new FlowNodeException(node, "请在前面添加样本对齐组件。"); } - + if (CollectionUtils.isEmpty(params.getMembers())) { throw new FlowNodeException(node, "请添加分箱策略"); } @@ -97,8 +97,6 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - // Reassemble front-end parameters JObject transformParam = JObject.create() .append("transform_cols", -1) @@ -221,9 +219,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("optimal_binning_param", optimalBinningParam) .append("modes", modesObj); - taskParam.put("params", binningParam); - - return taskParam; + return binningParam; } @Override @@ -270,32 +266,32 @@ protected TaskResultMySqlModel getResult(String taskId, String type) { } List providerResults = modelParam.getJSONList("providerResults"); - Map biningResultMap = new HashMap<>(); - if (CollectionUtils.isNotEmpty(providerResults)) { - for (JObject providerResult : providerResults) { - String memberName = CacheObjects.getMemberName(providerResult.getString("memberId")); - String key = memberName + "_" + providerResult.getString("memberId") + "_" - + providerResult.getString("role"); - if (biningResultMap.containsKey(key)) { - // merge - JObject result = biningResultMap.get(key); - JObject temp = result.getJObject("binningResult"); - temp.putAll(providerResult.getJObject("binningResult")); - result.put("binningResult", temp); - biningResultMap.put(key, result); - } else { - // add - providerResult.append("member_name", memberName) - .append("member_id", providerResult.getString("memberId")) - .append("member_role", providerResult.getString("role")); - biningResultMap.put(key, providerResult); - } - - } - for (Map.Entry entry : biningResultMap.entrySet()) { - resultList.add(entry.getValue()); - } - } + Map biningResultMap = new HashMap<>(); + if (CollectionUtils.isNotEmpty(providerResults)) { + for (JObject providerResult : providerResults) { + String memberName = CacheObjects.getMemberName(providerResult.getString("memberId")); + String key = memberName + "_" + providerResult.getString("memberId") + "_" + + providerResult.getString("role"); + if (biningResultMap.containsKey(key)) { + // merge + JObject result = biningResultMap.get(key); + JObject temp = result.getJObject("binningResult"); + temp.putAll(providerResult.getJObject("binningResult")); + result.put("binningResult", temp); + biningResultMap.put(key, result); + } else { + // add + providerResult.append("member_name", memberName) + .append("member_id", providerResult.getString("memberId")) + .append("member_role", providerResult.getString("role")); + biningResultMap.put(key, providerResult); + } + + } + for (Map.Entry entry : biningResultMap.entrySet()) { + resultList.add(entry.getValue()); + } + } taskResultMySqlModel.setResult(JObject.create().append("result", resultList).toJSONString()); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureCalculationComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureCalculationComponent.java index 3bf1fb07b..7f0195b9f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureCalculationComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureCalculationComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -31,11 +31,11 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.springframework.beans.BeanUtils; import org.springframework.stereotype.Service; @@ -64,10 +64,7 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - - JSONObject taskParam = new JSONObject(); - - return taskParam; + return new JSONObject(); } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureSelectionComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureSelectionComponent.java index 6350f1128..9ad14cc39 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureSelectionComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureSelectionComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -29,11 +29,11 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.FederatedLearningType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; import org.apache.commons.collections4.CollectionUtils; import org.springframework.stereotype.Service; @@ -117,8 +117,6 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - // Reassemble front-end parameters List members = params.members; List kernelParam = new ArrayList<>(); @@ -126,15 +124,18 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT JObject obj = JObject.create().append("member_id", member.getMemberId()) .append("role", member.getMemberRole()) - .append("features", member.getFeatures().stream().map(x -> x.getName()).collect(Collectors.toList())); + .append("features", + member.getFeatures() + .stream() + .map(x -> x.getName()) + .collect(Collectors.toList()) + ); kernelParam.add(obj); } - taskParam.put("params", JObject.create().append("members", kernelParam)); - - taskParam.put("env", "test"); + JObject output = JObject.create().append("members", kernelParam); - return taskParam; + return output; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureStandardizedComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureStandardizedComponent.java index 7449cc302..89325d742 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureStandardizedComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureStandardizedComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -29,11 +29,11 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.stereotype.Service; import java.util.ArrayList; @@ -54,9 +54,6 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - - JSONObject taskParam = new JSONObject(); - // Need to find DataIO data set FlowGraphNode dataIONode = graph.findOneNodeFromParent(node, ComponentType.DataIO); TaskMySqlModel dataIOTask = findTaskFromPretasks(preTasks, dataIONode); @@ -64,17 +61,16 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT throw new FlowNodeException(node, "请添加DataIO组件!"); } - JObject resultObj = JObject.create(); // Get the withLabel field in the dataIO node JObject taskConfig = JObject.create(dataIOTask.getTaskConf()); - if (taskConfig.getJObject("params") == null) { + if (taskConfig == null) { throw new FlowNodeException(node, "找不到DataIO_task中的with_label字段"); } - boolean withLabel = taskConfig.getJObject("params").getBooleanValue("with_label"); + boolean withLabel = taskConfig.getBooleanValue("with_label"); List members = params.getMembers(); - + JObject output = JObject.create(); for (MemberFeatureInfoModel member : members) { if (CacheObjects.getMemberId().equals(member.getMemberId())) { List features = member.getFeatures(); @@ -85,18 +81,18 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT } }); - resultObj.append("fields", fields); + output.append("fields", fields); break; } } - resultObj.append("with_label", withLabel) + output + .append("with_label", withLabel) .append("save_dataset", true); - taskParam.put("params", resultObj); - return taskParam; + return output; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureStatisticsComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureStatisticsComponent.java index afe35e022..dd0074398 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureStatisticsComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureStatisticsComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -27,17 +27,17 @@ import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; import com.welab.wefe.board.service.dto.entity.MemberModel; -import com.welab.wefe.board.service.dto.kernel.KernelTask; import com.welab.wefe.board.service.dto.kernel.Member; +import com.welab.wefe.board.service.dto.kernel.machine_learning.KernelTask; import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.BeanUtils; import org.springframework.stereotype.Service; @@ -87,11 +87,8 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - // Need to use dataIO data set FlowGraphNode dataIONode = graph.findOneNodeFromParent(node, ComponentType.DataIO); - TaskMySqlModel dataIOTask = findTaskFromPretasks(preTasks, dataIONode); if (dataIONode == null) { throw new FlowNodeException(node, "请添加DataIO策略!"); @@ -105,15 +102,10 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT } } - JObject featureStatisticsParam = JObject.create(); - featureStatisticsParam.append("percentage_list", percentileList); - - // Local non-local federation for local testing - featureStatisticsParam.put("work_mode", params.workMode); - - taskParam.put("params", featureStatisticsParam); - - return taskParam; + JObject output = JObject.create() + .append("percentage_list", percentileList) + .append("work_mode", params.workMode); + return output; } @Override @@ -188,10 +180,7 @@ public KernelTask getTaskMembers(FlowGraph graph, FlowGraphNode node) { if ("local".equals(params.getWorkMode())) { params.getMembers().forEach(x -> { if (x.getMemberId().equals(CacheObjects.getMemberId())) { - Member member = new Member(); - member.setMemberId(x.getMemberId()); - member.setMemberName(x.getMemberName()); - member.setMemberRole(x.getMemberRole()); + Member member = Member.forMachineLearning(x.getMemberId(), x.getMemberRole()); members.add(member); } }); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureTransformComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureTransformComponent.java index 76630df3d..5a8fe70b1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureTransformComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FeatureTransformComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -34,10 +34,10 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; /** * @author Winter @@ -79,9 +79,8 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT }); - taskParam.put("params", JObject.create().append("transform_rules", transformRules.toJSONString())); - - return taskParam; +// taskParam.put("params", JObject.create().append("transform_rules", transformRules.toJSONString())); + return JObject.create().append("transform_rules", transformRules.toJSONString()); } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FillMissingValueComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FillMissingValueComponent.java index 08c15a34c..c535bad5c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FillMissingValueComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/FillMissingValueComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -25,19 +25,19 @@ import com.welab.wefe.board.service.component.base.io.OutputItem; import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; -import com.welab.wefe.board.service.dto.kernel.KernelTask; import com.welab.wefe.board.service.dto.kernel.Member; +import com.welab.wefe.board.service.dto.kernel.machine_learning.KernelTask; import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.board.service.util.ModelMapper; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.BeanUtils; import org.springframework.stereotype.Service; @@ -82,8 +82,6 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - // Need to find DataIO data set FlowGraphNode dataIONode = graph.findOneNodeFromParent(node, ComponentType.DataIO); TaskMySqlModel dataIOTask = findTaskFromPretasks(preTasks, dataIONode); @@ -93,10 +91,10 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT // Get the withLabel field in the dataIO node JObject taskConfig = JObject.create(dataIOTask.getTaskConf()); - if (taskConfig.getJObject("params") == null) { + if (taskConfig == null) { throw new FlowNodeException(node, "找不到DataIO_task中的with_label字段"); } - boolean withLabel = taskConfig.getJObject("params").getBooleanValue("with_label"); + boolean withLabel = taskConfig.getBooleanValue("with_label"); List members = params.members; @@ -117,14 +115,12 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT break; } } - JObject resultObj = JObject.create() + JObject output = JObject.create() .append("features", featuresObj.toJSONString()) .append("with_label", withLabel) .append("save_dataset", true); - taskParam.put("params", resultObj); - - return taskParam; + return output; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzFeatureBinningComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzFeatureBinningComponent.java index a2c85ef27..d10ec70c8 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzFeatureBinningComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzFeatureBinningComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,16 +16,6 @@ package com.welab.wefe.board.service.component.feature; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; - -import org.apache.commons.collections4.CollectionUtils; -import org.springframework.beans.BeanUtils; -import org.springframework.stereotype.Service; - import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.component.DataIOComponent; import com.welab.wefe.board.service.component.base.AbstractComponent; @@ -39,12 +29,21 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskResultType; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; @Service public class HorzFeatureBinningComponent extends AbstractComponent { @@ -101,10 +100,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT JObject binningParam = JObject.create() .append("bin_num", bin_num) .append("bin_names", bin_names); - - taskParam.put("params", binningParam); - - return taskParam; + return binningParam; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzOneHotComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzOneHotComponent.java index a9e52029b..ed55f74dc 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzOneHotComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzOneHotComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -34,8 +34,8 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; @Service public class HorzOneHotComponent extends AbstractComponent { @@ -49,7 +49,6 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, VertOneHotComponent.Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); List members = params.getMembers(); List transformColNames = new ArrayList<>(); @@ -63,10 +62,8 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT }); } }); - taskParam.put("params", - JObject.create().append("transform_col_names", transformColNames).append("save_dataset", true)); - return taskParam; + return JObject.create().append("transform_col_names", transformColNames).append("save_dataset", true); } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzStatisticComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzStatisticComponent.java index 1adb4ca95..bcf436f5d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzStatisticComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/HorzStatisticComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,6 @@ package com.welab.wefe.board.service.component.feature; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import org.apache.commons.collections4.CollectionUtils; -import org.springframework.beans.BeanUtils; -import org.springframework.stereotype.Service; - import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.component.base.AbstractComponent; import com.welab.wefe.board.service.component.base.filter.OutputDataTypesOutputFilter; @@ -38,10 +30,17 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.TaskResultType; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.stereotype.Service; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; @Service public class HorzStatisticComponent extends AbstractComponent { @@ -61,19 +60,18 @@ public boolean canSelectFeatures() { return true; } - @Override - protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, - Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - List members = params.members; - for (MemberFeatureInfoModel member : members) { - if (CacheObjects.getMemberId().equals(member.getMemberId())) { - List features = member.features; - taskParam.put("params", JObject.create("col_names", features)); - } - } - return taskParam; - } + @Override + protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, + Params params) throws FlowNodeException { + List members = params.members; + for (MemberFeatureInfoModel member : members) { + if (CacheObjects.getMemberId().equals(member.getMemberId())) { + List features = member.features; + return JObject.create("col_names", features); + } + } + return JObject.create(); + } @Override protected List getAllResult(String taskId) { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/MixBinningComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/MixBinningComponent.java index c368c18a8..ffb984770 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/MixBinningComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/MixBinningComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,16 +16,6 @@ package com.welab.wefe.board.service.component.feature; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; - -import org.apache.commons.collections4.CollectionUtils; -import org.springframework.beans.BeanUtils; -import org.springframework.stereotype.Service; - import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.component.DataIOComponent; import com.welab.wefe.board.service.component.base.AbstractComponent; @@ -39,12 +29,21 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskResultType; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; /** * @author Winter @@ -89,7 +88,6 @@ public ComponentType taskType() { protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); List members = params.members; int bin_num = 10; List bin_names = new ArrayList<>(); @@ -104,9 +102,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT } JObject binningParam = JObject.create().append("bin_num", bin_num).append("bin_names", bin_names); - taskParam.put("params", binningParam); - - return taskParam; + return binningParam; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/MixStatisticComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/MixStatisticComponent.java index 99c787eb9..21b344e24 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/MixStatisticComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/MixStatisticComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,14 +16,6 @@ package com.welab.wefe.board.service.component.feature; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import org.apache.commons.collections4.CollectionUtils; -import org.springframework.beans.BeanUtils; -import org.springframework.stereotype.Service; - import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.component.base.AbstractComponent; import com.welab.wefe.board.service.component.base.filter.OutputDataTypesOutputFilter; @@ -37,10 +29,17 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.TaskResultType; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.stereotype.Service; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; /** * @author Winter @@ -49,7 +48,7 @@ public class MixStatisticComponent extends AbstractComponent { @Override protected void checkBeforeBuildTask(FlowGraph graph, List preTasks, FlowGraphNode node, - Params params) { + Params params) { } @Override @@ -61,19 +60,18 @@ public ComponentType taskType() { public boolean canSelectFeatures() { return true; } - + @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, - Params params) { - JSONObject taskParam = new JSONObject(); - List members = params.members; - for (MemberFeatureInfoModel member : members) { - if (CacheObjects.getMemberId().equals(member.getMemberId())) { - List features = member.features; - taskParam.put("params", JObject.create("col_names", features)); - } - } - return taskParam; + Params params) { + + MemberFeatureInfoModel me = params.members + .stream() + .filter(member -> CacheObjects.getMemberId().equals(member.getMemberId())) + .findFirst() + .orElse(null); + + return JObject.create("col_names", me.features); } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertFilterComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertFilterComponent.java index 6d4b3bf39..fb02c9e68 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertFilterComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertFilterComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -34,10 +34,10 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.wefe.enums.ComponentType; @Service public class VertFilterComponent extends AbstractComponent { @@ -50,7 +50,6 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); JObject resultObj = JObject.create(); params.getMembers().forEach(member -> { if (CacheObjects.getMemberId().equals(member.getMemberId()) @@ -59,8 +58,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT } }); - taskParam.put("params", resultObj); - return taskParam; + return resultObj; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertOneHotComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertOneHotComponent.java index 39b463aea..8dc25d199 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertOneHotComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertOneHotComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -35,10 +35,10 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.wefe.enums.ComponentType; @Service public class VertOneHotComponent extends AbstractComponent { @@ -65,10 +65,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT }); } }); - taskParam.put("params", - JObject.create().append("transform_col_names", transformColNames).append("save_dataset", true)); - - return taskParam; + return JObject.create().append("transform_col_names", transformColNames).append("save_dataset", true); } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertPCAComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertPCAComponent.java index 2b6a76d08..b14032421 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertPCAComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertPCAComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -34,9 +34,9 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.TaskResultType; @Service public class VertPCAComponent extends AbstractComponent { @@ -57,9 +57,6 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, VertOneHotComponent.Params params) throws FlowNodeException { - - JSONObject taskParam = new JSONObject(); - JObject resultObj = JObject.create(); List featureList = new ArrayList<>(); params.getMembers().forEach(member -> { @@ -70,10 +67,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT } }); resultObj.append("column_names", featureList); - - taskParam.put("params", resultObj); - - return taskParam; + return resultObj; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertPearsonComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertPearsonComponent.java index b262be347..bc5f0e451 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertPearsonComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/feature/VertPearsonComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -29,12 +29,12 @@ import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.springframework.stereotype.Service; import java.util.ArrayList; @@ -63,9 +63,6 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - - JSONObject taskParam = new JSONObject(); - // Need to find DataIO data set FlowGraphNode dataIONode = graph.findOneNodeFromParent(node, ComponentType.DataIO); TaskMySqlModel dataIOTask = findTaskFromPretasks(preTasks, dataIONode); @@ -73,26 +70,24 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT throw new FlowNodeException(node, "请添加DataIO组件!"); } - JObject resultObj = JObject.create(); + JObject output = JObject.create(); params.getMembers().forEach(x -> { - if (x.getMemberId().equals(CacheObjects.getMemberId())) { + if (x.getMemberId().equals(CacheObjects.getMemberId()) && graph.getJob().getMyRole() == x.getMemberRole()) { List features = new ArrayList<>(); x.getFeatures().forEach(feature -> { if (StringUtil.isNotEmpty(feature.getMethod())) { features.add(feature.getName()); } }); - resultObj.append("column_names", features); + output.append("column_names", features); } }); - resultObj.append("cross_parties", params.isCrossParties()); - - taskParam.put("params", resultObj); + output.append("cross_parties", params.isCrossParties()); - return taskParam; + return output; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/AbstractModelingComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/AbstractModelingComponent.java index 6000cd0ca..b4bb384b2 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/AbstractModelingComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/AbstractModelingComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,10 +21,10 @@ import com.welab.wefe.board.service.component.base.io.*; import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.springframework.beans.BeanUtils; import java.util.ArrayList; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzLRComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzLRComponent.java index 6dc313438..a8c5258fa 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzLRComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzLRComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,11 +26,11 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.dto.AbstractLRInput; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.stereotype.Service; import java.util.Arrays; @@ -56,10 +56,8 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - - JObject vertLRParam = JObject.create(); - vertLRParam.append("penalty", params.otherParam.penalty) + JObject output = JObject.create(); + output.append("penalty", params.otherParam.penalty) .append("tol", params.otherParam.tol) .append("alpha", params.otherParam.alpha) .append("optimizer", params.otherParam.optimizer) @@ -76,9 +74,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("shuffle", params.getCvParam().isShuffle()) .append("need_cv", params.getCvParam().isNeedCv()); - taskParam.put("params", vertLRParam); - - return taskParam; + return output; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzNNComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzNNComponent.java index e1db9c59e..1b1b69d30 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzNNComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzNNComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,6 @@ package com.welab.wefe.board.service.component.modeling; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.springframework.stereotype.Service; - import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.component.base.io.IODataType; import com.welab.wefe.board.service.component.base.io.InputMatcher; @@ -32,10 +26,15 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; @Service public class HorzNNComponent extends AbstractModelingComponent { @@ -48,7 +47,7 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); +// JSONObject taskParam = new JSONObject(); JObject horzNNParam = JObject.create(); horzNNParam.append("encode_label", false).append("max_iter", params.maxIter).append("batch_size", params.batchSize); @@ -67,8 +66,8 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT JObject nnDefine = JObject.create().append("class_name", "Sequential").append("layers", params.nnDefine.layers); horzNNParam.append("nn_define", nnDefine).append("config_type", "keras"); - taskParam.put("params", horzNNParam); - return taskParam; +// taskParam.put("params", horzNNParam); + return horzNNParam; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzSecureBoostComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzSecureBoostComponent.java index 677da2a7c..c18ffdc27 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzSecureBoostComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/HorzSecureBoostComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,11 +26,11 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.dto.AbstractSecureBoostInput; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.stereotype.Service; import java.util.Arrays; @@ -55,16 +55,13 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - - JObject horzSecureBoostParam = JObject.create(); + JObject output = JObject.create(); JObject treeParam = JObject.create().append("criterion_method", "xgboost") .append("criterion_params", params.getTreeParam().getCriterionParams()) .append("max_depth", params.getTreeParam().getMaxDepth()) .append("min_sample_split", params.getTreeParam().getMinSampleSplit()) .append("min_impurity_split", params.getTreeParam().getMinImpuritySplit()) - .append("min_leaf_node", params.getTreeParam().getMinLeafNode()) - .append("max_split_nodes", params.getTreeParam().getMaxSplitNodes()); + .append("min_leaf_node", params.getTreeParam().getMinLeafNode()); JObject objectiveParam = JObject.create().append("objective", params.getObjectiveParam().getObjective()) .append("params", params.getObjectiveParam().getParams()); @@ -75,7 +72,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("need_cv", params.getCvParam().isNeedCv()); - horzSecureBoostParam.append("task_type", params.otherParam.taskType) + output.append("task_type", params.otherParam.taskType) .append("learning_rate", params.otherParam.learningRate) .append("num_trees", params.otherParam.numTrees) .append("subsample_feature_rate", params.otherParam.subsampleFeatureRate) @@ -86,9 +83,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("objective_param", objectiveParam) .append("cv_param", cvParam); - taskParam.put("params", horzSecureBoostParam); - - return taskParam; + return output; } @Override @@ -143,7 +138,7 @@ public static class OtherParam extends AbstractCheckModel { private float subsampleFeatureRate; @Check(name = "多次迭代无变化是允许停止", require = true) private boolean nIterNoChange; - @Check(name = "收敛阀值", require = true) + @Check(name = "收敛阈值", require = true) private float tol; @Check(name = "最大分箱数", require = true) private int binNum; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/MixLrComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/MixLrComponent.java index bb22b5b3f..8e5cfb54b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/MixLrComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/MixLrComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -25,11 +25,11 @@ import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.dto.AbstractLRInput; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.stereotype.Service; import java.util.Arrays; @@ -55,10 +55,8 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) { - JSONObject taskParam = new JSONObject(); - - JObject vertLRParam = JObject.create(); - vertLRParam.append("penalty", params.otherParam.penalty) + JObject output = JObject.create(); + output.append("penalty", params.otherParam.penalty) .append("tol", params.otherParam.tol) .append("alpha", params.otherParam.alpha) .append("optimizer", params.otherParam.optimizer) @@ -75,9 +73,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("shuffle", params.getCvParam().isShuffle()) .append("need_cv", params.getCvParam().isNeedCv()); - taskParam.put("params", vertLRParam); - - return taskParam; + return output; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/MixSecureBoostComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/MixSecureBoostComponent.java index 461ecd911..70bb5b147 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/MixSecureBoostComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/MixSecureBoostComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -27,11 +27,11 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.dto.AbstractSecureBoostInput; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.stereotype.Service; import java.util.Arrays; @@ -55,17 +55,13 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - - JSONObject taskParam = new JSONObject(); - - JObject vertSecureBoostParam = JObject.create(); + JObject output = JObject.create(); JObject treeParam = JObject.create().append("criterion_method", "xgboost") .append("criterion_params", params.getTreeParam().getCriterionParams()) .append("max_depth", params.getTreeParam().getMaxDepth()) .append("min_sample_split", params.getTreeParam().getMinSampleSplit()) .append("min_impurity_split", params.getTreeParam().getMinImpuritySplit()) - .append("min_leaf_node", params.getTreeParam().getMinLeafNode()) - .append("max_split_nodes", params.getTreeParam().getMaxSplitNodes()); + .append("min_leaf_node", params.getTreeParam().getMinLeafNode()); JObject objectiveParam = JObject.create().append("objective", params.getObjectiveParam().getObjective()) .append("params", params.getObjectiveParam().getParams()); @@ -78,7 +74,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("shuffle", params.getCvParam().isShuffle()) .append("need_cv", params.getCvParam().isNeedCv()); - vertSecureBoostParam.append("task_type", params.otherParam.taskType) + output.append("task_type", params.otherParam.taskType) .append("learning_rate", params.otherParam.learningRate) .append("num_trees", params.otherParam.numTrees) .append("subsample_feature_rate", params.otherParam.subsampleFeatureRate) @@ -92,9 +88,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("encrypt_param", encryptParam) .append("cv_param", cvParam); - taskParam.put("params", vertSecureBoostParam); - - return taskParam; + return output; } @Override @@ -178,7 +172,7 @@ public static class OtherParam extends AbstractCheckModel { private float subsampleFeatureRate; @Check(name = "多次迭代无变化是允许停止", require = true) private boolean nIterNoChange; - @Check(name = "收敛阀值", require = true) + @Check(name = "收敛阈值", require = true) private float tol; @Check(name = "最大分箱数", require = true) private int binNum; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertLRComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertLRComponent.java index 0771972bb..59e7639f6 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertLRComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertLRComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,27 +16,29 @@ package com.welab.wefe.board.service.component.modeling; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.springframework.stereotype.Service; + import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.component.base.io.IODataType; import com.welab.wefe.board.service.component.base.io.InputMatcher; import com.welab.wefe.board.service.component.base.io.Names; import com.welab.wefe.board.service.component.base.io.OutputItem; +import com.welab.wefe.board.service.database.entity.job.JobMemberMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.dto.AbstractLRInput; -import org.springframework.stereotype.Service; - -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.TaskResultType; /** * @author lonnie @@ -50,6 +52,11 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas if (intersectionNode == null) { throw new FlowNodeException(node, "请在前面添加样本对齐组件。"); } + List jobMembers = graph.getMembers(); + long memberCount = jobMembers.size(); + if (memberCount > 2 && "sshe-lr".equalsIgnoreCase(params.getOtherParam().getLrMethod())) { + throw new FlowNodeException(node, "sshe-lr 只支持两个参与方"); + } } @@ -61,10 +68,8 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - - JObject vertLRParam = JObject.create(); - vertLRParam.append("penalty", params.otherParam.penalty) + JObject output = JObject.create(); + output.append("penalty", params.otherParam.penalty) .append("tol", params.otherParam.tol) .append("alpha", params.otherParam.alpha) .append("optimizer", params.otherParam.optimizer) @@ -83,11 +88,10 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("key_length", 1024) .append("n_splits", params.getCvParam().getnSplits()) .append("shuffle", params.getCvParam().isShuffle()) - .append("need_cv", params.getCvParam().isNeedCv()); - - taskParam.put("params", vertLRParam); + .append("need_cv", params.getCvParam().isNeedCv()) + .append("lr_method", params.getOtherParam().getLrMethod()); - return taskParam; + return output; } @Override @@ -148,6 +152,8 @@ public void setOtherParam(OtherParam otherParam) { } public static class OtherParam extends AbstractCheckModel { + @Check(name = "LR算法", require = true) + private String lrMethod; @Check(name = "惩罚方式", require = true) private String penalty; @@ -187,7 +193,15 @@ public static class OtherParam extends AbstractCheckModel { @Check(name = "提前结束的迭代次数", require = true) private int earlyStoppingRounds; - public String getPenalty() { + public String getLrMethod() { + return lrMethod; + } + + public void setLrMethod(String lrMethod) { + this.lrMethod = lrMethod; + } + + public String getPenalty() { return penalty; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertNNComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertNNComponent.java index 5020780a0..0a6abe974 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertNNComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertNNComponent.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,6 @@ package com.welab.wefe.board.service.component.modeling; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.springframework.stereotype.Service; - import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.component.base.io.IODataType; import com.welab.wefe.board.service.component.base.io.InputMatcher; @@ -33,11 +27,16 @@ import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; @Service public class VertNNComponent extends AbstractModelingComponent { @@ -55,7 +54,7 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); +// JSONObject taskParam = new JSONObject(); JObject vertNNParam = JObject.create(); vertNNParam.append("epochs", params.epochs).append("interactive_layer_lr", params.interactiveLayerLr) @@ -84,9 +83,9 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT vertNNParam.append("config_type", "keras"); - taskParam.put("params", vertNNParam); +// taskParam.put("params", vertNNParam); - return taskParam; + return vertNNParam; } @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertSecureBoostComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertSecureBoostComponent.java index ab3f0f606..69e17939e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertSecureBoostComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/modeling/VertSecureBoostComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,16 +22,17 @@ import com.welab.wefe.board.service.component.base.io.InputMatcher; import com.welab.wefe.board.service.component.base.io.Names; import com.welab.wefe.board.service.component.base.io.OutputItem; +import com.welab.wefe.board.service.database.entity.job.JobMemberMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.common.enums.ComponentType; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.dto.AbstractSecureBoostInput; +import com.welab.wefe.common.wefe.enums.ComponentType; import org.springframework.stereotype.Service; import java.util.Arrays; @@ -48,6 +49,12 @@ protected void checkBeforeBuildTask(FlowGraph graph, List preTas if (intersectionNode == null) { throw new FlowNodeException(node, "请在前面添加样本对齐组件。"); } + + List jobMembers = graph.getMembers(); + long memberCount = jobMembers.size(); + if (memberCount > 2 && "layered".equalsIgnoreCase(params.otherParam.workMode)) { + throw new FlowNodeException(node, "layered模式 只支持两个参与方"); + } } @@ -59,16 +66,13 @@ public ComponentType taskType() { @Override protected JSONObject createTaskParams(FlowGraph graph, List preTasks, FlowGraphNode node, Params params) throws FlowNodeException { - JSONObject taskParam = new JSONObject(); - - JObject vertSecureBoostParam = JObject.create(); + JObject output = JObject.create(); JObject treeParam = JObject.create().append("criterion_method", "xgboost") .append("criterion_params", params.getTreeParam().getCriterionParams()) .append("max_depth", params.getTreeParam().getMaxDepth()) .append("min_sample_split", params.getTreeParam().getMinSampleSplit()) .append("min_impurity_split", params.getTreeParam().getMinImpuritySplit()) - .append("min_leaf_node", params.getTreeParam().getMinLeafNode()) - .append("max_split_nodes", params.getTreeParam().getMaxSplitNodes()); + .append("min_leaf_node", params.getTreeParam().getMinLeafNode()); JObject objectiveParam = JObject.create().append("objective", params.getObjectiveParam().getObjective()) .append("params", params.getObjectiveParam().getParams()); @@ -81,7 +85,7 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("shuffle", params.getCvParam().isShuffle()) .append("need_cv", params.getCvParam().isNeedCv()); - vertSecureBoostParam.append("task_type", params.otherParam.taskType) + output.append("task_type", params.otherParam.taskType) .append("learning_rate", params.otherParam.learningRate) .append("num_trees", params.otherParam.numTrees) .append("subsample_feature_rate", params.otherParam.subsampleFeatureRate) @@ -93,11 +97,16 @@ protected JSONObject createTaskParams(FlowGraph graph, List preT .append("tree_param", treeParam) .append("objective_param", objectiveParam) .append("encrypt_param", encryptParam) - .append("cv_param", cvParam); - - taskParam.put("params", vertSecureBoostParam); - - return taskParam; + .append("cv_param", cvParam).append("work_mode", params.otherParam.workMode); + if ("layered".equalsIgnoreCase(params.otherParam.workMode)) { + output.append("promoter_depth", params.otherParam.promoterDepth).append("provider_depth", + params.otherParam.providerDepth); + } else if ("skip".equalsIgnoreCase(params.otherParam.workMode)) { + output.append("tree_num_per_member", params.otherParam.treeNumPerMember); + } else if ("dp".equalsIgnoreCase(params.otherParam.workMode)) { + output.append("epsilon", params.otherParam.epsilon); + } + return output; } @Override @@ -137,6 +146,8 @@ public static class Params extends AbstractSecureBoostInput { @Check(require = true) private EncryptParam encryptParam; + + public EncryptParam getEncryptParam() { return encryptParam; @@ -181,7 +192,7 @@ public static class OtherParam extends AbstractCheckModel { private float subsampleFeatureRate; @Check(name = "多次迭代无变化是允许停止", require = true) private boolean nIterNoChange; - @Check(name = "收敛阀值", require = true) + @Check(name = "收敛阈值", require = true) private float tol; @Check(name = "最大分箱数", require = true) private int binNum; @@ -189,6 +200,22 @@ public static class OtherParam extends AbstractCheckModel { private int validationFreqs; @Check(name = "允许提前结束的最小迭代次数", require = true) private int earlyStoppingRounds; + @Check(name = "工作模式") + private String workMode = "normal"; // normal、layered、skip + + // 当work_mode==layered时,需要下面两个参数 + @Check(name = "promoter层数") + private int promoterDepth; + @Check(name = "provider层数") + private int providerDepth; + + // 当work_mode==skip时,需要下面这个参数 + @Check(name = "单方每次构建树的数量") + private int treeNumPerMember; + + // 当work_mode==dp时 隐私预算 + @Check(name = "隐私预算") // 1.22 + private float epsilon; public String getTaskType() { return taskType; @@ -260,6 +287,46 @@ public int getEarlyStoppingRounds() { public void setEarlyStoppingRounds(int earlyStoppingRounds) { this.earlyStoppingRounds = earlyStoppingRounds; + } + + public String getWorkMode() { + return workMode; + } + + public void setWorkMode(String workMode) { + this.workMode = workMode; + } + + public int getPromoterDepth() { + return promoterDepth; + } + + public void setPromoterDepth(int promoterDepth) { + this.promoterDepth = promoterDepth; + } + + public int getProviderDepth() { + return providerDepth; + } + + public void setProviderDepth(int providerDepth) { + this.providerDepth = providerDepth; + } + + public int getTreeNumPerMember() { + return treeNumPerMember; + } + + public void setTreeNumPerMember(int treeNumPerMember) { + this.treeNumPerMember = treeNumPerMember; + } + + public float getEpsilon() { + return epsilon; + } + + public void setEpsilon(float epsilon) { + this.epsilon = epsilon; } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/component/temp/AbstractValidationDataSetLoaderComponent.java b/board/board-service/src/main/java/com/welab/wefe/board/service/component/temp/AbstractValidationDataSetLoaderComponent.java index 6700ceb5a..726f6a022 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/component/temp/AbstractValidationDataSetLoaderComponent.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/component/temp/AbstractValidationDataSetLoaderComponent.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/config/WebSocketConfig.java b/board/board-service/src/main/java/com/welab/wefe/board/service/config/WebSocketConfig.java index 016552c1e..e80df8d9c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/config/WebSocketConfig.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/config/WebSocketConfig.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/constant/BloomfilterAddMethod.java b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/BloomfilterAddMethod.java new file mode 100644 index 000000000..36a62292a --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/BloomfilterAddMethod.java @@ -0,0 +1,37 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.constant; + +/** + * How to add bloom_filter + * + * @author jacky.jiang + */ +public enum BloomfilterAddMethod { + /** + * + */ + HttpUpload, + /** + * + */ + LocalFile, + /** + * + */ + Database +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/constant/ChatConstant.java b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/ChatConstant.java index 696d1b9a6..48effe871 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/constant/ChatConstant.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/ChatConstant.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/constant/Config.java b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/Config.java index a4ecd930e..0b381f10c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/constant/Config.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/Config.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,9 +17,7 @@ package com.welab.wefe.board.service.constant; import com.welab.wefe.common.data.storage.common.DBType; -import com.welab.wefe.common.enums.JobBackendType; -import com.welab.wefe.common.enums.env.EnvBranch; -import com.welab.wefe.common.enums.env.EnvName; +import com.welab.wefe.common.web.config.CommonConfig; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.context.annotation.PropertySource; @@ -33,51 +31,45 @@ @Component @PropertySource(value = {"file:${config.path}"}, encoding = "utf-8") @ConfigurationProperties -public class Config { +public class Config extends CommonConfig { - @Value("${wefe.union.base-url}") - private String UNION_BASE_URL; + @Value("${wefe.job.work_mode}") + private Integer workMode; + @Value("${db.storage.type}") + private DBType dbType; - @Value("${wefe.file.upload.dir}") - private String fileUploadDir; + @Value("${fc.access_key_id:xxx}") + private String fcAccessKeyId; + @Value("${fc.access_key_secret:xxx}") + private String fcAccessKeySecret; - @Value("${wefe.job.work_mode}") - private Integer workMode; + @Value("${verification.code.send.channel:email}") + private String verificationCodeSendChannel; - @Value("${wefe.job.backend}") - private JobBackendType backend; + @Value("${sms.aliyun.sign.name:xxx}") + private String smsAliyunSignName; - @Value("${db.storage.type}") - private DBType dbType; + @Value("${sms.aliyun.account.forget.password.verification.code.template.code:xxx}") + private String smsAliyunAccountForgetPasswordVerificationCodeTemplateCode; - @Value("${env.name}") - private EnvName envName; + @Value("${sms.aliyun.member.register.verification.code.template.code:xxx}") + private String smsAliyunMemberregisterVerificationCodeTemplateCode; - /** - * The branch of the environment, different branches will have different functions. - *

- * online_demo: You can only delete data created by yourself(eg:flow、member、data_set) - */ - @Value("${env.branch:master}") - private EnvBranch envBranch; + @Value("${email.account.forget.password.subject:忘记密码}") + private String emailAccountForgetPasswordSubject; - public String getUNION_BASE_URL() { - return UNION_BASE_URL; - } + @Value("${email.account.forget.password.content:您正在执行忘记密码操作。您的验证码是#code#,2分钟内有效,请勿泄漏于他人!}") + private String emailAccountForgetPasswordContent; - public void setUNION_BASE_URL(String UNION_BASE_URL) { - this.UNION_BASE_URL = UNION_BASE_URL; - } + @Value("${sm4.secret.key:}") + private String sm4SecretKey; - public String getFileUploadDir() { - return fileUploadDir; - } + @Value("${encrypt.phone.number.open:false}") + private boolean encryptPhoneNumberOpen; - public void setFileUploadDir(String fileUploadDir) { - this.fileUploadDir = fileUploadDir; - } + // region getter/setter public Integer getWorkMode() { return workMode; @@ -87,14 +79,6 @@ public void setWorkMode(Integer workMode) { this.workMode = workMode; } - public JobBackendType getBackend() { - return backend; - } - - public void setBackend(JobBackendType backend) { - this.backend = backend; - } - public DBType getDbType() { return dbType; } @@ -103,24 +87,86 @@ public void setDbType(DBType dbType) { this.dbType = dbType; } - public EnvName getEnvName() { - return envName; + public String getVerificationCodeSendChannel() { + return verificationCodeSendChannel; + } + + public void setVerificationCodeSendChannel(String verificationCodeSendChannel) { + this.verificationCodeSendChannel = verificationCodeSendChannel; + } + + public String getSmsAliyunSignName() { + return smsAliyunSignName; + } + + public void setSmsAliyunSignName(String smsAliyunSignName) { + this.smsAliyunSignName = smsAliyunSignName; } - public void setEnvName(EnvName envName) { - this.envName = envName; + public String getSmsAliyunAccountForgetPasswordVerificationCodeTemplateCode() { + return smsAliyunAccountForgetPasswordVerificationCodeTemplateCode; } - public EnvBranch getEnvBranch() { - return envBranch; + public void setSmsAliyunAccountForgetPasswordVerificationCodeTemplateCode(String smsAliyunAccountForgetPasswordVerificationCodeTemplateCode) { + this.smsAliyunAccountForgetPasswordVerificationCodeTemplateCode = smsAliyunAccountForgetPasswordVerificationCodeTemplateCode; } - public void setEnvBranch(EnvBranch envBranch) { - this.envBranch = envBranch; + public String getSmsAliyunMemberregisterVerificationCodeTemplateCode() { + return smsAliyunMemberregisterVerificationCodeTemplateCode; } - public boolean isOnlineDemo() { - return envBranch == EnvBranch.online_demo; + public void setSmsAliyunMemberregisterVerificationCodeTemplateCode(String smsAliyunMemberregisterVerificationCodeTemplateCode) { + this.smsAliyunMemberregisterVerificationCodeTemplateCode = smsAliyunMemberregisterVerificationCodeTemplateCode; } + public String getEmailAccountForgetPasswordSubject() { + return emailAccountForgetPasswordSubject; + } + + public void setEmailAccountForgetPasswordSubject(String emailAccountForgetPasswordSubject) { + this.emailAccountForgetPasswordSubject = emailAccountForgetPasswordSubject; + } + + public String getEmailAccountForgetPasswordContent() { + return emailAccountForgetPasswordContent; + } + + public void setEmailAccountForgetPasswordContent(String emailAccountForgetPasswordContent) { + this.emailAccountForgetPasswordContent = emailAccountForgetPasswordContent; + } + + public String getFcAccessKeyId() { + return fcAccessKeyId; + } + + public void setFcAccessKeyId(String fcAccessKeyId) { + this.fcAccessKeyId = fcAccessKeyId; + } + + public String getFcAccessKeySecret() { + return fcAccessKeySecret; + } + + public void setFcAccessKeySecret(String fcAccessKeySecret) { + this.fcAccessKeySecret = fcAccessKeySecret; + } + + public String getSm4SecretKey() { + return sm4SecretKey; + } + + public void setSm4SecretKey(String sm4SecretKey) { + this.sm4SecretKey = sm4SecretKey; + } + + public boolean isEncryptPhoneNumberOpen() { + return encryptPhoneNumberOpen; + } + + public void setEncryptPhoneNumberOpen(boolean encryptPhoneNumberOpen) { + this.encryptPhoneNumberOpen = encryptPhoneNumberOpen; + } + + // endregion + } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/constant/DataSetAddMethod.java b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/DataSetAddMethod.java index 8f866f56b..690024aab 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/constant/DataSetAddMethod.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/DataSetAddMethod.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/constant/ServiceStatus.java b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/ServiceStatus.java index 4343a1bfe..362accac5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/constant/ServiceStatus.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/constant/ServiceStatus.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/DataSourceConfig.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/DataSourceConfig.java index 024a69f78..39c1a78dc 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/DataSourceConfig.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/DataSourceConfig.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/AccountMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/AccountMySqlModel.java deleted file mode 100644 index cfa7e4601..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/AccountMySqlModel.java +++ /dev/null @@ -1,159 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.database.entity; - -import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.AuditStatus; - -import javax.persistence.Entity; -import javax.persistence.EnumType; -import javax.persistence.Enumerated; - -/** - * @author Zane - */ -@Entity(name = "account") -public class AccountMySqlModel extends AbstractBaseMySqlModel { - - /** - * 手机号 - */ - private String phoneNumber; - /** - * 密码 - */ - private String password; - /** - * 盐 - */ - private String salt; - /** - * 昵称 - */ - private String nickname; - /** - * 邮箱 - */ - private String email; - /** - * 是否是超级管理员;超级管理员通常是第一个创建并初始化系统的那个人 - */ - private Boolean superAdminRole; - /** - * 是否是管理员;管理员有更多权限,比如设置 member 是否对外可见。 - */ - private Boolean adminRole; - /** - * 审核状态 - */ - @Enumerated(EnumType.STRING) - private AuditStatus auditStatus; - /** - * 审核意见 - */ - private String auditComment; - - /** - * 是否可用 - */ - private Boolean enable; - - - //region getter/setter - - public String getPhoneNumber() { - return phoneNumber; - } - - public void setPhoneNumber(String phoneNumber) { - this.phoneNumber = phoneNumber; - } - - public String getPassword() { - return password; - } - - public void setPassword(String password) { - this.password = password; - } - - public String getSalt() { - return salt; - } - - public void setSalt(String salt) { - this.salt = salt; - } - - public String getNickname() { - return nickname; - } - - public void setNickname(String nickname) { - this.nickname = nickname; - } - - public String getEmail() { - return email; - } - - public void setEmail(String email) { - this.email = email; - } - - public Boolean getSuperAdminRole() { - return superAdminRole; - } - - public void setSuperAdminRole(Boolean superAdminRole) { - this.superAdminRole = superAdminRole; - } - - public Boolean getAdminRole() { - return adminRole; - } - - public void setAdminRole(Boolean adminRole) { - this.adminRole = adminRole; - } - - public AuditStatus getAuditStatus() { - return auditStatus; - } - - public void setAuditStatus(AuditStatus auditStatus) { - this.auditStatus = auditStatus; - } - - public String getAuditComment() { - return auditComment; - } - - public void setAuditComment(String auditComment) { - this.auditComment = auditComment; - } - - public Boolean getEnable() { - return enable; - } - - public void setEnable(Boolean enable) { - this.enable = enable; - } - - //endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/AccountMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/AccountMysqlModel.java new file mode 100644 index 000000000..db5452a31 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/AccountMysqlModel.java @@ -0,0 +1,204 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity; + +import com.alibaba.fastjson.JSONArray; +import com.vladmihalcea.hibernate.type.json.JsonStringType; +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.board.service.database.listener.AccountMysqlModelListener; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import org.hibernate.annotations.Type; +import org.hibernate.annotations.TypeDef; + +import javax.persistence.*; +import java.util.Date; + +/** + * @author Zane + */ +@Entity(name = "account") +@TypeDef(name = "json", typeClass = JsonStringType.class) +@EntityListeners(AccountMysqlModelListener.class) +public class AccountMysqlModel extends AbstractBaseMySqlModel { + + /** + * 手机号 + */ + private String phoneNumber; + /** + * 密码 + */ + private String password; + /** + * 盐 + */ + private String salt; + /** + * 昵称 + */ + private String nickname; + /** + * 邮箱 + */ + private String email; + /** + * 是否是超级管理员;超级管理员通常是第一个创建并初始化系统的那个人 + */ + private Boolean superAdminRole; + /** + * 是否是管理员;管理员有更多权限,比如设置 member 是否对外可见。 + */ + private Boolean adminRole; + /** + * 审核状态 + */ + @Enumerated(EnumType.STRING) + private AuditStatus auditStatus; + /** + * 审核意见 + */ + private String auditComment; + + /** + * 是否可用 + */ + private Boolean enable; + /** + * 是否已注销 + */ + private boolean cancelled; + /** + * 最后活动时间 + */ + private Date lastActionTime; + /** + * 历史曾用密码 + */ + @Type(type = "json") + @Column(columnDefinition = "json") + private JSONArray historyPasswordList; + + + //region getter/setter + + public String getPhoneNumber() throws StatusCodeWithException { + return phoneNumber; + } + + public void setPhoneNumber(String phoneNumber) throws StatusCodeWithException { + this.phoneNumber = phoneNumber; + } + + public String getPassword() { + return password; + } + + public void setPassword(String password) { + this.password = password; + } + + public String getSalt() { + return salt; + } + + public void setSalt(String salt) { + this.salt = salt; + } + + public String getNickname() { + return nickname; + } + + public void setNickname(String nickname) { + this.nickname = nickname; + } + + public String getEmail() { + return email; + } + + public void setEmail(String email) { + this.email = email; + } + + public Boolean getSuperAdminRole() { + return superAdminRole; + } + + public void setSuperAdminRole(Boolean superAdminRole) { + this.superAdminRole = superAdminRole; + } + + public Boolean getAdminRole() { + return adminRole; + } + + public void setAdminRole(Boolean adminRole) { + this.adminRole = adminRole; + } + + public AuditStatus getAuditStatus() { + return auditStatus; + } + + public void setAuditStatus(AuditStatus auditStatus) { + this.auditStatus = auditStatus; + } + + public String getAuditComment() { + return auditComment; + } + + public void setAuditComment(String auditComment) { + this.auditComment = auditComment; + } + + public Boolean getEnable() { + return enable; + } + + public void setEnable(Boolean enable) { + this.enable = enable; + } + + public boolean isCancelled() { + return cancelled; + } + + public void setCancelled(boolean cancelled) { + this.cancelled = cancelled; + } + + public Date getLastActionTime() { + return lastActionTime; + } + + public void setLastActionTime(Date lastActionTime) { + this.lastActionTime = lastActionTime; + } + + public JSONArray getHistoryPasswordList() { + return historyPasswordList; + } + + public void setHistoryPasswordList(JSONArray historyPasswordList) { + this.historyPasswordList = historyPasswordList; + } + + //endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/BlacklistMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/BlacklistMysqlModel.java index 2fa905a07..7b6d6e061 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/BlacklistMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/BlacklistMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/DataOutputInfoMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/DataOutputInfoMysqlModel.java index fb84af4ac..a675cfa0e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/DataOutputInfoMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/DataOutputInfoMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/DataSourceMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/DataSourceMysqlModel.java new file mode 100644 index 000000000..fd41083a5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/DataSourceMysqlModel.java @@ -0,0 +1,124 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity; + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.common.wefe.enums.DatabaseType; + +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; + +/** + * 数据来源,目的:从指定的数据库中读取数据,并上传到ck作为原始的数据集 + * + * @author Johnny.lin + */ +@Entity(name = "data_source") +public class DataSourceMysqlModel extends AbstractBaseMySqlModel { + /** + * 数据源名称 + */ + private String name; + + /** + * 数据库类型,枚举(hive、impala、mysql) + */ + @Enumerated(EnumType.STRING) + private DatabaseType databaseType; + + /** + * 数据库IP地址 + */ + private String host; + + /** + * 端口 + */ + private Integer port; + + /** + * 要连接的数据库名称 + */ + private String databaseName; + + /** + * 用户名 + */ + private String userName; + + /** + * 密码 + */ + private String password; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public DatabaseType getDatabaseType() { + return databaseType; + } + + public void setDatabaseType(DatabaseType databaseType) { + this.databaseType = databaseType; + } + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public Integer getPort() { + return port; + } + + public void setPort(Integer port) { + this.port = port; + } + + public String getDatabaseName() { + return databaseName; + } + + public void setDatabaseName(String databaseName) { + this.databaseName = databaseName; + } + + public String getUserName() { + return userName; + } + + public void setUserName(String userName) { + this.userName = userName; + } + + public String getPassword() { + return password; + } + + public void setPassword(String password) { + this.password = password; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/GlobalConfigMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/GlobalConfigMySqlModel.java deleted file mode 100644 index f3723c8b5..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/GlobalConfigMySqlModel.java +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.database.entity; - -import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; - -import javax.persistence.Column; -import javax.persistence.Entity; - -/** - * @author Zane - */ -@Entity(name = "global_config") -public class GlobalConfigMySqlModel extends AbstractBaseMySqlModel { - /** - * 配置项所在的组 - */ - @Column(name = "`group`") - private String group; - /** - * 配置项名称 - */ - private String name; - /** - * 配置项的值 - */ - private String value; - /** - * 配置项的解释说明 - */ - private String comment; - - // region getter/setter - - public String getGroup() { - return group; - } - - public void setGroup(String group) { - this.group = group; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public String getValue() { - return value; - } - - public void setValue(String value) { - this.value = value; - } - - public String getComment() { - return comment; - } - - public void setComment(String comment) { - this.comment = comment; - } - - // endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/GlobalConfigMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/GlobalConfigMysqlModel.java new file mode 100644 index 000000000..a6a7e222b --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/GlobalConfigMysqlModel.java @@ -0,0 +1,85 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity; + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.board.service.database.listener.GlobalConfigMysqlModelListener; + +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EntityListeners; + +/** + * @author Zane + */ +@Entity(name = "global_config") +@EntityListeners(GlobalConfigMysqlModelListener.class) +public class GlobalConfigMysqlModel extends AbstractBaseMySqlModel { + /** + * 配置项所在的组 + */ + @Column(name = "`group`") + private String group; + /** + * 配置项名称 + */ + private String name; + /** + * 配置项的值 + */ + private String value; + /** + * 配置项的解释说明 + */ + private String comment; + + // region getter/setter + + public String getGroup() { + return group; + } + + public void setGroup(String group) { + this.group = group; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getValue() { + return value; + } + + public void setValue(String value) { + this.value = value; + } + + public String getComment() { + return comment; + } + + public void setComment(String comment) { + this.comment = comment; + } + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/MessageMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/MessageMysqlModel.java index 5ba04c882..931297cae 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/MessageMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/MessageMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,8 @@ package com.welab.wefe.board.service.database.entity; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.MessageLevel; -import com.welab.wefe.common.enums.ProducerType; +import com.welab.wefe.common.wefe.enums.MessageLevel; +import com.welab.wefe.common.wefe.enums.ProducerType; import javax.persistence.Entity; import javax.persistence.EnumType; @@ -49,7 +49,7 @@ public class MessageMysqlModel extends AbstractBaseMySqlModel { */ private String content; /** - * 未读 + * 是否未读 */ private Boolean unread; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/OperationLogMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/OperationLogMysqlModel.java index 74453295d..64290d019 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/OperationLogMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/OperationLogMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -47,11 +47,6 @@ public class OperationLogMysqlModel extends AbstractMySqlModel { */ private String operatorId; - /** - * 操作人员手机号 - */ - private String operatorPhone; - /** * 请求token */ @@ -114,14 +109,6 @@ public void setOperatorId(String operatorId) { this.operatorId = operatorId; } - public String getOperatorPhone() { - return operatorPhone; - } - - public void setOperatorPhone(String operatorPhone) { - this.operatorPhone = operatorPhone; - } - public String getToken() { return token; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/OutputModelMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/OutputModelMysqlModel.java index 97274b6bf..fc1ebea18 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/OutputModelMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/OutputModelMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/TrackingMetricMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/TrackingMetricMysqlModel.java index 7ac0e06e7..262955f84 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/TrackingMetricMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/TrackingMetricMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/VerificationCodeMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/VerificationCodeMysqlModel.java new file mode 100644 index 000000000..872d476ff --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/VerificationCodeMysqlModel.java @@ -0,0 +1,122 @@ +/** + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity; + +import com.welab.wefe.board.service.database.entity.base.AbstractMySqlModel; +import com.welab.wefe.board.service.database.listener.AccountMysqlModelListener; +import com.welab.wefe.board.service.database.listener.VerificationCodeMysqlModelListener; +import com.welab.wefe.common.wefe.enums.VerificationCodeBusinessType; +import com.welab.wefe.common.wefe.enums.VerificationCodeSendChannel; + +import javax.persistence.Entity; +import javax.persistence.EntityListeners; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; + +/** + * Verification code model + * + * @author aaron.li + * @date 2022/1/19 17:45 + **/ +@Entity(name = "verification_code") +@EntityListeners(VerificationCodeMysqlModelListener.class) +public class VerificationCodeMysqlModel extends AbstractMySqlModel { + /** + * Business id, This field can be used to associate business information + */ + private String bizId; + /** + * mobile + */ + private String mobile; + /** + * Verification code + */ + private String code; + /** + * Whether the verification code is sent successfully. true or false + */ + private String success; + /** + * Verification code send channel + */ + @Enumerated(EnumType.STRING) + private VerificationCodeSendChannel sendChannel; + /** + * Verification code business type + */ + @Enumerated(EnumType.STRING) + private VerificationCodeBusinessType businessType; + private String respContent; + + public String getMobile() { + return mobile; + } + + public void setMobile(String mobile) { + this.mobile = mobile; + } + + public String getCode() { + return code; + } + + public void setCode(String code) { + this.code = code; + } + + public String getSuccess() { + return success; + } + + public void setSuccess(String success) { + this.success = success; + } + + public VerificationCodeSendChannel getSendChannel() { + return sendChannel; + } + + public void setSendChannel(VerificationCodeSendChannel sendChannel) { + this.sendChannel = sendChannel; + } + + public VerificationCodeBusinessType getBusinessType() { + return businessType; + } + + public void setBusinessType(VerificationCodeBusinessType businessType) { + this.businessType = businessType; + } + + public String getRespContent() { + return respContent; + } + + public void setRespContent(String respContent) { + this.respContent = respContent; + } + + public String getBizId() { + return bizId; + } + + public void setBizId(String bizId) { + this.bizId = bizId; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/base/AbstractBaseMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/base/AbstractBaseMySqlModel.java index b6fcc17be..9f0744504 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/base/AbstractBaseMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/base/AbstractBaseMySqlModel.java @@ -45,35 +45,15 @@ public AbstractBaseMySqlModel() { setCreatedBy(CurrentAccount.id()); } - //region getter/setter - - public String getCreatedBy() { - return createdBy; - } - - public void setCreatedBy(String createdBy) { - this.createdBy = createdBy; - } - public void setCreatedBy(AbstractApiInput input) { String id = getOperatorId(createdBy, input); - setCreatedBy(id); - } - - public String getUpdatedBy() { - return updatedBy; - } - - public void setUpdatedBy(String updatedBy) { - this.updatedBy = updatedBy; - super.setUpdatedTime(new Date()); + this.createdBy = id; } public void setUpdatedBy(AbstractApiInput input) { String id = getOperatorId(updatedBy, input); setUpdatedBy(id); } - public String getOperatorId(AbstractApiInput input) { return getOperatorId(null, input); } @@ -91,5 +71,24 @@ public String getOperatorId(String operatorId, AbstractApiInput input) { return result; } + //region getter/setter + + public String getCreatedBy() { + return createdBy; + } + + public void setCreatedBy(String createdBy) { + this.createdBy = createdBy; + } + + public String getUpdatedBy() { + return updatedBy; + } + + public void setUpdatedBy(String updatedBy) { + this.updatedBy = updatedBy; + super.setUpdatedTime(new Date()); + } + //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/base/AbstractMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/base/AbstractMySqlModel.java index d47892897..2746621d0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/base/AbstractMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/base/AbstractMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/ChatLastAccountMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/ChatLastAccountMysqlModel.java index c8f5762bd..fd069fc97 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/ChatLastAccountMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/ChatLastAccountMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/ChatUnreadMessageMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/ChatUnreadMessageMySqlModel.java index e33fa7bbe..c69d47f89 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/ChatUnreadMessageMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/ChatUnreadMessageMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/MemberChatMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/MemberChatMySqlModel.java index 1d8c690ab..5951e5099 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/MemberChatMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/MemberChatMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/MessageQueueMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/MessageQueueMySqlModel.java index 5831c16cb..25bc4f244 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/MessageQueueMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/chat/MessageQueueMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,8 @@ package com.welab.wefe.board.service.database.entity.chat; import com.welab.wefe.board.service.database.entity.base.AbstractMySqlModel; -import com.welab.wefe.common.enums.GatewayActionType; -import com.welab.wefe.common.enums.ProducerType; +import com.welab.wefe.common.wefe.enums.GatewayActionType; +import com.welab.wefe.common.wefe.enums.ProducerType; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/BloomFilterMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/BloomFilterMysqlModel.java new file mode 100644 index 000000000..3d9a920ff --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/BloomFilterMysqlModel.java @@ -0,0 +1,164 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.entity.data_resource; + +import com.welab.wefe.board.service.constant.BloomfilterAddMethod; +import com.welab.wefe.common.wefe.enums.DataResourceType; + +import javax.persistence.*; + +/** + * @author zane + * @date 2021/12/1 + */ +@Entity(name = "bloom_filter") +@Table(name = "bloom_filter") +public class BloomFilterMysqlModel extends DataResourceMysqlModel { + /** + * 密钥e + */ + @Column(name = "rsa_e") + private String rsaE; + /** + * 密钥n + */ + @Column(name = "rsa_n") + private String rsaN; + /** + * 密钥e + */ + @Column(name = "rsa_d") + private String rsaD; + /** + * 密钥p + */ + @Column(name = "rsa_p") + private String rsaP; + /** + * 密钥q + */ + @Column(name = "rsa_q") + private String rsaQ; + /** + * 数据源id + */ + private String dataSourceId; + /** + * 数据源地址 + */ + private String sourcePath; + /** + * 主键hash生成方法 + */ + private String hashFunction; + /** + * 布隆过滤器添加方式 + */ + @Enumerated(EnumType.STRING) + private BloomfilterAddMethod addMethod; + /** + * sql语句 + */ + private String sqlScript; + + public BloomFilterMysqlModel() { + super.setDataResourceType(DataResourceType.BloomFilter); + } + + // region getter/setter + + public String getRsaE() { + return rsaE; + } + + public void setRsaE(String rsaE) { + this.rsaE = rsaE; + } + + public String getRsaN() { + return rsaN; + } + + public void setRsaN(String rsaN) { + this.rsaN = rsaN; + } + + public String getRsaD() { + return rsaD; + } + + public void setRsaD(String rsaD) { + this.rsaD = rsaD; + } + + public String getDataSourceId() { + return dataSourceId; + } + + public void setDataSourceId(String dataSourceId) { + this.dataSourceId = dataSourceId; + } + + public String getSourcePath() { + return sourcePath; + } + + public void setSourcePath(String sourcePath) { + this.sourcePath = sourcePath; + } + + public String getHashFunction() { + return hashFunction; + } + + public void setHashFunction(String hashFunction) { + this.hashFunction = hashFunction; + } + + public BloomfilterAddMethod getAddMethod() { + return addMethod; + } + + public void setAddMethod(BloomfilterAddMethod addMethod) { + this.addMethod = addMethod; + } + + public String getSqlScript() { + return sqlScript; + } + + public void setSqlScript(String sqlScript) { + this.sqlScript = sqlScript; + } + + public String getRsaP() { + return rsaP; + } + + public void setRsaP(String rsaP) { + this.rsaP = rsaP; + } + + public String getRsaQ() { + return rsaQ; + } + + public void setRsaQ(String rsaQ) { + this.rsaQ = rsaQ; + } + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/DataResourceMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/DataResourceMysqlModel.java new file mode 100644 index 000000000..b12cc9d02 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/DataResourceMysqlModel.java @@ -0,0 +1,321 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.entity.data_resource; + +import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.annotation.JSONField; +import com.vladmihalcea.hibernate.type.json.JsonStringType; +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.DataResourceStorageType; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DataResourcePublicLevel; +import org.hibernate.annotations.Type; +import org.hibernate.annotations.TypeDef; + +import javax.persistence.*; +import java.io.File; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * @author zane + * @date 2021/12/1 + */ +@Entity(name = "data_resource") +@Table(name = "data_resource") +@Inheritance(strategy = InheritanceType.JOINED) +@TypeDef(name = "json", typeClass = JsonStringType.class) +public class DataResourceMysqlModel extends AbstractBaseMySqlModel { + /** + * 资源名称 + */ + private String name; + /** + * 资源类型 + */ + @Enumerated(EnumType.STRING) + private DataResourceType dataResourceType; + /** + * 描述 + */ + private String description; + /** + * 标签 + */ + private String tags; + /** + * 存储类型 + */ + @Enumerated(EnumType.STRING) + private DataResourceStorageType storageType; + /** + * 资源在存储中的命名空间;库名、目录路径) + */ + private String storageNamespace; + /** + * 资源在存储中的名称;表名、文件名) + */ + private String storageResourceName; + /** + * 总数据量 + */ + private long totalDataCount; + /** + * 资源的可见性 + */ + @Enumerated(EnumType.STRING) + private DataResourcePublicLevel publicLevel; + /** + * 可见成员列表;只有在列表中的联邦成员才可以看到该资源的基本信息 + */ + private String publicMemberList; + /** + * 该资源在多少个job中被使用 + */ + private int usageCountInJob; + /** + * 该资源在多少个flow中被使用 + */ + private int usageCountInFlow; + /** + * 该资源在多少个project中被使用 + */ + private int usageCountInProject; + /** + * 该资源被多少个其他成员被使用 + */ + private int usageCountInMember; + /** + * 是否是衍生资源 + */ + private boolean derivedResource; + /** + * 衍生来源,枚举;原始、对齐、分箱) + */ + @Enumerated(EnumType.STRING) + private ComponentType derivedFrom; + /** + * 衍生来源流程id + */ + private String derivedFromFlowId; + /** + * 衍生来源任务id + */ + private String derivedFromJobId; + /** + * 衍生来源子任务id + */ + private String derivedFromTaskId; + /** + * 该数据资源相关的统计信息 + */ + @Type(type = "json") + @Column(columnDefinition = "json") + private JSONObject statisticalInformation; + + // region custom method + + /** + * 获取资源文件对象 + */ + @JSONField(serialize = false) + public File file() { + return Paths.get( + this.getStorageNamespace(), + this.getStorageResourceName() + ) + .toFile(); + } + + /** + * 获取资源所在目录 + */ + @JSONField(serialize = false) + public Path dir() { + return Paths.get( + this.getStorageNamespace() + ); + } + + // endregion + + + // region getter/setter + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public String getTags() { + return tags; + } + + public void setTags(String tags) { + this.tags = tags; + } + + public DataResourceStorageType getStorageType() { + return storageType; + } + + public void setStorageType(DataResourceStorageType storageType) { + this.storageType = storageType; + } + + public String getStorageNamespace() { + return storageNamespace; + } + + public void setStorageNamespace(String storageNamespace) { + this.storageNamespace = storageNamespace; + } + + public String getStorageResourceName() { + return storageResourceName; + } + + public void setStorageResourceName(String storageResourceName) { + this.storageResourceName = storageResourceName; + } + + public Long getTotalDataCount() { + return totalDataCount; + } + + public void setTotalDataCount(long totalDataCount) { + this.totalDataCount = totalDataCount; + } + + public DataResourcePublicLevel getPublicLevel() { + return publicLevel; + } + + public void setPublicLevel(DataResourcePublicLevel publicLevel) { + this.publicLevel = publicLevel; + } + + public String getPublicMemberList() { + return publicMemberList; + } + + public void setPublicMemberList(String publicMemberList) { + this.publicMemberList = publicMemberList; + } + + public int getUsageCountInJob() { + return usageCountInJob; + } + + public void setUsageCountInJob(int usageCountInJob) { + this.usageCountInJob = usageCountInJob; + } + + public int getUsageCountInFlow() { + return usageCountInFlow; + } + + public void setUsageCountInFlow(int usageCountInFlow) { + this.usageCountInFlow = usageCountInFlow; + } + + public int getUsageCountInProject() { + return usageCountInProject; + } + + public void setUsageCountInProject(int usageCountInProject) { + this.usageCountInProject = usageCountInProject; + } + + public int getUsageCountInMember() { + return usageCountInMember; + } + + public void setUsageCountInMember(int usageCountInMember) { + this.usageCountInMember = usageCountInMember; + } + + public boolean isDerivedResource() { + return derivedResource; + } + + public void setDerivedResource(boolean derivedResource) { + this.derivedResource = derivedResource; + } + + public ComponentType getDerivedFrom() { + return derivedFrom; + } + + public void setDerivedFrom(ComponentType derivedFrom) { + this.derivedFrom = derivedFrom; + } + + public String getDerivedFromFlowId() { + return derivedFromFlowId; + } + + public void setDerivedFromFlowId(String derivedFromFlowId) { + this.derivedFromFlowId = derivedFromFlowId; + } + + public String getDerivedFromJobId() { + return derivedFromJobId; + } + + public void setDerivedFromJobId(String derivedFromJobId) { + this.derivedFromJobId = derivedFromJobId; + } + + public String getDerivedFromTaskId() { + return derivedFromTaskId; + } + + public void setDerivedFromTaskId(String derivedFromTaskId) { + this.derivedFromTaskId = derivedFromTaskId; + } + + public JSONObject getStatisticalInformation() { + return statisticalInformation; + } + + public void setStatisticalInformation(JSONObject statisticalInformation) { + this.statisticalInformation = statisticalInformation; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/DataResourceUploadTaskMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/DataResourceUploadTaskMysqlModel.java new file mode 100644 index 000000000..3bd6fc61c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/DataResourceUploadTaskMysqlModel.java @@ -0,0 +1,159 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.entity.data_resource; + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DataResourceUploadStatus; + +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.Table; + +/** + * @author zane + * @date 2021/12/1 + */ +@Entity(name = "data_resource_upload_task") +@Table(name = "data_resource_upload_task") +public class DataResourceUploadTaskMysqlModel extends AbstractBaseMySqlModel { + /** + * 数据资源id + */ + private String dataResourceId; + /** + * 数据资源名称 + */ + private String dataResourceName; + /** + * 资源类型 + */ + private DataResourceType dataResourceType; + /** + * 总数据行数 + */ + private Long totalDataCount; + /** + * 已写入数据行数 + */ + private long completedDataCount; + /** + * 任务进度百分比 + */ + private Integer progressRatio; + /** + * 预计剩余耗时 + */ + private long estimateRemainingTime; + /** + * 无效数据量;主键重复条数) + */ + private long invalidDataCount; + /** + * 错误消息 + */ + private String errorMessage; + /** + * 状态:上传中、已完成、已失败 + */ + @Enumerated(EnumType.STRING) + private DataResourceUploadStatus status; + + // region getter/setter + + public String getDataResourceId() { + return dataResourceId; + } + + public void setDataResourceId(String dataResourceId) { + this.dataResourceId = dataResourceId; + } + + public String getDataResourceName() { + return dataResourceName; + } + + public void setDataResourceName(String dataResourceName) { + this.dataResourceName = dataResourceName; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public Long getTotalDataCount() { + return totalDataCount; + } + + public void setTotalDataCount(Long totalDataCount) { + this.totalDataCount = totalDataCount; + } + + public long getCompletedDataCount() { + return completedDataCount; + } + + public void setCompletedDataCount(long completedDataCount) { + this.completedDataCount = completedDataCount; + } + + public Integer getProgressRatio() { + return progressRatio; + } + + public void setProgressRatio(Integer progressRatio) { + this.progressRatio = progressRatio; + } + + public long getEstimateRemainingTime() { + return estimateRemainingTime; + } + + public void setEstimateRemainingTime(long estimateRemainingTime) { + this.estimateRemainingTime = estimateRemainingTime; + } + + public long getInvalidDataCount() { + return invalidDataCount; + } + + public void setInvalidDataCount(long invalidDataCount) { + this.invalidDataCount = invalidDataCount; + } + + public String getErrorMessage() { + return errorMessage; + } + + public void setErrorMessage(String errorMessage) { + this.errorMessage = errorMessage; + } + + public DataResourceUploadStatus getStatus() { + return status; + } + + public void setStatus(DataResourceUploadStatus status) { + this.status = status; + } + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/ImageDataSetMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/ImageDataSetMysqlModel.java new file mode 100644 index 000000000..3ead9b3e3 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/ImageDataSetMysqlModel.java @@ -0,0 +1,123 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.entity.data_resource; + + +import com.alibaba.fastjson.annotation.JSONField; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; + +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import javax.persistence.Table; +import java.util.List; +import java.util.TreeSet; + +/** + * @author zane + * @date 2021/12/1 + */ +@Entity(name = "image_data_set") +@Table(name = "image_data_set") +public class ImageDataSetMysqlModel extends DataResourceMysqlModel { + /** + * 任务类型;物体检测...) + */ + @Enumerated(EnumType.STRING) + private DeepLearningJobType forJobType; + /** + * label;列表 + */ + private String labelList; + /** + * 已标注数量 + */ + private Long labeledCount; + /** + * 是否已标注完毕 + */ + private boolean labelCompleted; + /** + * 数据集大小 + */ + private Long filesSize; + + public ImageDataSetMysqlModel() { + super.setDataResourceType(DataResourceType.ImageDataSet); + } + + @JSONField(serialize = false) + public TreeSet getLabelSet() { + TreeSet labelSet = new TreeSet<>(); + if (StringUtil.isEmpty(labelList)) { + return labelSet; + } + + List list = StringUtil.splitWithoutEmptyItem(labelList, ","); + for (String label : list) { + labelSet.add(label); + } + return labelSet; + } + + // region getter/setter + + + public DeepLearningJobType getForJobType() { + return forJobType; + } + + public void setForJobType(DeepLearningJobType forJobType) { + this.forJobType = forJobType; + } + + public String getLabelList() { + return labelList; + } + + public void setLabelList(String labelList) { + this.labelList = labelList; + } + + public Long getLabeledCount() { + return labeledCount; + } + + public void setLabeledCount(Long labeledCount) { + this.labeledCount = labeledCount; + } + + public boolean isLabelCompleted() { + return labelCompleted; + } + + public void setLabelCompleted(boolean labelCompleted) { + this.labelCompleted = labelCompleted; + } + + public Long getFilesSize() { + return filesSize; + } + + public void setFilesSize(Long filesSize) { + this.filesSize = filesSize; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/TableDataSetMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/TableDataSetMysqlModel.java new file mode 100644 index 000000000..f6fa4b0a3 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_resource/TableDataSetMysqlModel.java @@ -0,0 +1,174 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.entity.data_resource; + + +import com.welab.wefe.common.wefe.enums.DataResourceType; + +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.Table; + +/** + * @author zane + * @date 2021/12/1 + */ +@Entity(name = "table_data_set") +@Table(name = "table_data_set") +public class TableDataSetMysqlModel extends DataResourceMysqlModel { + /** + * 数据集字段列表 + */ + private String columnNameList; + /** + * 数据集列数 + */ + private Integer columnCount; + /** + * 主键字段 + */ + private String primaryKeyColumn; + /** + * 特征列表 + */ + private String featureNameList; + /** + * 特征数量 + */ + private Integer featureCount; + /** + * 是否包含;Y 值 + */ + @Column(name = "contains_y") + private boolean containsY; + /** + * y列名称列表 + */ + private String yNameList; + /** + * y列的数量 + */ + private Integer yCount; + /** + * 正样本的值 + */ + private String positiveSampleValue; + /** + * 正例数量 + */ + private Long yPositiveSampleCount; + /** + * 正例比例 + */ + private Double yPositiveSampleRatio; + + public TableDataSetMysqlModel() { + super.setDataResourceType(DataResourceType.TableDataSet); + } + + // region getter/setter + + public String getColumnNameList() { + return columnNameList; + } + + public void setColumnNameList(String columnNameList) { + this.columnNameList = columnNameList; + } + + public Integer getColumnCount() { + return columnCount; + } + + public void setColumnCount(Integer columnCount) { + this.columnCount = columnCount; + } + + public String getPrimaryKeyColumn() { + return primaryKeyColumn; + } + + public void setPrimaryKeyColumn(String primaryKeyColumn) { + this.primaryKeyColumn = primaryKeyColumn; + } + + public String getFeatureNameList() { + return featureNameList; + } + + public void setFeatureNameList(String featureNameList) { + this.featureNameList = featureNameList; + } + + public Integer getFeatureCount() { + return featureCount; + } + + public void setFeatureCount(Integer featureCount) { + this.featureCount = featureCount; + } + + public boolean isContainsY() { + return containsY; + } + + public void setContainsY(boolean containsY) { + this.containsY = containsY; + } + + public String getyNameList() { + return yNameList; + } + + public void setyNameList(String yNameList) { + this.yNameList = yNameList; + } + + public Integer getyCount() { + return yCount; + } + + public void setyCount(Integer yCount) { + this.yCount = yCount; + } + + public String getPositiveSampleValue() { + return positiveSampleValue; + } + + public void setPositiveSampleValue(String positiveSampleValue) { + this.positiveSampleValue = positiveSampleValue; + } + + public Long getyPositiveSampleCount() { + return yPositiveSampleCount; + } + + public void setyPositiveSampleCount(Long yPositiveSampleCount) { + this.yPositiveSampleCount = yPositiveSampleCount; + } + + public Double getyPositiveSampleRatio() { + return yPositiveSampleRatio; + } + + public void setyPositiveSampleRatio(Double yPositiveSampleRatio) { + this.yPositiveSampleRatio = yPositiveSampleRatio; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetColumnMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetColumnMysqlModel.java index 80ba160cc..fe13bb97b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetColumnMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetColumnMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,7 +19,7 @@ import com.alibaba.fastjson.JSONObject; import com.vladmihalcea.hibernate.type.json.JsonStringType; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.ColumnDataType; +import com.welab.wefe.common.wefe.enums.ColumnDataType; import org.hibernate.annotations.Type; import org.hibernate.annotations.TypeDef; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetMysqlModel.java deleted file mode 100644 index c392db0e3..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetMysqlModel.java +++ /dev/null @@ -1,356 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.database.entity.data_set; - -import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.DataSetPublicLevel; - -import javax.persistence.Column; -import javax.persistence.Entity; -import javax.persistence.EnumType; -import javax.persistence.Enumerated; - -/** - * @author Zane - */ -@Entity(name = "data_set") -public class DataSetMysqlModel extends AbstractBaseMySqlModel { - - /** - * 数据集名称 - */ - private String name; - /** - * 标签 - */ - private String tags; - /** - * 描述 - */ - private String description; - /** - * 存储类型 - */ - private String storageType; - /** - * 命名空间 - */ - private String namespace; - /** - * 表名 - */ - private String tableName; - /** - * 数据行数 - */ - private Long rowCount; - /** - * 主键字段 - */ - private String primaryKeyColumn; - /** - * 数据集列数 - */ - private Integer columnCount; - /** - * 数据集字段列表 - */ - private String columnNameList; - /** - * 特征数量 - */ - private Integer featureCount; - /** - * 特征列表 - */ - private String featureNameList; - /** - * 是否包含 Y 值 - */ - @Column(name = "contains_y") - private Boolean containsY; - /** - * y列的数量 - */ - private Integer yCount; - /** - * y列名称列表 - */ - private String yNameList; - /** - * 数据集的可见性 - */ - @Enumerated(EnumType.STRING) - private DataSetPublicLevel publicLevel; - /** - * 可见成员列表,只有在列表中的联邦成员才可以看到该数据集的基本信息 - */ - private String publicMemberList; - /** - * 使用次数 - */ - private Integer usageCountInJob = 0; - /** - * 使用次数 - */ - private Integer usageCountInFlow = 0; - /** - * 使用次数 - */ - private Integer usageCountInProject = 0; - /** - * 来源类型,枚举(原始、对齐、分箱) - */ - @Enumerated(EnumType.STRING) - private ComponentType sourceType; - /** - * 来源流程id - */ - private String sourceFlowId; - /** - * 来源任务id - */ - private String sourceJobId; - /** - * 来源子任务id - */ - private String sourceTaskId; - - /** - * 正例样本数量 - */ - private Long yPositiveExampleCount = 0L; - - /** - * 正例样本比例 - */ - private Double yPositiveExampleRatio = 0D; - - - //region getter/setter - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public String getTags() { - return tags; - } - - public void setTags(String tags) { - this.tags = tags; - } - - public String getDescription() { - return description; - } - - public void setDescription(String description) { - this.description = description; - } - - public String getStorageType() { - return storageType; - } - - public void setStorageType(String storageType) { - this.storageType = storageType; - } - - public String getNamespace() { - return namespace; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - public String getTableName() { - return tableName; - } - - public void setTableName(String tableName) { - this.tableName = tableName; - } - - public Long getRowCount() { - return rowCount; - } - - public void setRowCount(Long rowCount) { - this.rowCount = rowCount; - } - - public String getPrimaryKeyColumn() { - return primaryKeyColumn; - } - - public void setPrimaryKeyColumn(String primaryKeyColumn) { - this.primaryKeyColumn = primaryKeyColumn; - } - - public Integer getColumnCount() { - return columnCount; - } - - public void setColumnCount(Integer columnCount) { - this.columnCount = columnCount; - } - - public String getColumnNameList() { - return columnNameList; - } - - public void setColumnNameList(String columnNameList) { - this.columnNameList = columnNameList; - } - - public Integer getFeatureCount() { - return featureCount; - } - - public void setFeatureCount(Integer featureCount) { - this.featureCount = featureCount; - } - - public String getFeatureNameList() { - return featureNameList; - } - - public void setFeatureNameList(String featureNameList) { - this.featureNameList = featureNameList; - } - - public Boolean getContainsY() { - return containsY; - } - - public void setContainsY(Boolean containsY) { - this.containsY = containsY; - } - - public Integer getyCount() { - return yCount; - } - - public void setyCount(Integer yCount) { - this.yCount = yCount; - } - - public String getyNameList() { - return yNameList; - } - - public void setyNameList(String yNameList) { - this.yNameList = yNameList; - } - - public DataSetPublicLevel getPublicLevel() { - return publicLevel; - } - - public void setPublicLevel(DataSetPublicLevel publicLevel) { - this.publicLevel = publicLevel; - } - - public String getPublicMemberList() { - return publicMemberList; - } - - public void setPublicMemberList(String publicMemberList) { - this.publicMemberList = publicMemberList; - } - - public Integer getUsageCountInJob() { - return usageCountInJob; - } - - public void setUsageCountInJob(Integer usageCountInJob) { - this.usageCountInJob = usageCountInJob; - } - - public Integer getUsageCountInFlow() { - return usageCountInFlow; - } - - public void setUsageCountInFlow(Integer usageCountInFlow) { - this.usageCountInFlow = usageCountInFlow; - } - - public Integer getUsageCountInProject() { - return usageCountInProject; - } - - public void setUsageCountInProject(Integer usageCountInProject) { - this.usageCountInProject = usageCountInProject; - } - - public ComponentType getSourceType() { - return sourceType; - } - - public void setSourceType(ComponentType sourceType) { - this.sourceType = sourceType; - } - - public String getSourceFlowId() { - return sourceFlowId; - } - - public void setSourceFlowId(String sourceFlowId) { - this.sourceFlowId = sourceFlowId; - } - - public String getSourceJobId() { - return sourceJobId; - } - - public void setSourceJobId(String sourceJobId) { - this.sourceJobId = sourceJobId; - } - - public String getSourceTaskId() { - return sourceTaskId; - } - - public void setSourceTaskId(String sourceTaskId) { - this.sourceTaskId = sourceTaskId; - } - - public Long getyPositiveExampleCount() { - return yPositiveExampleCount; - } - - public void setyPositiveExampleCount(Long yPositiveExampleCount) { - this.yPositiveExampleCount = yPositiveExampleCount; - } - - public Double getyPositiveExampleRatio() { - return yPositiveExampleRatio; - } - - public void setyPositiveExampleRatio(Double yPositiveExampleRatio) { - this.yPositiveExampleRatio = yPositiveExampleRatio; - } - - //endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetTaskMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetTaskMysqlModel.java deleted file mode 100644 index f8a08dbc1..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/DataSetTaskMysqlModel.java +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.database.entity.data_set; - -import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; - -import javax.persistence.Entity; - -/** - * @author lonnie - */ -@Entity(name = "data_set_task") -public class DataSetTaskMysqlModel extends AbstractBaseMySqlModel { - - /** - * 数据集名 - */ - private String dataSetName; - - /** - * 数据集id - */ - private String dataSetId; - - /** - * 总数据行数 - */ - private long totalRowCount = 0; - /** - * 已写入数据行数 - */ - private long addedRowCount = 0; - - /** - * 任务进度百分比 - */ - private int progress; - - /** - * 预计剩余耗时 - */ - private long estimateTime; - - /** - * 主键重复条数 - */ - private long repeatIdRowCount; - /** - * 错误消息 - */ - private String errorMessage; - - // region getter/setter - - - public String getDataSetName() { - return dataSetName; - } - - public void setDataSetName(String dataSetName) { - this.dataSetName = dataSetName; - } - - public String getDataSetId() { - return dataSetId; - } - - public void setDataSetId(String dataSetId) { - this.dataSetId = dataSetId; - } - - public long getTotalRowCount() { - return totalRowCount; - } - - public void setTotalRowCount(long totalRowCount) { - this.totalRowCount = totalRowCount; - } - - public long getAddedRowCount() { - return addedRowCount; - } - - public void setAddedRowCount(long addedRowCount) { - this.addedRowCount = addedRowCount; - } - - public int getProgress() { - return progress; - } - - public void setProgress(int progress) { - this.progress = progress; - } - - public long getEstimateTime() { - return estimateTime; - } - - public void setEstimateTime(long estimateTime) { - this.estimateTime = estimateTime; - } - - public long getRepeatIdRowCount() { - return repeatIdRowCount; - } - - public void setRepeatIdRowCount(long repeatIdRowCount) { - this.repeatIdRowCount = repeatIdRowCount; - } - - public String getErrorMessage() { - return errorMessage; - } - - public void setErrorMessage(String errorMessage) { - this.errorMessage = errorMessage; - } - - // endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/ImageDataSetSampleMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/ImageDataSetSampleMysqlModel.java new file mode 100644 index 000000000..07223be28 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/data_set/ImageDataSetSampleMysqlModel.java @@ -0,0 +1,166 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity.data_set; + +import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.annotation.JSONField; +import com.vladmihalcea.hibernate.type.json.JsonStringType; +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.common.util.StringUtil; +import org.hibernate.annotations.Type; +import org.hibernate.annotations.TypeDef; + +import javax.persistence.Column; +import javax.persistence.Entity; +import java.util.List; +import java.util.TreeSet; + +/** + * @author Zane + */ +@Entity(name = "image_data_set_sample") +@TypeDef(name = "json", typeClass = JsonStringType.class) +public class ImageDataSetSampleMysqlModel extends AbstractBaseMySqlModel { + + /** + * 数据集id + */ + private String dataSetId; + /** + * 文件名 + */ + private String fileName; + /** + * 包含文件名的文件路径 + */ + private String filePath; + /** + * 文件大小 + */ + private long fileSize; + /** + * label + */ + private String labelList; + /** + * 是否已标注 + */ + private boolean labeled; + /** + * json 形式的标注信息 + */ + @Type(type = "json") + @Column(columnDefinition = "json") + private JSONObject labelInfo; + + /** + * xml 形式的标注信息 + */ + private String xmlAnnotation; + + @JSONField(serialize = false) + public TreeSet getLabelSet() { + TreeSet labelSet = new TreeSet<>(); + if (StringUtil.isEmpty(labelList)) { + return labelSet; + } + + List list = StringUtil.splitWithoutEmptyItem(labelList, ","); + for (String label : list) { + labelSet.add(label); + } + return labelSet; + } + + public void setLabelList(String labelList) { + + // 在 labelList 前后加上逗号,用于sql方便匹配单个 label。 + if (labelList != null) { + if (!labelList.startsWith(",")) { + labelList = "," + labelList; + } + if (!labelList.endsWith(",")) { + labelList = labelList + ","; + } + } + this.labelList = labelList; + } + + //region getter/setter + + public String getDataSetId() { + return dataSetId; + } + + public void setDataSetId(String dataSetId) { + this.dataSetId = dataSetId; + } + + public String getFileName() { + return fileName; + } + + public void setFileName(String fileName) { + this.fileName = fileName; + } + + public String getFilePath() { + return filePath; + } + + public void setFilePath(String filePath) { + this.filePath = filePath; + } + + public long getFileSize() { + return fileSize; + } + + public void setFileSize(long fileSize) { + this.fileSize = fileSize; + } + + public String getLabelList() { + return labelList; + } + + public boolean isLabeled() { + return labeled; + } + + public void setLabeled(boolean labeled) { + this.labeled = labeled; + } + + public JSONObject getLabelInfo() { + return labelInfo; + } + + public void setLabelInfo(JSONObject labelInfo) { + this.labelInfo = labelInfo; + } + + public String getXmlAnnotation() { + return xmlAnnotation; + } + + public void setXmlAnnotation(String xmlAnnotation) { + this.xmlAnnotation = xmlAnnotation; + } + + //endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowActionLogMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowActionLogMySqlModel.java index 2773fd29c..5f49841b0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowActionLogMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowActionLogMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,8 @@ package com.welab.wefe.board.service.database.entity.flow; import com.welab.wefe.board.service.database.entity.base.AbstractMySqlModel; -import com.welab.wefe.common.enums.GatewayActionType; -import com.welab.wefe.common.enums.ProducerType; +import com.welab.wefe.common.wefe.enums.GatewayActionType; +import com.welab.wefe.common.wefe.enums.ProducerType; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowActionQueueMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowActionQueueMySqlModel.java index 97543517f..0c2dbee59 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowActionQueueMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowActionQueueMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,8 @@ package com.welab.wefe.board.service.database.entity.flow; import com.welab.wefe.board.service.database.entity.base.AbstractMySqlModel; -import com.welab.wefe.common.enums.FlowActionType; -import com.welab.wefe.common.enums.ProducerType; +import com.welab.wefe.common.wefe.enums.FlowActionType; +import com.welab.wefe.common.wefe.enums.ProducerType; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowTemplateMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowTemplateMySqlModel.java index a6117ab1d..15e5bc123 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowTemplateMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/flow/FlowTemplateMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,7 @@ package com.welab.wefe.board.service.database.entity.flow; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/ExportProgressMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/ExportProgressMySqlModel.java new file mode 100644 index 000000000..be772f3cb --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/ExportProgressMySqlModel.java @@ -0,0 +1,118 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.entity.fusion; + + + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.board.service.fusion.enums.ExportStatus; + +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; + +/** + * @author hunter.zhao + */ +@Entity(name = "fusion_result_export_progress") +public class ExportProgressMySqlModel extends AbstractBaseMySqlModel { + /** + * 融合任务businessId + */ + String businessId; + + /** + * 导出表名 + */ + String tableName; + + /** + * 进度 + */ + int progress; + + /** + * 导出总数 + */ + int totalDataCount; + + /** + * 已导出数量 + */ + int processedCount; + + long finishTime; + + @Enumerated(EnumType.STRING) + ExportStatus status; + + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public String getTableName() { + return tableName; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public int getProgress() { + return progress; + } + + public void setProgress(int progress) { + this.progress = progress; + } + + public int getTotalDataCount() { + return totalDataCount; + } + + public void setTotalDataCount(int totalDataCount) { + this.totalDataCount = totalDataCount; + } + + public int getProcessedCount() { + return processedCount; + } + + public void setProcessedCount(int processedCount) { + this.processedCount = processedCount; + } + + public long getFinishTime() { + return finishTime; + } + + public void setFinishTime(long finishTime) { + this.finishTime = finishTime; + } + + public ExportStatus getStatus() { + return status; + } + + public void setStatus(ExportStatus status) { + this.status = status; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FieldInfoMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FieldInfoMySqlModel.java new file mode 100644 index 000000000..a5e9504f9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FieldInfoMySqlModel.java @@ -0,0 +1,91 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity.fusion; + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.common.wefe.enums.HashOptions; + +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; + +/** + * @author hunter.zhao + */ +@Entity(name = "fusion_field_info") +public class FieldInfoMySqlModel extends AbstractBaseMySqlModel { + private String businessId; + + private String columns; + + @Enumerated(EnumType.STRING) + private HashOptions options; + + private int fristIndex; + + private int endIndex; + + private int position; + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public String getColumns() { + return columns; + } + + public void setColumns(String columns) { + this.columns = columns; + } + + public HashOptions getOptions() { + return options; + } + + public void setOptions(HashOptions options) { + this.options = options; + } + + public int getFristIndex() { + return fristIndex; + } + + public void setFristIndex(int fristIndex) { + this.fristIndex = fristIndex; + } + + public int getEndIndex() { + return endIndex; + } + + public void setEndIndex(int endIndex) { + this.endIndex = endIndex; + } + + public int getPosition() { + return position; + } + + public void setPosition(int position) { + this.position = position; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FusionActuatorInfoMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FusionActuatorInfoMySqlModel.java new file mode 100644 index 000000000..6f16caf9e --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FusionActuatorInfoMySqlModel.java @@ -0,0 +1,72 @@ +package com.welab.wefe.board.service.database.entity.fusion; + +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.fusion.core.enums.FusionTaskStatus; + +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; + +/** + * @author hunter.zhao + */ +@Entity(name = "fusion_actuator_info") +public class FusionActuatorInfoMySqlModel extends AbstractBaseMySqlModel { + String type; + + @Enumerated(EnumType.STRING) + FusionTaskStatus status; + + int progress; + + String businessId; + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public FusionTaskStatus getStatus() { + return status; + } + + public void setStatus(FusionTaskStatus status) { + this.status = status; + } + + public int getProgress() { + return progress; + } + + public void setProgress(int progress) { + this.progress = progress; + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FusionResultMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FusionResultMySqlModel.java new file mode 100644 index 000000000..c181db5b6 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FusionResultMySqlModel.java @@ -0,0 +1,40 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity.fusion; + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; + +import javax.persistence.Entity; +import java.util.Date; + +/** + * @author hunter.zhao + */ +@Entity(name = "fusion_result") +public class FusionResultMySqlModel extends AbstractBaseMySqlModel { + String taskId; + + String name; + + String rows; + + Date startTime; + + Date endTime; + + long spend; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FusionTaskMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FusionTaskMySqlModel.java new file mode 100644 index 000000000..24d24c079 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/FusionTaskMySqlModel.java @@ -0,0 +1,321 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity.fusion; + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.fusion.core.enums.AlgorithmType; +import com.welab.wefe.fusion.core.enums.FusionTaskStatus; +import com.welab.wefe.fusion.core.enums.PSIActuatorRole; + +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; + +/** + * @author hunter.zhao + */ +@Entity(name = "fusion_task") +public class FusionTaskMySqlModel extends AbstractBaseMySqlModel { + + String projectId; + + String businessId; + + String name; + + @Enumerated(EnumType.STRING) + FusionTaskStatus status; + + String error; + + String dstMemberId; + + String dataResourceId; + + @Enumerated(EnumType.STRING) + @Column(name = "data_resource_type") + DataResourceType dataResourceType; + + /** + * Number of rows of data resources + */ + Long rowCount; + + String hashFunction; + + String partnerDataResourceId; + + @Enumerated(EnumType.STRING) + @Column(name = "partner_data_resource_type") + DataResourceType partnerDataResourceType; + + /** + * Number of rows of data resources + */ + Long partnerRowCount; + + String partnerHashFunction; + + /** + * Whether the trace + */ + public boolean isTrace; + + /** + * Traces the field + */ + public String traceColumn; + + + @Enumerated(EnumType.STRING) + @Column(name = "psi_actuator_role") + PSIActuatorRole psiActuatorRole; + + @Enumerated(EnumType.STRING) + @Column(name = "my_role") + JobMemberRole myRole; + + @Enumerated(EnumType.STRING) + @Column(name = "algorithm") + AlgorithmType algorithm; + + + /** + * Number of fusion + */ + public Long fusionCount = 0L; + + /** + * Number of fusion + */ + public Long processedCount = 0L; + + /** + * Number of fusion + */ + public Long dataCount = 0L; + + public long spend; + + public String description; + + public String comment; + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getPartnerDataResourceId() { + return partnerDataResourceId; + } + + public void setPartnerDataResourceId(String partnerDataResourceId) { + this.partnerDataResourceId = partnerDataResourceId; + } + + public DataResourceType getPartnerDataResourceType() { + return partnerDataResourceType; + } + + public void setPartnerDataResourceType(DataResourceType partnerDataResourceType) { + this.partnerDataResourceType = partnerDataResourceType; + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public FusionTaskStatus getStatus() { + return status; + } + + public void setStatus(FusionTaskStatus status) { + this.status = status; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public String getDstMemberId() { + return dstMemberId; + } + + public void setDstMemberId(String dstMemberId) { + this.dstMemberId = dstMemberId; + } + + public String getDataResourceId() { + return dataResourceId; + } + + public void setDataResourceId(String dataResourceId) { + this.dataResourceId = dataResourceId; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public boolean isTrace() { + return isTrace; + } + + public void setTrace(boolean trace) { + isTrace = trace; + } + + public String getTraceColumn() { + return traceColumn; + } + + public void setTraceColumn(String traceColumn) { + this.traceColumn = traceColumn; + } + + public Long getRowCount() { + return rowCount; + } + + public void setRowCount(Long rowCount) { + this.rowCount = rowCount; + } + + public PSIActuatorRole getPsiActuatorRole() { + return psiActuatorRole; + } + + public void setPsiActuatorRole(PSIActuatorRole psiActuatorRole) { + this.psiActuatorRole = psiActuatorRole; + } + + public AlgorithmType getAlgorithm() { + return algorithm; + } + + public void setAlgorithm(AlgorithmType algorithm) { + this.algorithm = algorithm; + } + + + public Long getFusionCount() { + return fusionCount; + } + + public void setFusionCount(Long fusionCount) { + this.fusionCount = fusionCount; + } + + public long getSpend() { + return spend; + } + + public void setSpend(long spend) { + this.spend = spend; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public String getComment() { + return comment; + } + + public void setComment(String comment) { + this.comment = comment; + } + + public JobMemberRole getMyRole() { + return myRole; + } + + public void setMyRole(JobMemberRole myRole) { + this.myRole = myRole; + } + + public Long getPartnerRowCount() { + return partnerRowCount; + } + + public void setPartnerRowCount(Long partnerRowCount) { + this.partnerRowCount = partnerRowCount; + } + + public Long getProcessedCount() { + return processedCount; + } + + public void setProcessedCount(Long processedCount) { + this.processedCount = processedCount; + } + + public Long getDataCount() { + return dataCount; + } + + public void setDataCount(Long dataCount) { + this.dataCount = dataCount; + } + + public String getHashFunction() { + return hashFunction; + } + + public void setHashFunction(String hashFunction) { + this.hashFunction = hashFunction; + } + + public String getPartnerHashFunction() { + return partnerHashFunction; + } + + public void setPartnerHashFunction(String partnerHashFunction) { + this.partnerHashFunction = partnerHashFunction; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/bloomfilter/BloomFilterColumnMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/bloomfilter/BloomFilterColumnMysqlModel.java new file mode 100644 index 000000000..05789e45c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/bloomfilter/BloomFilterColumnMysqlModel.java @@ -0,0 +1,132 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity.fusion.bloomfilter; + +import com.alibaba.fastjson.JSONObject; +import com.vladmihalcea.hibernate.type.json.JsonStringType; +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.common.wefe.enums.ColumnDataType; +import org.hibernate.annotations.Type; +import org.hibernate.annotations.TypeDef; + +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; + +/** + * @author jacky.jiang + */ +@Entity(name = "bloom_filter_column") +@TypeDef(name = "json", typeClass = JsonStringType.class) +public class BloomFilterColumnMysqlModel extends AbstractBaseMySqlModel { + + /** + * 过滤器Id + */ + private String bloomFilterId; + /** + * 字段序号 + */ + @Column(name = "`index`") + private Integer index; + /** + * 字段名称 + */ + private String name; + /** + * 数据类型 + */ + @Enumerated(EnumType.STRING) + private ColumnDataType dataType; + /** + * 注释 + */ + private String comment; + /** + * 空值数据行数 + */ + private Long emptyRows; + /** + * 数值分布 + */ + @Type(type = "json") + @Column(columnDefinition = "json") + private JSONObject valueDistribution; + + //region getter/setter + + + public String getBloomFilterId() { + return bloomFilterId; + } + + public void setBloomFilterId(String bloomFilterId) { + this.bloomFilterId = bloomFilterId; + } + + public Integer getIndex() { + return index; + } + + public void setIndex(Integer index) { + this.index = index; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public ColumnDataType getDataType() { + return dataType; + } + + public void setDataType(ColumnDataType dataType) { + this.dataType = dataType; + } + + public String getComment() { + return comment; + } + + public void setComment(String comment) { + this.comment = comment; + } + + public Long getEmptyRows() { + return emptyRows; + } + + public void setEmptyRows(Long emptyRows) { + this.emptyRows = emptyRows; + } + + public JSONObject getValueDistribution() { + return valueDistribution; + } + + public void setValueDistribution(JSONObject valueDistribution) { + this.valueDistribution = valueDistribution; + } + + + //endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/bloomfilter/BloomFilterTaskMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/bloomfilter/BloomFilterTaskMysqlModel.java new file mode 100644 index 000000000..5c6e0e968 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/fusion/bloomfilter/BloomFilterTaskMysqlModel.java @@ -0,0 +1,142 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.entity.fusion.bloomfilter; + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; + +import javax.persistence.Entity; + +/** + * @author jacky.jiang + */ +@Entity(name = "bloom_filter_task") +public class BloomFilterTaskMysqlModel extends AbstractBaseMySqlModel { + + /** + * 过滤器名 + */ + private String bloomFilterName; + + /** + * 过滤器id + */ + private String bloomFilterId; + + /** + * 总数据行数 + */ + private long totalRowCount = 0; + /** + * 已写入数据行数 + */ + private long addedRowCount = 0; + + /** + * 任务进度百分比 + */ + private int progress; + + /** + * 预计剩余耗时 + */ + private long estimateTime; + + /** + * 主键重复条数 + */ + private long repeatIdRowCount; + /** + * 错误消息 + */ + private String errorMessage; + + // region getter/setter + + public String getBloomfilterName() { + return bloomFilterName; + } + + public void setBloomfilterName(String bloomfilterName) { + this.bloomFilterName = bloomfilterName; + } + + public String getBloomFilterName() { + return bloomFilterName; + } + + public void setBloomFilterName(String bloomFilterName) { + this.bloomFilterName = bloomFilterName; + } + + public String getBloomFilterId() { + return bloomFilterId; + } + + public void setBloomFilterId(String bloomFilterId) { + this.bloomFilterId = bloomFilterId; + } + + public long getTotalRowCount() { + return totalRowCount; + } + + public void setTotalRowCount(long totalRowCount) { + this.totalRowCount = totalRowCount; + } + + public long getAddedRowCount() { + return addedRowCount; + } + + public void setAddedRowCount(long addedRowCount) { + this.addedRowCount = addedRowCount; + } + + public int getProgress() { + return progress; + } + + public void setProgress(int progress) { + this.progress = progress; + } + + public long getEstimateTime() { + return estimateTime; + } + + public void setEstimateTime(long estimateTime) { + this.estimateTime = estimateTime; + } + + public long getRepeatIdRowCount() { + return repeatIdRowCount; + } + + public void setRepeatIdRowCount(long repeatIdRowCount) { + this.repeatIdRowCount = repeatIdRowCount; + } + + public String getErrorMessage() { + return errorMessage; + } + + public void setErrorMessage(String errorMessage) { + this.errorMessage = errorMessage; + } + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/JobMemberMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/JobMemberMySqlModel.java index 1af60bf61..278e63ebc 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/JobMemberMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/JobMemberMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,7 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/JobMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/JobMySqlModel.java index 915fab320..870b856a5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/JobMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/JobMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,9 +17,9 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.JobStatus; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.JobStatus; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ModelOotRecordMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ModelOotRecordMysqlModel.java index 3a2d4264e..87adc5d6c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ModelOotRecordMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ModelOotRecordMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectDataSetMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectDataSetMySqlModel.java index b6de66cf2..87f3891f7 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectDataSetMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectDataSetMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,9 +17,10 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import javax.persistence.Entity; import javax.persistence.EnumType; @@ -79,7 +80,11 @@ public class ProjectDataSetMySqlModel extends AbstractBaseMySqlModel { * 来源子任务id */ private String sourceTaskId; - + /** + * 数据集类型 + */ + @Enumerated(EnumType.STRING) + private DataResourceType dataResourceType; //region getter/setter @@ -164,6 +169,14 @@ public void setSourceTaskId(String sourceTaskId) { this.sourceTaskId = sourceTaskId; } + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectFlowMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectFlowMySqlModel.java index 02aba4b78..6c2d520c6 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectFlowMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectFlowMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,9 +17,10 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.ProjectFlowStatus; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectFlowStatus; import javax.persistence.Entity; import javax.persistence.EnumType; @@ -43,6 +44,11 @@ public class ProjectFlowMySqlModel extends AbstractBaseMySqlModel { */ @Enumerated(EnumType.STRING) private FederatedLearningType federatedLearningType; + /** + * 深度学习任务类型 + */ + @Enumerated(EnumType.STRING) + private DeepLearningJobType deepLearningJobType; /** * 项目ID */ @@ -101,6 +107,14 @@ public void setFederatedLearningType(FederatedLearningType federatedLearningType this.federatedLearningType = federatedLearningType; } + public DeepLearningJobType getDeepLearningJobType() { + return deepLearningJobType; + } + + public void setDeepLearningJobType(DeepLearningJobType deepLearningJobType) { + this.deepLearningJobType = deepLearningJobType; + } + public String getMessage() { return message; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectFlowNodeMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectFlowNodeMySqlModel.java index 1112d841d..2ebe34f7c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectFlowNodeMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectFlowNodeMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,7 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.ComponentType; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMemberAuditMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMemberAuditMySqlModel.java index 625cb94c2..0e71732b0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMemberAuditMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMemberAuditMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,7 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.AuditStatus; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMemberMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMemberMySqlModel.java index b8d7e2136..78d543c71 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMemberMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMemberMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,8 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMySqlModel.java index 2fa4ad024..cd539dd57 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/ProjectMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,9 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectType; import javax.persistence.Entity; import javax.persistence.EnumType; @@ -140,6 +141,12 @@ public class ProjectMySqlModel extends AbstractBaseMySqlModel { */ private String flowStatusStatistics; + /** + * 项目类型 + */ + @Enumerated(EnumType.STRING) + private ProjectType projectType; + //region getter/setter public boolean isDeleted() { @@ -326,5 +333,13 @@ public void setClosedTime(Date closedTime) { this.closedTime = closedTime; } + public ProjectType getProjectType() { + return projectType; + } + + public void setProjectType(ProjectType projectType) { + this.projectType = projectType; + } + //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskContextMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskContextMySqlModel.java index ea1ea8122..df0b9d002 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskContextMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskContextMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskMySqlModel.java index 935425922..8b704e93e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,9 +17,9 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskStatus; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskStatus; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskProgressMysqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskProgressMysqlModel.java index 217e02493..080198a77 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskProgressMysqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskProgressMysqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,8 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskResultMySqlModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskResultMySqlModel.java index 45ca6c9a2..f3c583d92 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskResultMySqlModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/entity/job/TaskResultMySqlModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,8 @@ package com.welab.wefe.board.service.database.entity.job; import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import javax.persistence.Entity; import javax.persistence.EnumType; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/listener/AccountMysqlModelListener.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/listener/AccountMysqlModelListener.java new file mode 100644 index 000000000..cc44eafa1 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/listener/AccountMysqlModelListener.java @@ -0,0 +1,58 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.listener; + +import com.welab.wefe.board.service.database.entity.AccountMysqlModel; +import com.welab.wefe.board.service.util.BoardSM4Util; +import com.welab.wefe.common.exception.StatusCodeWithException; + +import javax.persistence.PostLoad; +import javax.persistence.PrePersist; +import javax.persistence.PreUpdate; + +public class AccountMysqlModelListener { + + /** + * before save + */ + @PrePersist + public void prePersist(Object entity) throws StatusCodeWithException { + if (null != entity) { + AccountMysqlModel model = (AccountMysqlModel) entity; + model.setPhoneNumber(BoardSM4Util.encryptPhoneNumber(model.getPhoneNumber())); + } + } + + /** + * before update + */ + @PreUpdate + public void preUpdate(Object entity) throws StatusCodeWithException { + prePersist(entity); + } + + /** + * query + */ + @PostLoad + public void postLoad(Object entity) throws StatusCodeWithException { + if (null != entity) { + AccountMysqlModel model = (AccountMysqlModel) entity; + model.setPhoneNumber(BoardSM4Util.decryptPhoneNumber(model.getPhoneNumber())); + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/listener/GlobalConfigMysqlModelListener.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/listener/GlobalConfigMysqlModelListener.java new file mode 100644 index 000000000..b23ac7fe1 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/listener/GlobalConfigMysqlModelListener.java @@ -0,0 +1,67 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.listener; + +import com.welab.wefe.board.service.database.entity.GlobalConfigMysqlModel; +import com.welab.wefe.board.service.database.entity.VerificationCodeMysqlModel; +import com.welab.wefe.board.service.service.globalconfig.BaseGlobalConfigService; +import com.welab.wefe.board.service.util.BoardSM4Util; +import com.welab.wefe.common.exception.StatusCodeWithException; + +import javax.persistence.PostLoad; +import javax.persistence.PrePersist; +import javax.persistence.PreUpdate; + +public class GlobalConfigMysqlModelListener { + /** + * before save + */ + @PrePersist + public void prePersist(Object entity) throws StatusCodeWithException { + if (null != entity) { + GlobalConfigMysqlModel model = (GlobalConfigMysqlModel) entity; + if (BaseGlobalConfigService.Group.MEMBER_INFO.equals(model.getGroup()) + && "member_mobile".equals(model.getName())) { + model.setValue(BoardSM4Util.encryptCommonText(model.getValue())); + } + } + } + + /** + * before update + */ + @PreUpdate + public void preUpdate(Object entity) throws StatusCodeWithException { + prePersist(entity); + } + + /** + * query + */ + @PostLoad + public void postLoad(Object entity) throws StatusCodeWithException { + if (null != entity) { + GlobalConfigMysqlModel model = (GlobalConfigMysqlModel) entity; + if (BaseGlobalConfigService.Group.MEMBER_INFO.equals(model.getGroup()) + && "member_mobile".equals(model.getName())) { + if (BoardSM4Util.isEncryptText(model.getValue())) { + model.setValue(BoardSM4Util.decryptCommonText(model.getValue())); + } + } + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/listener/VerificationCodeMysqlModelListener.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/listener/VerificationCodeMysqlModelListener.java new file mode 100644 index 000000000..b730f5b17 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/listener/VerificationCodeMysqlModelListener.java @@ -0,0 +1,58 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.listener; + +import com.welab.wefe.board.service.database.entity.VerificationCodeMysqlModel; +import com.welab.wefe.board.service.util.BoardSM4Util; +import com.welab.wefe.common.exception.StatusCodeWithException; + +import javax.persistence.PostLoad; +import javax.persistence.PrePersist; +import javax.persistence.PreUpdate; + +public class VerificationCodeMysqlModelListener { + + /** + * before save + */ + @PrePersist + public void prePersist(Object entity) throws StatusCodeWithException { + if (null != entity) { + VerificationCodeMysqlModel model = (VerificationCodeMysqlModel) entity; + model.setMobile(BoardSM4Util.encryptPhoneNumber(model.getMobile())); + } + } + + /** + * before update + */ + @PreUpdate + public void preUpdate(Object entity) throws StatusCodeWithException { + prePersist(entity); + } + + /** + * query + */ + @PostLoad + public void postLoad(Object entity) throws StatusCodeWithException { + if (null != entity) { + VerificationCodeMysqlModel model = (VerificationCodeMysqlModel) entity; + model.setMobile(BoardSM4Util.decryptPhoneNumber(model.getMobile())); + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/AccountRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/AccountRepository.java index ed2bbc909..b905204c4 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/AccountRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/AccountRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,22 +16,45 @@ package com.welab.wefe.board.service.database.repository; -import com.welab.wefe.board.service.database.entity.AccountMySqlModel; +import com.welab.wefe.board.service.database.entity.AccountMysqlModel; import com.welab.wefe.board.service.database.repository.base.BaseRepository; import org.springframework.data.jpa.repository.Modifying; import org.springframework.data.jpa.repository.Query; import org.springframework.stereotype.Repository; +import org.springframework.transaction.annotation.Transactional; /** * @author Zane */ @Repository -public interface AccountRepository extends BaseRepository { +public interface AccountRepository extends BaseRepository { - AccountMySqlModel findByPhoneNumber(String phoneNumber); + AccountMysqlModel findByPhoneNumber(String phoneNumber); @Modifying(clearAutomatically = true) @Query(value = "update account a set a.superAdminRole = false,a.adminRole = false where a.id =?1 ") void cancelSuperAdmin(String id); + + @Transactional + @Modifying(clearAutomatically = true) + @Query(value = "update #{#entityName} set last_action_time = now() where id =?1 ", nativeQuery = true) + void updateLastActionTime(String id); + + + /** + * 禁用 90 天未活动的账号 + */ + @Transactional + @Modifying(clearAutomatically = true) + @Query(value = "update #{#entityName} set enable=false where DATEDIFF(now(),last_action_time)>90", nativeQuery = true) + int disableAccountWithoutAction90Days(); + + /** + * 注销 180 天未活动的账号 + */ + @Transactional + @Modifying(clearAutomatically = true) + @Query(value = "update #{#entityName} set cancelled=true where DATEDIFF(now(),last_action_time)>180", nativeQuery = true) + int cancelAccountWithoutAction180Days(); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/BlacklistRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/BlacklistRepository.java index 3c9f30fed..503974f32 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/BlacklistRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/BlacklistRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ChatLastAccountRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ChatLastAccountRepository.java index 2ef3515ad..189dfa2ca 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ChatLastAccountRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ChatLastAccountRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ChatUnreadMessageRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ChatUnreadMessageRepository.java index fad5fc1db..6396aae03 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ChatUnreadMessageRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ChatUnreadMessageRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataOutputInfoRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataOutputInfoRepository.java index 29ee8d791..f024849a2 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataOutputInfoRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataOutputInfoRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetColumnRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetColumnRepository.java index b9834ea74..d645e9b7c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetColumnRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetColumnRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -28,7 +28,7 @@ @Repository public interface DataSetColumnRepository extends BaseRepository { - @Modifying + @Modifying(clearAutomatically = true) @Transactional void deleteByDataSetId(String dataSetId); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetRepository.java deleted file mode 100644 index 49cccd21c..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetRepository.java +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.database.repository; - -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; -import com.welab.wefe.board.service.database.repository.base.BaseRepository; -import org.springframework.data.jpa.repository.Modifying; -import org.springframework.data.jpa.repository.Query; -import org.springframework.stereotype.Repository; -import org.springframework.transaction.annotation.Transactional; - -import java.util.List; - -/** - * @author Zane - */ -@Repository -public interface DataSetRepository extends BaseRepository { - - @Query(value = "select tags,count(tags) as count from #{#entityName} where tags<>'' group by tags;", nativeQuery = true) - List listAllTags(); - - @Query(value = "select count(*) from #{#entityName} where name=?1", nativeQuery = true) - int countByName(String name); - - @Query(value = "select count(*) from #{#entityName} where name=?1 and id<>?2", nativeQuery = true) - int countByName(String name, String id); - - @Modifying - @Transactional - @Query(value = "update data_set set usage_count_in_project=(select count(*) from project_data_set where data_set_id=?1 and audit_status='agree') where id=?1", nativeQuery = true) - void updateUsageCountInProject(String dataSetId); -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetTaskRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetTaskRepository.java deleted file mode 100644 index 8baed0fd4..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSetTaskRepository.java +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.database.repository; - -import com.welab.wefe.board.service.database.entity.data_set.DataSetTaskMysqlModel; -import com.welab.wefe.board.service.database.repository.base.BaseRepository; -import org.springframework.stereotype.Repository; - -/** - * @author lonnie - */ -@Repository -public interface DataSetTaskRepository extends BaseRepository { -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSourceRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSourceRepository.java index 636d891c3..9e2c69417 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSourceRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/DataSourceRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,7 @@ package com.welab.wefe.board.service.database.repository; -import com.welab.wefe.board.service.database.entity.DataSourceMySqlModel; +import com.welab.wefe.board.service.database.entity.DataSourceMysqlModel; import com.welab.wefe.board.service.database.repository.base.BaseRepository; import org.springframework.data.jpa.repository.Query; import org.springframework.stereotype.Repository; @@ -25,7 +25,7 @@ * @author Johnny.lin */ @Repository -public interface DataSourceRepository extends BaseRepository { +public interface DataSourceRepository extends BaseRepository { @Query(value = "select count(*) from #{#entityName} where name=?1", nativeQuery = true) int countByName(String name); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FeatureJobMemberRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FeatureJobMemberRepository.java index 7adbaeb1b..327c8b399 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FeatureJobMemberRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FeatureJobMemberRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowActionLogRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowActionLogRepository.java index 750fd63fb..d5fb2a5fd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowActionLogRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowActionLogRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowActionQueueRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowActionQueueRepository.java index 11afed1bf..f44544f2b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowActionQueueRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowActionQueueRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowTemplateRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowTemplateRepository.java index 05c97cbbd..7f099da06 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowTemplateRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/FlowTemplateRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/GlobalConfigRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/GlobalConfigRepository.java index 87fc55274..8587113a1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/GlobalConfigRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/GlobalConfigRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,7 @@ package com.welab.wefe.board.service.database.repository; -import com.welab.wefe.board.service.database.entity.GlobalConfigMySqlModel; +import com.welab.wefe.board.service.database.entity.GlobalConfigMysqlModel; import com.welab.wefe.board.service.database.repository.base.BaseRepository; import org.springframework.stereotype.Repository; @@ -26,7 +26,7 @@ * @author Zane */ @Repository -public interface GlobalConfigRepository extends BaseRepository { +public interface GlobalConfigRepository extends BaseRepository { - List findByGroup(String group); + List findByGroup(String group); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ImageDataSetSampleRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ImageDataSetSampleRepository.java new file mode 100644 index 000000000..25e573da2 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ImageDataSetSampleRepository.java @@ -0,0 +1,50 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.repository; + +import com.welab.wefe.board.service.database.entity.data_set.ImageDataSetSampleMysqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.data.jpa.repository.Modifying; +import org.springframework.data.jpa.repository.Query; +import org.springframework.stereotype.Repository; +import org.springframework.transaction.annotation.Transactional; + +import java.util.List; + +/** + * @author zane + * @date 2021/11/10 + */ +@Repository +public interface ImageDataSetSampleRepository extends BaseRepository { + @Modifying(clearAutomatically = true) + @Transactional + void deleteByDataSetId(String dataSetId); + + @Query(value = "select label_list from #{#entityName} where data_set_id=?1 and labeled=true;", nativeQuery = true) + List getAllLabelList(String dataSetId); + + @Query(value = "select label_list from #{#entityName} where data_set_id=?1 and labeled=true group by label_list;", nativeQuery = true) + List getAllDistinctLabelList(String dataSetId); + + + @Query(value = "select count(*) from #{#entityName} where data_set_id=?1 and labeled=true", nativeQuery = true) + long getLabeledCount(String dataSetId); + + @Query(value = "select count(*) from #{#entityName} where data_set_id=?1", nativeQuery = true) + long getSampleCount(String dataSetId); + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/JobMemberRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/JobMemberRepository.java index 65d37943f..cf6cc48a5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/JobMemberRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/JobMemberRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/JobRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/JobRepository.java index 5dc0d0d2c..e91fca2b5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/JobRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/JobRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MemberChatRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MemberChatRepository.java index 17a1a3f53..71ad3120c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MemberChatRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MemberChatRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -47,7 +47,7 @@ public interface MemberChatRepository extends BaseRepository queryChatList(@Param("selfMemberId") String selfMemberId); - @Modifying + @Modifying(clearAutomatically = true) @Query(value = "update member_chat set status = :newStatus where from_account_id = :fromAccountId " + "and to_account_id = :toAccountId and status = :status", nativeQuery = true) void updateMessageStatus(@Param("fromAccountId") String fromAccountId, @Param("toAccountId") String toAccountId, @Param("status") int status, @Param("newStatus") int newStatus); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MessageQueueRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MessageQueueRepository.java index 53047141a..461c45de6 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MessageQueueRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MessageQueueRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MessageRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MessageRepository.java index af46f3c33..8b68df2bc 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MessageRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/MessageRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ModelOotRecordRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ModelOotRecordRepository.java index 44778ac9c..43388b11b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ModelOotRecordRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ModelOotRecordRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/OperationLogRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/OperationLogRepository.java index 596004b19..68f8cadb9 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/OperationLogRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/OperationLogRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/OutputModelRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/OutputModelRepository.java index 84382fc62..690f5eb1c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/OutputModelRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/OutputModelRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectDataSetRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectDataSetRepository.java index cdfbf556f..befb4ee02 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectDataSetRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectDataSetRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -37,7 +37,7 @@ public interface ProjectDataSetRepository extends BaseRepository nodeIds); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectFlowRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectFlowRepository.java index 2bfecfdc8..00f41fd07 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectFlowRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectFlowRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectMemberAuditRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectMemberAuditRepository.java index 2626d1eeb..22e9b84fa 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectMemberAuditRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/ProjectMemberAuditRepository.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -34,7 +34,7 @@ public interface ProjectMemberAuditRepository extends BaseRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepository.java index 254a2ba1b..f87e8333f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepository.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepository.java @@ -1,4 +1,4 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepositoryFactoryBean.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepositoryFactoryBean.java index 5f9f64510..9a3f8a79a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepositoryFactoryBean.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepositoryFactoryBean.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepositoryImpl.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepositoryImpl.java index 34cc5f350..996dc810b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepositoryImpl.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/BaseRepositoryImpl.java @@ -1,4 +1,4 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/RepositoryManager.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/RepositoryManager.java new file mode 100644 index 000000000..7e25bb2f2 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/base/RepositoryManager.java @@ -0,0 +1,63 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.repository.base; + +import com.welab.wefe.board.service.database.entity.base.AbstractMySqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.DataResourceRepository; +import com.welab.wefe.common.util.ClassUtils; +import com.welab.wefe.common.util.ReflectionsUtil; +import com.welab.wefe.common.web.Launcher; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * @author zane + * @date 2021/11/11 + */ +public class RepositoryManager { + /** + * AbstractMySqlModel : BaseRepository + */ + private static final Map, Class> MAP = new HashMap(); + + public static T get(Class mysqlModelClass) { + if (MAP.isEmpty()) { + List> list = ReflectionsUtil + .getClassesImplementing(BaseRepository.class, "com.welab.wefe") + .stream() + .filter(x -> x.isInterface()) + .collect(Collectors.toList()); + + for (Class repoClass : list) { + Class entityClass = ClassUtils.getGenericClass(repoClass, 0); + if (entityClass != null) { + MAP.put(entityClass, repoClass); + } + } + } + + // 由于 DataResourceRepository 使用了泛型声明 + // 无法获取到具体的 GenericClass + // 所以这里手动 put 一下 + MAP.put(DataResourceMysqlModel.class, DataResourceRepository.class); + + return (T) Launcher.getBean(MAP.get(mysqlModelClass)); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/BloomFilterRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/BloomFilterRepository.java new file mode 100644 index 000000000..70a9dadad --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/BloomFilterRepository.java @@ -0,0 +1,27 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.repository.data_resource; + +import com.welab.wefe.board.service.database.entity.data_resource.BloomFilterMysqlModel; +import org.springframework.stereotype.Repository; + +/** + * @author zane + * @date 2021/12/1 + */ +@Repository("bloomFilterRepository") +public interface BloomFilterRepository extends DataResourceRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/DataResourceRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/DataResourceRepository.java new file mode 100644 index 000000000..76a1754f8 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/DataResourceRepository.java @@ -0,0 +1,50 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.repository.data_resource; + +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.data.jpa.repository.Modifying; +import org.springframework.data.jpa.repository.Query; +import org.springframework.stereotype.Repository; +import org.springframework.transaction.annotation.Transactional; + +import java.util.List; + +/** + * @author Zane + */ +@Repository("dataResourceRepository") +public interface DataResourceRepository extends BaseRepository { + + @Query(value = "select tags,count(tags) as count from #{#entityName} where data_resource_type=?1 and tags<>'' group by tags;", nativeQuery = true) + List listAllTags(String resourceType); + + @Query(value = "select tags,count(tags) as count from #{#entityName} where tags<>'' group by tags;", nativeQuery = true) + List listAllTags(); + + @Query(value = "select count(*) from #{#entityName} where name=?1", nativeQuery = true) + int countByName(String name); + + @Query(value = "select count(*) from #{#entityName} where name=?1 and id<>?2", nativeQuery = true) + int countByName(String name, String id); + + @Modifying(clearAutomatically = true) + @Transactional + @Query(value = "update #{#entityName} set usage_count_in_project=(select count(*) from project_data_set where data_set_id=?1 and audit_status='agree') where id=?1", nativeQuery = true) + void updateUsageCountInProject(String dataSetId); +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/DataResourceUploadTaskRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/DataResourceUploadTaskRepository.java new file mode 100644 index 000000000..56530c0a6 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/DataResourceUploadTaskRepository.java @@ -0,0 +1,28 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.repository.data_resource; + +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceUploadTaskMysqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.stereotype.Repository; + +/** + * @author zane + * @date 2021/12/1 + */ +@Repository +public interface DataResourceUploadTaskRepository extends BaseRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/ImageDataSetRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/ImageDataSetRepository.java new file mode 100644 index 000000000..8335ff881 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/ImageDataSetRepository.java @@ -0,0 +1,27 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.repository.data_resource; + +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import org.springframework.stereotype.Repository; + +/** + * @author zane + * @date 2021/12/1 + */ +@Repository("imageDataSetRepository") +public interface ImageDataSetRepository extends DataResourceRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/TableDataSetRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/TableDataSetRepository.java new file mode 100644 index 000000000..b41615d40 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/data_resource/TableDataSetRepository.java @@ -0,0 +1,27 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.repository.data_resource; + +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; +import org.springframework.stereotype.Repository; + +/** + * @author zane + * @date 2021/12/1 + */ +@Repository("tableDataSetRepository") +public interface TableDataSetRepository extends DataResourceRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/BloomFilterColumnRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/BloomFilterColumnRepository.java new file mode 100644 index 000000000..ab85d3dd1 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/BloomFilterColumnRepository.java @@ -0,0 +1,34 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.repository.fusion; + +import com.welab.wefe.board.service.database.entity.fusion.bloomfilter.BloomFilterColumnMysqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.data.jpa.repository.Modifying; +import org.springframework.stereotype.Repository; +import org.springframework.transaction.annotation.Transactional; + +/** + * @author jacky.jiang + */ +@Repository +public interface BloomFilterColumnRepository extends BaseRepository { + + @Modifying(clearAutomatically = true) + @Transactional + void deleteByBloomFilterId(String bloomFilterId); +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/BloomFilterTaskRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/BloomFilterTaskRepository.java new file mode 100644 index 000000000..973d2b549 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/BloomFilterTaskRepository.java @@ -0,0 +1,28 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.repository.fusion; + +import com.welab.wefe.board.service.database.entity.fusion.bloomfilter.BloomFilterTaskMysqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.stereotype.Repository; + +/** + * @author lonnie + */ +@Repository +public interface BloomFilterTaskRepository extends BaseRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/ExportProgressRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/ExportProgressRepository.java new file mode 100644 index 000000000..03e7b302c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/ExportProgressRepository.java @@ -0,0 +1,35 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.repository.fusion; + + +import com.welab.wefe.board.service.database.entity.fusion.ExportProgressMySqlModel; +import com.welab.wefe.board.service.database.entity.fusion.FusionResultMySqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; +import org.springframework.stereotype.Repository; + +/** + * @author Hunter + */ +@Repository +public interface ExportProgressRepository extends BaseRepository { + @Query(value = "select * from #{#entityName} where business_id=?1 order by created_time desc limit 1", nativeQuery = true) + ExportProgressMySqlModel findLastByBusinessId(@Param("businessId") String businessId); + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FieldInfoRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FieldInfoRepository.java new file mode 100644 index 000000000..ed91bb474 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FieldInfoRepository.java @@ -0,0 +1,28 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.repository.fusion; + +import com.welab.wefe.board.service.database.entity.fusion.FieldInfoMySqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.stereotype.Repository; + +/** + * @author hunter.zhao + */ +@Repository +public interface FieldInfoRepository extends BaseRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FusionActuatorInfoRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FusionActuatorInfoRepository.java new file mode 100644 index 000000000..b063748a1 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FusionActuatorInfoRepository.java @@ -0,0 +1,29 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.database.repository.fusion; + + + +import com.welab.wefe.board.service.database.entity.fusion.FusionActuatorInfoMySqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.stereotype.Repository; + +/** + * @author hunter.zhao + */ +@Repository +public interface FusionActuatorInfoRepository extends BaseRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FusionResultRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FusionResultRepository.java new file mode 100644 index 000000000..ad06dc17c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FusionResultRepository.java @@ -0,0 +1,29 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.repository.fusion; + + +import com.welab.wefe.board.service.database.entity.fusion.FusionResultMySqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.stereotype.Repository; + +/** + * @author Hunter + */ +@Repository +public interface FusionResultRepository extends BaseRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FusionTaskRepository.java b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FusionTaskRepository.java new file mode 100644 index 000000000..e43522dbc --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/database/repository/fusion/FusionTaskRepository.java @@ -0,0 +1,29 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.database.repository.fusion; + + +import com.welab.wefe.board.service.database.entity.fusion.FusionTaskMySqlModel; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import org.springframework.stereotype.Repository; + +/** + * @author Hunter + */ +@Repository +public interface FusionTaskRepository extends BaseRepository { +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/base/PagingInput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/base/PagingInput.java index 0a76fd7ea..20c34e25a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/base/PagingInput.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/base/PagingInput.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/base/PagingOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/base/PagingOutput.java index 9ee8d45b2..8b673be55 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/base/PagingOutput.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/base/PagingOutput.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,8 +16,8 @@ package com.welab.wefe.board.service.dto.base; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.web.dto.AbstractApiOutput; +import com.welab.wefe.common.web.util.ModelMapper; import java.util.List; import java.util.stream.Collectors; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/AbstractOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/AbstractOutputModel.java index a5db2f1c6..2d4c52277 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/AbstractOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/AbstractOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,6 +17,7 @@ package com.welab.wefe.board.service.dto.entity; import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.dto.AbstractApiOutput; import org.apache.commons.lang3.StringUtils; @@ -27,37 +28,40 @@ */ public class AbstractOutputModel extends AbstractApiOutput { - /** - * 全局唯一标识 - */ + @Check(name = "全局唯一标识") private String id; - /** - * 创建人 - */ + @Check(name = "创建人") private String createdBy; - /** - * 创建时间 - */ + @Check(name = "创建时间") private Date createdTime; - /** - * 更新人 - */ + @Check(name = "更新人") private String updatedBy; - /** - * 更新时间 - */ + @Check(name = "更新时间") private Date updatedTime; - /** - * 创建者昵称 - */ + @Check(name = "创建者昵称") private String creatorNickname; - /** - * 修改者昵称 - */ + @Check(name = "修改者昵称") private String updaterNickname; + + public void setCreatedBy(String createdBy) { + this.createdBy = createdBy; + this.creatorNickname = CacheObjects.getNickname(createdBy); + if (StringUtils.isBlank(this.creatorNickname)) { + this.creatorNickname = CacheObjects.getMemberName(createdBy); + } + } + + public void setUpdatedBy(String updatedBy) { + this.updatedBy = updatedBy; + this.updaterNickname = CacheObjects.getNickname(updatedBy); + if (StringUtils.isBlank(this.updaterNickname)) { + this.updaterNickname = CacheObjects.getMemberName(updatedBy); + } + } + //region getter/setter public String getId() { @@ -72,14 +76,6 @@ public String getCreatedBy() { return createdBy; } - public void setCreatedBy(String createdBy) { - this.createdBy = createdBy; - this.creatorNickname = CacheObjects.getNickname(createdBy); - if (StringUtils.isBlank(this.creatorNickname)) { - this.creatorNickname = CacheObjects.getMemberName(createdBy); - } - } - public Date getCreatedTime() { return createdTime; } @@ -92,14 +88,6 @@ public String getUpdatedBy() { return updatedBy; } - public void setUpdatedBy(String updatedBy) { - this.updatedBy = updatedBy; - this.updaterNickname = CacheObjects.getNickname(updatedBy); - if (StringUtils.isBlank(this.updaterNickname)) { - this.updaterNickname = CacheObjects.getMemberName(updatedBy); - } - } - public Date getUpdatedTime() { return updatedTime; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/AccountOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/AccountOutputModel.java index d93b99cbd..a3a7bd34c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/AccountOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/AccountOutputModel.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,49 +16,45 @@ package com.welab.wefe.board.service.dto.entity; -import com.welab.wefe.common.enums.AuditStatus; + +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.Masker; import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.wefe.enums.AuditStatus; + +import java.util.Date; /** * @author Zane */ public class AccountOutputModel extends AbstractOutputModel { - /** - * 手机号 - */ + @Check(name = "手机号") private String phoneNumber; - /** - * 昵称 - */ + @Check(name = "昵称") private String nickname; - /** - * 邮箱 - */ + @Check(name = "邮箱") private String email; - /** - * 是否是超级管理员;超级管理员通常是第一个创建并初始化系统的那个人 - */ + @Check(name = "是否是超级管理员;超级管理员通常是第一个创建并初始化系统的那个人") private Boolean superAdminRole; - /** - * 是否是管理员;管理员有更多权限,比如设置 member 是否对外可见。 - */ + @Check(name = "是否是管理员;管理员有更多权限,比如设置 member 是否对外可见。") private Boolean adminRole; - /** - * 审核状态 - */ + @Check(name = "审核状态") private AuditStatus auditStatus; - /** - * 审核意见 - */ + @Check(name = "审核意见") private String auditComment; + @Check(name = "是否可用") + private Boolean enable; /** - * 是否可用 + * 是否已注销 */ - private Boolean enable; + private boolean cancelled; + /** + * 最后活动时间 + */ + private Date lastActionTime; public String getEmail() { if (!CurrentAccount.isAdmin()) { @@ -135,6 +131,22 @@ public void setEnable(Boolean enable) { this.enable = enable; } + public boolean isCancelled() { + return cancelled; + } + + public void setCancelled(boolean cancelled) { + this.cancelled = cancelled; + } + + public Date getLastActionTime() { + return lastActionTime; + } + + public void setLastActionTime(Date lastActionTime) { + this.lastActionTime = lastActionTime; + } + //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/BlacklistOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/BlacklistOutputModel.java index 9e9a5cf2c..edee8b0ce 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/BlacklistOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/BlacklistOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,6 +16,8 @@ package com.welab.wefe.board.service.dto.entity; +import com.welab.wefe.common.fieldvalidate.annotation.Check; + import java.util.Date; /** @@ -25,29 +27,19 @@ public class BlacklistOutputModel { private String id; - /** - * Member id - */ + @Check(name = "Member id") private String memberId; - /** - * Member name - */ + @Check(name = "Member name") private String memberName; - /** - * Remark - */ + @Check(name = "Remark") private String remark; - /** - * Creator - */ + @Check(name = "Creator") private String createdBy; - /** - * Created time - */ + @Check(name = "Created time") private Date createdTime; public String getId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/BloomFilterDataResourceListOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/BloomFilterDataResourceListOutputModel.java new file mode 100644 index 000000000..1c6ce20d1 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/BloomFilterDataResourceListOutputModel.java @@ -0,0 +1,72 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.entity; + +import com.welab.wefe.board.service.dto.entity.project.data_set.ProjectDataResourceOutputModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.JobMemberRole; + +import java.util.List; + +/** + * @author jacky.jiang + */ +public class BloomFilterDataResourceListOutputModel extends AbstractOutputModel { + + @Check(name = "项目ID") + private String projectId; + + @Check(name = "我方身份;枚举(promoter/provider)") + private JobMemberRole myRole; + + @Check(name = "我方成员ID") + private String memberId; + + private List dataSetList; + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public JobMemberRole getMyRole() { + return myRole; + } + + public void setMyRole(JobMemberRole myRole) { + this.myRole = myRole; + } + + public String getMemberId() { + return memberId; + } + + public void setMemberId(String memberId) { + this.memberId = memberId; + } + + public List getDataSetList() { + return dataSetList; + } + + public void setDataSetList(List dataSetList) { + this.dataSetList = dataSetList; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ChatLastAccountOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ChatLastAccountOutputModel.java index 4e6d4e024..f3cab3fb3 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ChatLastAccountOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ChatLastAccountOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,6 +16,8 @@ package com.welab.wefe.board.service.dto.entity; +import com.welab.wefe.common.fieldvalidate.annotation.Check; + /** * Recent chat account object output entity * @@ -23,45 +25,27 @@ **/ public class ChatLastAccountOutputModel extends AbstractOutputModel { - /** - * account id - */ + @Check(name = "account id") private String accountId; - /** - * account name - */ + @Check(name = "account name") private String accountName; - /** - * member id - */ + @Check(name = "member id") private String memberId; - /** - * member name - */ + @Check(name = "member name") private String memberName; - /** - * liaison member id - */ + @Check(name = "liaison member id") private String liaisonMemberId; - /** - * liaison member name - */ + @Check(name = "liaison member name") private String liaisonMemberName; - /** - * liaison account id - */ + @Check(name = "liaison account id") private String liaisonAccountId; - /** - * liaison account name - */ + @Check(name = "liaison account name") private String liaisonAccountName; - /** - * unread num - */ + @Check(name = "unread num") private Integer unreadNum = 0; public String getAccountId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataIoTaskFeatureInfoOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataIoTaskFeatureInfoOutputModel.java index b2b6269b6..579b43b0b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataIoTaskFeatureInfoOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataIoTaskFeatureInfoOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,8 +16,10 @@ package com.welab.wefe.board.service.dto.entity; -import com.welab.wefe.common.enums.JobMemberRole; + +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.dto.AbstractApiOutput; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import java.util.List; @@ -27,25 +29,15 @@ * @author aaron.li **/ public class DataIoTaskFeatureInfoOutputModel extends AbstractApiOutput { - /** - * 成员ID - */ + @Check(name = "成员ID") private String memberId; - /** - * 成员名称 - */ + @Check(name = "成员名称") private String memberName; - /** - * 角色 - */ + @Check(name = "角色") private JobMemberRole role; - /** - * 数据集ID - */ + @Check(name = "数据集ID") private String dataSetId; - /** - * 选择入模的特征列 - */ + @Check(name = "选择入模的特征列") private List features; public String getMemberId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataOutputInfoOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataOutputInfoOutputModel.java deleted file mode 100644 index 20a45a3f2..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataOutputInfoOutputModel.java +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.entity; - -/** - * @author aaron.li - **/ -public class DataOutputInfoOutputModel extends AbstractOutputModel { - /** - * 组件名称 - */ - private String componentName; - - /** - * 模型id - */ - private String partyModelId; - /** - * 模型版本 - */ - private String modelVersion; - - public String getPartyModelId() { - return partyModelId; - } - - public void setPartyModelId(String partyModelId) { - this.partyModelId = partyModelId; - } - - public String getModelVersion() { - return modelVersion; - } - - public void setModelVersion(String modelVersion) { - this.modelVersion = modelVersion; - } - - public String getComponentName() { - return componentName; - } - - public void setComponentName(String componentName) { - this.componentName = componentName; - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataSetTaskOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataSetTaskOutputModel.java deleted file mode 100644 index 33457b402..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/DataSetTaskOutputModel.java +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.entity; - -/** - * @author zane.luo - */ -public class DataSetTaskOutputModel extends AbstractOutputModel { - - /** - * 数据集名 - */ - private String dataSetName; - - /** - * 数据集id - */ - private String dataSetId; - - /** - * 总数据行数 - */ - private long totalRowCount = 0; - /** - * 已写入数据行数 - */ - private long addedRowCount = 0; - - /** - * 任务进度百分比 - */ - private int progress; - - /** - * 预计剩余耗时 - */ - private long estimateTime; - - /** - * 主键重复条数 - */ - private long repeatIdRowCount; - /** - * 错误消息 - */ - private String errorMessage; - - // region getter/setter - - - public String getDataSetName() { - return dataSetName; - } - - public void setDataSetName(String dataSetName) { - this.dataSetName = dataSetName; - } - - public String getDataSetId() { - return dataSetId; - } - - public void setDataSetId(String dataSetId) { - this.dataSetId = dataSetId; - } - - public long getTotalRowCount() { - return totalRowCount; - } - - public void setTotalRowCount(long totalRowCount) { - this.totalRowCount = totalRowCount; - } - - public long getAddedRowCount() { - return addedRowCount; - } - - public void setAddedRowCount(long addedRowCount) { - this.addedRowCount = addedRowCount; - } - - public int getProgress() { - return progress; - } - - public void setProgress(int progress) { - this.progress = progress; - } - - public long getEstimateTime() { - return estimateTime; - } - - public void setEstimateTime(long estimateTime) { - this.estimateTime = estimateTime; - } - - public long getRepeatIdRowCount() { - return repeatIdRowCount; - } - - public void setRepeatIdRowCount(long repeatIdRowCount) { - this.repeatIdRowCount = repeatIdRowCount; - } - - public String getErrorMessage() { - return errorMessage; - } - - public void setErrorMessage(String errorMessage) { - this.errorMessage = errorMessage; - } - - // endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberChatOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberChatOutputModel.java index 83ccc2a41..c0eea1a77 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberChatOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberChatOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,6 +16,8 @@ package com.welab.wefe.board.service.dto.entity; +import com.welab.wefe.common.fieldvalidate.annotation.Check; + /** * 聊天详情输出对象 * @@ -23,35 +25,21 @@ **/ public class MemberChatOutputModel extends AbstractOutputModel { - /** - * 发送方的账号id - */ + @Check(name = "发送方的账号id") private String fromAccountId; - /** - * 发送方成员ID - */ + @Check(name = "发送方成员ID") private String fromMemberId; - /** - * 接收方的账号id - */ + @Check(name = "接收方的账号id") private String toAccountId; - /** - * 发送方成员名称 - */ + @Check(name = "发送方成员名称") private String toMemberId; - /** - * 聊天内容 - */ + @Check(name = "聊天内容") private String content; - /** - * 状态:(0:已读、1:未读、2、发送成功、3、发送失败) - */ + @Check(name = "状态:(0:已读、1:未读、2、发送成功、3、发送失败)") private Integer status; - /** - * 消息ID - */ + @Check(name = "消息ID") private String messageId; public String getFromAccountId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberFeatureInfoModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberFeatureInfoModel.java index 22d1cf83e..b0421ba78 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberFeatureInfoModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberFeatureInfoModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberModel.java index 0ff7090ec..0c31cc8b7 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,9 +16,10 @@ package com.welab.wefe.board.service.dto.entity; -import com.welab.wefe.common.enums.JobMemberRole; + import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.JobMemberRole; /** * @author lonnie diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberOutputModel.java index 07b65e69c..6b33a4606 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MemberOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MessageOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MessageOutputModel.java index 8680d50fb..357419f4c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MessageOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/MessageOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,32 +16,24 @@ package com.welab.wefe.board.service.dto.entity; -import com.welab.wefe.common.enums.MessageLevel; + +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.MessageLevel; /** * @author Zane */ public class MessageOutputModel extends AbstractOutputModel { - /** - * 消息生产者;枚举(board/gateway) - */ + @Check(name = "消息生产者;枚举(board/gateway)") private String producer; - /** - * 消息级别;枚举(info/success/error/warning) - */ + @Check(name = "消息级别;枚举(info/success/error/warning)") private MessageLevel level; - /** - * 标题 - */ + @Check(name = "标题") private String title; - /** - * 内容 - */ + @Check(name = "内容") private String content; - /** - * 未读 - */ + @Check(name = "未读") private Boolean unread; //region getter/setter diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/OperationLogOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/OperationLogOutputModel.java index 024303f5d..13aa79f03 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/OperationLogOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/OperationLogOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,65 +16,41 @@ package com.welab.wefe.board.service.dto.entity; -import org.apache.commons.lang3.StringUtils; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.common.fieldvalidate.annotation.Check; /** * @author eval **/ public class OperationLogOutputModel extends AbstractOutputModel { - /** - * 请求接口 - */ + @Check(name = "请求接口") private String logInterface; - /** - * 请求接口名称 - */ + @Check(name = "请求接口名称") private String interfaceName; - /** - * 请求IP - */ + @Check(name = "请求IP") private String requestIp; - /** - * 操作人员编号 - */ + @Check(name = "操作人员编号") private String operatorId; - /** - * 操作人员手机号 - */ - private String operatorPhone; - - /** - * 请求token - */ + @Check(name = "请求token") private String token; - /** - * 操作行为 - */ + @Check(name = "操作行为") private String logAction; - /** - * 请求结果编码 - */ + @Check(name = "请求结果编码") private int resultCode; - /** - * 请求结果 - */ + @Check(name = "请求结果") private String resultMessage; - /** - * 输出的手机号要脱敏 - */ - public void setOperatorPhone(String operatorPhone) { - this.operatorPhone = StringUtils.overlay(operatorPhone, "****", 3, 7); + public String getOperatorNickname() { + return CacheObjects.getNickname(operatorId); } - public String getLogInterface() { return logInterface; } @@ -107,10 +83,6 @@ public void setOperatorId(String operatorId) { this.operatorId = operatorId; } - public String getOperatorPhone() { - return operatorPhone; - } - public String getToken() { return token; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectDataSetInput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectDataSetInput.java index d0185a3ef..3bba5eaa9 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectDataSetInput.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectDataSetInput.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,9 +16,11 @@ package com.welab.wefe.board.service.dto.entity; -import com.welab.wefe.common.enums.JobMemberRole; + import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; /** * @author zane.luo @@ -32,6 +34,10 @@ public class ProjectDataSetInput extends AbstractCheckModel { @Check(name = "数据集 Id", require = true) private String dataSetId; + @Check(name = "数据集类型", require = true) + private DataResourceType dataResourceType; + + //region getter/setter @@ -59,6 +65,13 @@ public void setDataSetId(String dataSetId) { this.dataSetId = dataSetId; } + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectMemberAuditOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectMemberAuditOutput.java index 343ceae42..68686752c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectMemberAuditOutput.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectMemberAuditOutput.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,34 +17,25 @@ package com.welab.wefe.board.service.dto.entity; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.AuditStatus; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.AuditStatus; /** * @author zane.luo */ public class ProjectMemberAuditOutput extends AbstractOutputModel { - /** - * 所属项目 Id 项目主键 - */ + @Check(name = "所属项目 Id 项目主键") private String projectId; - /** - * 成员 Id - */ + @Check(name = "成员 Id") private String memberId; - /** - * 审核人 - */ + @Check(name = "审核人") private String auditorId; - /** - * 审核结果;枚举值(adopt/disagree) - */ + @Check(name = "审核结果;枚举值(adopt/disagree)") private AuditStatus auditResult; - /** - * 审核意见 - */ + @Check(name = "审核意见") private String auditComment; public String getAuditorName() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectMemberInput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectMemberInput.java index 90defb9e6..ca4788f9d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectMemberInput.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/ProjectMemberInput.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,9 +16,10 @@ package com.welab.wefe.board.service.dto.entity; -import com.welab.wefe.common.enums.JobMemberRole; + import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import java.util.List; @@ -28,7 +29,7 @@ public class ProjectMemberInput extends AbstractCheckModel { @Check(name = "成员Id", require = true, messageOnEmpty = "请选择项目合作方") private String memberId; - @Check(name = "成员角色", hiddenForFrontEnd = true) + @Check(name = "成员角色", donotShow = true) private JobMemberRole memberRole; @Check(name = "数据集列表") diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/component/ComponentOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/component/ComponentOutputModel.java index df528cccf..0b5224822 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/component/ComponentOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/component/ComponentOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,21 +16,17 @@ package com.welab.wefe.board.service.dto.entity.component; +import com.welab.wefe.common.fieldvalidate.annotation.Check; + /** * @author aaron.li **/ public class ComponentOutputModel { - /** - * 组件唯一标识 - */ + @Check(name = "组件唯一标识") private String id; - /** - * 组件中文名称 - */ + @Check(name = "组件中文名称") private String name; - /** - * 描述 - */ + @Check(name = "描述") private String desc; public ComponentOutputModel(String id, String name, String desc) { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/BloomFilterOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/BloomFilterOutputModel.java new file mode 100644 index 000000000..2dde13ca0 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/BloomFilterOutputModel.java @@ -0,0 +1,81 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.entity.data_resource.output; + +import com.welab.wefe.board.service.constant.BloomfilterAddMethod; +import com.welab.wefe.common.fieldvalidate.annotation.Check; + +/** + * @author zane + * @date 2021/12/1 + */ +public class BloomFilterOutputModel extends DataResourceOutputModel { + @Check(name = "数据源id") + private String dataSourceId; + @Check(name = "数据源地址") + private String sourcePath; + @Check(name = "主键hash生成方法") + private String hashFunction; + @Check(name = "布隆过滤器添加方式") + private BloomfilterAddMethod addMethod; + @Check(name = "sql语句") + private String sqlScript; + + // region getter/setter + + public String getDataSourceId() { + return dataSourceId; + } + + public void setDataSourceId(String dataSourceId) { + this.dataSourceId = dataSourceId; + } + + public String getSourcePath() { + return sourcePath; + } + + public void setSourcePath(String sourcePath) { + this.sourcePath = sourcePath; + } + + public String getHashFunction() { + return hashFunction; + } + + public void setHashFunction(String hashFunction) { + this.hashFunction = hashFunction; + } + + public BloomfilterAddMethod getAddMethod() { + return addMethod; + } + + public void setAddMethod(BloomfilterAddMethod addMethod) { + this.addMethod = addMethod; + } + + public String getSqlScript() { + return sqlScript; + } + + public void setSqlScript(String sqlScript) { + this.sqlScript = sqlScript; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/DataResourceOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/DataResourceOutputModel.java new file mode 100644 index 000000000..a5135aba5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/DataResourceOutputModel.java @@ -0,0 +1,276 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.entity.data_resource.output; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.DataResourceStorageType; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DataResourcePublicLevel; + +import java.util.Map; +import java.util.TreeMap; + +/** + * @author zane + * @date 2021/12/1 + */ +public class DataResourceOutputModel extends AbstractOutputModel { + @Check(name = "资源名称") + private String name; + @Check(name = "资源类型") + private DataResourceType dataResourceType; + @Check(name = "描述") + private String description; + @Check(name = "标签") + private String tags; + @Check(name = "存储类型") + private DataResourceStorageType storageType; + @Check(name = "资源在存储中的命名空间;库名、目录路径)") + private String storageNamespace; + @Check(name = "资源在存储中的名称;表名、文件名)") + private String storageResourceName; + @Check(name = "总数据量") + private Long totalDataCount; + @Check(name = "资源的可见性") + private DataResourcePublicLevel publicLevel; + @Check(name = "可见成员列表;只有在列表中的联邦成员才可以看到该资源的基本信息") + private String publicMemberList; + @Check(name = "该资源在多少个job中被使用") + private Integer usageCountInJob; + @Check(name = "该资源在多少个flow中被使用") + private Integer usageCountInFlow; + @Check(name = "该资源在多少个project中被使用") + private Integer usageCountInProject; + @Check(name = "该资源被多少个其他成员被使用") + private Integer usageCountInMember; + @Check(name = "是否是衍生资源") + private boolean derivedResource; + @Check(name = "衍生来源,枚举;原始、对齐、分箱)") + private ComponentType derivedFrom; + @Check(name = "衍生来源流程id") + private String derivedFromFlowId; + @Check(name = "衍生来源任务id") + private String derivedFromJobId; + @Check(name = "衍生来源子任务id") + private String derivedFromTaskId; + @Check(name = "该数据资源相关的统计信息") + private JSONObject statisticalInformation; + @Check(name = "数据集是否已被删除") + private boolean deleted; + + public String getDerivedFromCn() { + if (derivedFrom != null) { + return derivedFrom.getLabel(); + } + return ""; + } + + public Map getPublicMemberInfoList() { + TreeMap map = new TreeMap<>(); + + if (publicMemberList == null) { + return map; + } + StringUtil + .splitWithoutEmptyItem(publicMemberList, ",") + .forEach(item -> { + map.put(item, CacheObjects.getMemberName(item)); + }); + + return map; + } + + public String getDataResourceId() { + return super.getId(); + } + + // region getter/setter + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public String getTags() { + return tags; + } + + public void setTags(String tags) { + this.tags = tags; + } + + public DataResourceStorageType getStorageType() { + return storageType; + } + + public void setStorageType(DataResourceStorageType storageType) { + this.storageType = storageType; + } + + public String getStorageNamespace() { + return storageNamespace; + } + + public void setStorageNamespace(String storageNamespace) { + this.storageNamespace = storageNamespace; + } + + public String getStorageResourceName() { + return storageResourceName; + } + + public void setStorageResourceName(String storageResourceName) { + this.storageResourceName = storageResourceName; + } + + public Long getTotalDataCount() { + return totalDataCount; + } + + public void setTotalDataCount(Long totalDataCount) { + this.totalDataCount = totalDataCount; + } + + public DataResourcePublicLevel getPublicLevel() { + return publicLevel; + } + + public void setPublicLevel(DataResourcePublicLevel publicLevel) { + this.publicLevel = publicLevel; + } + + public String getPublicMemberList() { + return publicMemberList; + } + + public void setPublicMemberList(String publicMemberList) { + this.publicMemberList = publicMemberList; + } + + public Integer getUsageCountInJob() { + return usageCountInJob; + } + + public void setUsageCountInJob(Integer usageCountInJob) { + this.usageCountInJob = usageCountInJob; + } + + public Integer getUsageCountInFlow() { + return usageCountInFlow; + } + + public void setUsageCountInFlow(Integer usageCountInFlow) { + this.usageCountInFlow = usageCountInFlow; + } + + public Integer getUsageCountInProject() { + return usageCountInProject; + } + + public void setUsageCountInProject(Integer usageCountInProject) { + this.usageCountInProject = usageCountInProject; + } + + public Integer getUsageCountInMember() { + return usageCountInMember; + } + + public void setUsageCountInMember(Integer usageCountInMember) { + this.usageCountInMember = usageCountInMember; + } + + public boolean isDerivedResource() { + return derivedResource; + } + + public void setDerivedResource(boolean derivedResource) { + this.derivedResource = derivedResource; + } + + public ComponentType getDerivedFrom() { + return derivedFrom; + } + + public void setDerivedFrom(ComponentType derivedFrom) { + this.derivedFrom = derivedFrom; + } + + public String getDerivedFromFlowId() { + return derivedFromFlowId; + } + + public void setDerivedFromFlowId(String derivedFromFlowId) { + this.derivedFromFlowId = derivedFromFlowId; + } + + public String getDerivedFromJobId() { + return derivedFromJobId; + } + + public void setDerivedFromJobId(String derivedFromJobId) { + this.derivedFromJobId = derivedFromJobId; + } + + public String getDerivedFromTaskId() { + return derivedFromTaskId; + } + + public void setDerivedFromTaskId(String derivedFromTaskId) { + this.derivedFromTaskId = derivedFromTaskId; + } + + public JSONObject getStatisticalInformation() { + return statisticalInformation; + } + + public void setStatisticalInformation(JSONObject statisticalInformation) { + this.statisticalInformation = statisticalInformation; + } + + public boolean isDeleted() { + return deleted; + } + + public void setDeleted(boolean deleted) { + this.deleted = deleted; + } + +// endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/DataResourceUploadTaskOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/DataResourceUploadTaskOutputModel.java new file mode 100644 index 000000000..5f1ada6ad --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/DataResourceUploadTaskOutputModel.java @@ -0,0 +1,132 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.entity.data_resource.output; + +import com.welab.wefe.board.service.database.entity.base.AbstractBaseMySqlModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DataResourceUploadStatus; + +/** + * @author zane + * @date 2021/12/1 + */ +public class DataResourceUploadTaskOutputModel extends AbstractBaseMySqlModel { + @Check(name = "数据资源id") + private String dataResourceId; + @Check(name = "数据资源名称") + private String dataResourceName; + @Check(name = "资源类型") + private DataResourceType dataResourceType; + @Check(name = "总数据行数") + private Long totalDataCount; + @Check(name = "已写入数据行数") + private Long completedDataCount; + @Check(name = "任务进度百分比") + private Integer progressRatio; + @Check(name = "预计剩余耗时") + private long estimateRemainingTime; + @Check(name = "无效数据量;主键重复条数)") + private long invalidDataCount; + @Check(name = "错误消息") + private String errorMessage; + @Check(name = "状态:上传中、已完成、已失败") + private DataResourceUploadStatus status; + + // region getter/setter + + public String getDataResourceId() { + return dataResourceId; + } + + public void setDataResourceId(String dataResourceId) { + this.dataResourceId = dataResourceId; + } + + public String getDataResourceName() { + return dataResourceName; + } + + public void setDataResourceName(String dataResourceName) { + this.dataResourceName = dataResourceName; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public Long getTotalDataCount() { + return totalDataCount; + } + + public void setTotalDataCount(Long totalDataCount) { + this.totalDataCount = totalDataCount; + } + + public Long getCompletedDataCount() { + return completedDataCount; + } + + public void setCompletedDataCount(Long completedDataCount) { + this.completedDataCount = completedDataCount; + } + + public Integer getProgressRatio() { + return progressRatio; + } + + public void setProgressRatio(Integer progressRatio) { + this.progressRatio = progressRatio; + } + + public long getEstimateRemainingTime() { + return estimateRemainingTime; + } + + public void setEstimateRemainingTime(long estimateRemainingTime) { + this.estimateRemainingTime = estimateRemainingTime; + } + + public long getInvalidDataCount() { + return invalidDataCount; + } + + public void setInvalidDataCount(long invalidDataCount) { + this.invalidDataCount = invalidDataCount; + } + + public String getErrorMessage() { + return errorMessage; + } + + public void setErrorMessage(String errorMessage) { + this.errorMessage = errorMessage; + } + + public DataResourceUploadStatus getStatus() { + return status; + } + + public void setStatus(DataResourceUploadStatus status) { + this.status = status; + } + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/ImageDataSetOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/ImageDataSetOutputModel.java new file mode 100644 index 000000000..39ac04bfd --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/ImageDataSetOutputModel.java @@ -0,0 +1,102 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.entity.data_resource.output; + + +import com.alibaba.fastjson.annotation.JSONField; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; + +import java.util.List; +import java.util.TreeSet; + +/** + * @author zane + * @date 2021/12/1 + */ +public class ImageDataSetOutputModel extends DataResourceOutputModel { + @Check(name = "任务类型;物体检测...)") + private DeepLearningJobType forJobType; + @Check(name = "label;列表") + private String labelList; + @Check(name = "已标注数量") + private Long labeledCount; + @Check(name = "是否已标注完毕") + private boolean labelCompleted; + @Check(name = "数据集大小") + private Long filesSize; + + @JSONField(serialize = false) + public TreeSet getLabelSet() { + TreeSet labelSet = new TreeSet<>(); + if (StringUtil.isEmpty(labelList)) { + return labelSet; + } + + List list = StringUtil.splitWithoutEmptyItem(labelList, ","); + for (String label : list) { + labelSet.add(label); + } + return labelSet; + } + + // region getter/setter + + + public DeepLearningJobType getForJobType() { + return forJobType; + } + + public void setForJobType(DeepLearningJobType forJobType) { + this.forJobType = forJobType; + } + + public String getLabelList() { + return labelList; + } + + public void setLabelList(String labelList) { + this.labelList = labelList; + } + + public Long getLabeledCount() { + return labeledCount; + } + + public void setLabeledCount(Long labeledCount) { + this.labeledCount = labeledCount; + } + + public boolean isLabelCompleted() { + return labelCompleted; + } + + public void setLabelCompleted(boolean labelCompleted) { + this.labelCompleted = labelCompleted; + } + + public Long getFilesSize() { + return filesSize; + } + + public void setFilesSize(Long filesSize) { + this.filesSize = filesSize; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/TableDataSetOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/TableDataSetOutputModel.java new file mode 100644 index 000000000..7a0df35ee --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_resource/output/TableDataSetOutputModel.java @@ -0,0 +1,140 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.entity.data_resource.output; + +import com.welab.wefe.common.fieldvalidate.annotation.Check; + +/** + * @author zane + * @date 2021/12/1 + */ +public class TableDataSetOutputModel extends DataResourceOutputModel { + @Check(name = "数据集字段列表") + private String columnNameList; + @Check(name = "数据集列数") + private Integer columnCount; + @Check(name = "主键字段") + private String primaryKeyColumn; + @Check(name = "特征列表") + private String featureNameList; + @Check(name = "特征数量") + private Integer featureCount; + @Check(name = "是否包含;Y 值") + private boolean containsY; + @Check(name = "y列名称列表") + private String yNameList; + @Check(name = "y列的数量") + private Integer yCount; + @Check(name = "正样本的值") + private String positiveSampleValue; + @Check(name = "正例数量") + private Long yPositiveSampleCount; + @Check(name = "正例比例") + private Double yPositiveSampleRatio; + + // region getter/setter + + public String getColumnNameList() { + return columnNameList; + } + + public void setColumnNameList(String columnNameList) { + this.columnNameList = columnNameList; + } + + public Integer getColumnCount() { + return columnCount; + } + + public void setColumnCount(Integer columnCount) { + this.columnCount = columnCount; + } + + public String getPrimaryKeyColumn() { + return primaryKeyColumn; + } + + public void setPrimaryKeyColumn(String primaryKeyColumn) { + this.primaryKeyColumn = primaryKeyColumn; + } + + public String getFeatureNameList() { + return featureNameList; + } + + public void setFeatureNameList(String featureNameList) { + this.featureNameList = featureNameList; + } + + public Integer getFeatureCount() { + return featureCount; + } + + public void setFeatureCount(Integer featureCount) { + this.featureCount = featureCount; + } + + public boolean isContainsY() { + return containsY; + } + + public void setContainsY(boolean containsY) { + this.containsY = containsY; + } + + public String getyNameList() { + return yNameList; + } + + public void setyNameList(String yNameList) { + this.yNameList = yNameList; + } + + public Integer getyCount() { + return yCount; + } + + public void setyCount(Integer yCount) { + this.yCount = yCount; + } + + public String getPositiveSampleValue() { + return positiveSampleValue; + } + + public void setPositiveSampleValue(String positiveSampleValue) { + this.positiveSampleValue = positiveSampleValue; + } + + public Long getyPositiveSampleCount() { + return yPositiveSampleCount; + } + + public void setyPositiveSampleCount(Long yPositiveSampleCount) { + this.yPositiveSampleCount = yPositiveSampleCount; + } + + public Double getyPositiveSampleRatio() { + return yPositiveSampleRatio; + } + + public void setyPositiveSampleRatio(Double yPositiveSampleRatio) { + this.yPositiveSampleRatio = yPositiveSampleRatio; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetColumnInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetColumnInputModel.java index 9817e3248..ccc7ab193 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetColumnInputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetColumnInputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,21 +16,20 @@ package com.welab.wefe.board.service.dto.entity.data_set; -import com.welab.wefe.common.enums.ColumnDataType; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ColumnDataType; /** * @author Zane */ -public class DataSetColumnInputModel { +public class DataSetColumnInputModel extends AbstractCheckModel { - /** - * 字段名称 - */ + @Check(name = "字段名称") private String name; - /** - * 数据类型 - */ + @Check(name = "数据类型") private ColumnDataType dataType; /** * 注释 @@ -38,6 +37,16 @@ public class DataSetColumnInputModel { @Check(regex = "^.{0,250}$", messageOnInvalid = "注释太长啦~") private String comment; + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + if (getDataType() == null) { + throw new StatusCodeWithException("请给字段【" + getName() + "】设置数据类型", StatusCode.PARAMETER_VALUE_INVALID); + } + } + + //region getter/setter public String getName() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetColumnOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetColumnOutputModel.java index 99d3efd2d..5abb8b9df 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetColumnOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetColumnOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,8 @@ package com.welab.wefe.board.service.dto.entity.data_set; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.ColumnDataType; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ColumnDataType; import java.util.Map; @@ -26,33 +27,19 @@ */ public class DataSetColumnOutputModel extends AbstractOutputModel { - /** - * 数据集Id - */ + @Check(name = "数据集Id") private String dataSetId; - /** - * 字段序号 - */ + @Check(name = "字段序号") private Integer index; - /** - * 字段名称 - */ + @Check(name = "字段名称") private String name; - /** - * 数据类型 - */ + @Check(name = "数据类型") private ColumnDataType dataType; - /** - * 注释 - */ + @Check(name = "注释") private String comment; - /** - * 空值数据行数 - */ + @Check(name = "空值数据行数") private Long emptyRows; - /** - * 数值分布 - */ + @Check(name = "数值分布") private Map valueDistribution; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetOutputModel.java index 32602def9..737e419f1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/DataSetOutputModel.java @@ -1,11 +1,11 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,15 @@ package com.welab.wefe.board.service.dto.entity.data_set; +import java.util.List; +import java.util.TreeMap; + import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.DataSetPublicLevel; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.StringUtil; - -import java.util.List; -import java.util.TreeMap; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.DataResourcePublicLevel; /** * @author Zane @@ -93,7 +93,7 @@ public class DataSetOutputModel extends AbstractOutputModel { /** * 数据集的可见性 */ - private DataSetPublicLevel publicLevel; + private DataResourcePublicLevel publicLevel; /** * 使用次数 */ @@ -116,12 +116,12 @@ public class DataSetOutputModel extends AbstractOutputModel { * 来源类型,枚举(原始、对齐、分箱) */ private ComponentType sourceType; - + /** * 来源类型,枚举(原始、对齐、分箱) */ private String sourceTypeCn; - + /** * 来源任务id */ @@ -174,11 +174,11 @@ public void setPublicMemberList(String publicMemberList) throws StatusCodeWithEx //region getter/setter - + public String getName() { return name; } - + public String getSourceTypeCn() { return sourceTypeCn; } @@ -303,11 +303,11 @@ public void setyNameList(String yNameList) { this.yNameList = yNameList; } - public DataSetPublicLevel getPublicLevel() { + public DataResourcePublicLevel getPublicLevel() { return publicLevel; } - public void setPublicLevel(DataSetPublicLevel publicLevel) { + public void setPublicLevel(DataResourcePublicLevel publicLevel) { this.publicLevel = publicLevel; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/ImageDataSetSampleOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/ImageDataSetSampleOutputModel.java new file mode 100644 index 000000000..3dcfdf216 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/data_set/ImageDataSetSampleOutputModel.java @@ -0,0 +1,105 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.entity.data_set; + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; + +/** + * @author zane + * @date 2021/11/12 + */ +public class ImageDataSetSampleOutputModel extends AbstractOutputModel { + + @Check(name = "数据集id") + private String dataSetId; + @Check(name = "文件名") + private String fileName; + @Check(name = "文件路径") + private String filePath; + @Check(name = "文件大小") + private long fileSize; + @Check(name = "label") + private String labelList; + @Check(name = "是否已标注") + private boolean labeled; + @Check(name = "json 形式的标注信息") + private JSONObject labelInfo; + + public String getLabelList() { + // 移除前后的逗号,不然前端会报错。 + return StringUtil.trim(labelList, ','); + } + + // region getter/setter + + public String getDataSetId() { + return dataSetId; + } + + public void setDataSetId(String dataSetId) { + this.dataSetId = dataSetId; + } + + public String getFileName() { + return fileName; + } + + public void setFileName(String fileName) { + this.fileName = fileName; + } + + public String getFilePath() { + return filePath; + } + + public void setFilePath(String filePath) { + this.filePath = filePath; + } + + public long getFileSize() { + return fileSize; + } + + public void setFileSize(long fileSize) { + this.fileSize = fileSize; + } + + public void setLabelList(String labelList) { + this.labelList = labelList; + } + + public boolean isLabeled() { + return labeled; + } + + public void setLabeled(boolean labeled) { + this.labeled = labeled; + } + + public JSONObject getLabelInfo() { + return labelInfo; + } + + public void setLabelInfo(JSONObject labelInfo) { + this.labelInfo = labelInfo; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobListOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobListOutputModel.java index 590672cdb..19d23dcac 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobListOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobListOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,9 +17,10 @@ package com.welab.wefe.board.service.dto.entity.job; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.JobStatus; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.JobStatus; import java.util.Date; @@ -27,69 +28,37 @@ * @author zane.luo */ public class JobListOutputModel extends AbstractOutputModel { - /** - * 联邦任务类型(横向/纵向) - */ + @Check(name = "联邦任务类型(横向/纵向)") private FederatedLearningType federatedLearningType; - /** - * 项目ID - */ + @Check(name = "项目ID") private String projectId; - /** - * 流程ID - */ + @Check(name = "流程ID") private String flowId; - /** - * 任务ID - */ + @Check(name = "任务ID") private String jobId; - /** - * 名称 - */ + @Check(name = "名称") private String name; - /** - * 我方身份 枚举(promoter/provider/arbiter) - */ + @Check(name = "我方身份 枚举(promoter/provider/arbiter)") private JobMemberRole myRole; - /** - * 状态 枚举 - */ + @Check(name = "状态 枚举") private JobStatus status; - /** - * 状态更新时间 - */ + @Check(name = "状态更新时间") private Date statusUpdatedTime; - /** - * 开始时间 - */ + @Check(name = "开始时间") private Date startTime; - /** - * 结束时间 - */ + @Check(name = "结束时间") private Date finishTime; - /** - * 进度 - */ + @Check(name = "进度") private Integer progress; - /** - * 进度更新时间 - */ + @Check(name = "进度更新时间") private Date progressUpdatedTime; - /** - * 消息备注 失败原因/备注 - */ + @Check(name = "消息备注 失败原因/备注") private String message; - /** - * 是否包含建模结果 - */ + @Check(name = "是否包含建模结果") private Boolean hasModelingResult; - /** - * 收藏/置顶/标记 - */ + @Check(name = "收藏/置顶/标记") private Boolean star; - /** - * 备注 - */ + @Check(name = "备注") private String remark; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobMemberOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobMemberOutputModel.java index 3a80c3d4f..936e672fd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobMemberOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobMemberOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,35 +18,24 @@ import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.JobMemberRole; /** * @author seven.zeng */ public class JobMemberOutputModel extends AbstractOutputModel { - /** - * 项目Id - */ + @Check(name = "项目Id") private String projectId; - /** - * 流程Id - */ + @Check(name = "流程Id") private String flowId; - /** - * 任务Id - */ + @Check(name = "任务Id") private String jobId; - /** - * 在任务中的角色 枚举(promoter/provider/arbiter) - */ + @Check(name = "在任务中的角色 枚举(promoter/provider/arbiter)") private JobMemberRole jobRole; - /** - * 成员 Id - */ + @Check(name = "成员 Id") private String memberId; - /** - * 数据集 Id - */ + @Check(name = "数据集 Id") private String dataSetId; public String getMemberName() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobOutputModel.java index 35c2b9548..af77cf715 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/JobOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,9 +19,10 @@ import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.JobStatus; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.JobStatus; import java.util.Date; @@ -29,73 +30,39 @@ * @author zane.luo */ public class JobOutputModel extends AbstractOutputModel { - /** - * 联邦任务类型(横向/纵向) - */ + @Check(name = "联邦任务类型(横向/纵向)") private FederatedLearningType federatedLearningType; - /** - * 项目ID - */ + @Check(name = "项目ID") private String projectId; - /** - * 流程ID - */ + @Check(name = "流程ID") private String flowId; - /** - * 任务ID - */ + @Check(name = "任务ID") private String jobId; - /** - * 名称 - */ + @Check(name = "名称") private String name; - /** - * 我方身份 枚举(promoter/provider/arbiter) - */ + @Check(name = "我方身份 枚举(promoter/provider/arbiter)") private JobMemberRole myRole; - /** - * 状态 枚举 - */ + @Check(name = "状态 枚举") private JobStatus status; - /** - * 状态更新时间 - */ + @Check(name = "状态更新时间") private Date statusUpdatedTime; - /** - * 开始时间 - */ + @Check(name = "开始时间") private Date startTime; - /** - * 结束时间 - */ + @Check(name = "结束时间") private Date finishTime; - /** - * 进度 - */ + @Check(name = "进度") private Integer progress; - /** - * 进度更新时间 - */ + @Check(name = "进度更新时间") private Date progressUpdatedTime; - /** - * 消息备注 失败原因/备注 - */ + @Check(name = "消息备注 失败原因/备注") private String message; - /** - * 有向无环图 - */ + @Check(name = "有向无环图") private JSONObject graph; - /** - * 是否包含建模结果 - */ + @Check(name = "是否包含建模结果") private Boolean hasModelingResult; - /** - * 收藏/置顶/标记 - */ + @Check(name = "收藏/置顶/标记") private Boolean star; - /** - * 备注 - */ + @Check(name = "备注") private String remark; public JSONObject getGraph() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/PreviewJobNodeOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/PreviewJobNodeOutputModel.java index 38527680b..79e63add1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/PreviewJobNodeOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/PreviewJobNodeOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,9 @@ package com.welab.wefe.board.service.dto.entity.job; -import com.welab.wefe.common.enums.ComponentType; + +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ComponentType; import java.util.Map; @@ -24,33 +26,19 @@ * @author zane.luo */ public class PreviewJobNodeOutputModel { - /** - * 前端画布中的节点id,由前端生成 - */ + @Check(name = "前端画布中的节点id,由前端生成") private String nodeId; - /** - * 项目ID - */ + @Check(name = "项目ID") private String projectId; - /** - * 父节点 - */ + @Check(name = "父节点") private String parentNodeId; - /** - * 组件类型 - */ + @Check(name = "组件类型") private ComponentType componentType; - /** - * 深度 - */ + @Check(name = "深度") private Integer deep; - /** - * 在任务列表中的序号,如果为 null,表示该节点不会被执行。 - */ + @Check(name = "在任务列表中的序号,如果为 null,表示该节点不会被执行。") private Integer position; - /** - * 是否存在可用的历史缓存结果 - */ + @Check(name = "是否存在可用的历史缓存结果") private Boolean hasCacheResult; public Map input; public Map output; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/ProjectFlowNodeOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/ProjectFlowNodeOutputModel.java index f88d6639a..7a70f4d90 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/ProjectFlowNodeOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/ProjectFlowNodeOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,44 +18,29 @@ import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.ComponentType; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ComponentType; /** * @author zane.luo */ public class ProjectFlowNodeOutputModel extends AbstractOutputModel { - /** - * 是否是起始节点 - */ + @Check(name = "是否是起始节点") private boolean startNode; - /** - * 前端画布中的节点id,由前端生成 - */ + @Check(name = "前端画布中的节点id,由前端生成") private String nodeId; - /** - * 项目ID - */ + @Check(name = "项目ID") private String projectId; - /** - * 流程ID - */ + @Check(name = "流程ID") private String flowId; - /** - * 父节点 - */ + @Check(name = "父节点") private String parentNodeIdList; - /** - * 组件类型 - */ + @Check(name = "组件类型") private ComponentType componentType; - /** - * 组件参数 - */ + @Check(name = "组件参数") private JSONObject params; - /** - * 参数版本号 - */ + @Check(name = "参数版本号") private long paramsVersion; /** @@ -68,6 +53,18 @@ public String getComponentName() { return componentType.getLabel(); } + public JSONObject getParams() { + return params; + } + + public void setParams(String params) { + if (params == null) { + this.params = null; + } else { + this.params = JSONObject.parseObject(params); + } + } + //region getter/setter @@ -111,16 +108,6 @@ public void setComponentType(ComponentType componentType) { this.componentType = componentType; } - public JSONObject getParams() { - return params; - } - - public void setParams(String params) { - if (params != null) { - this.params = JSONObject.parseObject(params); - } - } - public long getParamsVersion() { return paramsVersion; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/RelateDataSetOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/RelateDataSetOutputModel.java index f7d30b046..fa41434c0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/RelateDataSetOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/RelateDataSetOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,60 +16,40 @@ package com.welab.wefe.board.service.dto.entity.job; -import com.welab.wefe.common.enums.JobMemberRole; + +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.dto.AbstractApiOutput; +import com.welab.wefe.common.wefe.enums.JobMemberRole; /** * @author lonnie */ public class RelateDataSetOutputModel extends AbstractApiOutput { - /** - * 任务 Id - */ + @Check(name = "任务 Id") private String businessId; - /** - * 成员 Id - */ + @Check(name = "成员 Id") private String memberId; - /** - * 成员名称 - */ + @Check(name = "成员名称") private String memberName; - /** - * 在任务中的角色;枚举(promoter/provider/arbiter) - */ + @Check(name = "在任务中的角色;枚举(promoter/provider/arbiter)") private JobMemberRole jobRole; - /** - * 数据集名称 - */ + @Check(name = "数据集名称") private String dataSetName; - /** - * 数据量 - */ + @Check(name = "数据量") private Long dataSetRows; - /** - * 特征列 - */ + @Check(name = "特征列") private String featureColumnList; - /** - * 主键列 - */ + @Check(name = "主键列") private String primaryKeyColumn; - /** - * 字段列表 - */ + @Check(name = "字段列表") private String columnNameList; - /** - * 来源数据集id - */ + @Check(name = "来源数据集id") private String sourceJobId; - /** - * 是否包含 Y 值 - */ + @Check(name = "是否包含 Y 值") private Boolean containsY; public String getBusinessId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskOutputModel.java index d5e6ab911..543dd49eb 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,8 +19,9 @@ import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskStatus; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.TaskStatus; import java.util.Date; @@ -28,65 +29,35 @@ * @author seven.zeng */ public class TaskOutputModel extends AbstractOutputModel { - /** - * 名称 - */ + @Check(name = "名称") private String name; - /** - * 任务Id - */ + @Check(name = "任务Id") private String jobId; - /** - * 业务ID,多方唯一 - */ + @Check(name = "业务ID,多方唯一") private String taskId; - /** - * 流程号 - */ + @Check(name = "流程号") private String flowId; - /** - * 任务在流程中的节点Id - */ + @Check(name = "任务在流程中的节点Id") private String flowNodeId; - /** - * 子任务的父节点 - */ + @Check(name = "子任务的父节点") private String parentTaskIdList; - /** - * 子任务依赖 - */ + @Check(name = "子任务依赖") private String dependenceList; - /** - * 子任务类型;枚举(DataIO/Intersection/HeteroLR...) - */ + @Check(name = "子任务类型;枚举(DataIO/Intersection/HeteroLR...)") private ComponentType taskType; - /** - * 任务conf_json - */ + @Check(name = "任务conf_json") private JSONObject taskConf; - /** - * 状态;枚举(created/running/canceled/success/error) - */ + @Check(name = "状态;枚举(created/running/canceled/success/error)") private TaskStatus status; - /** - * 开始时间 - */ + @Check(name = "开始时间") private Date startTime; - /** - * 结束时间 - */ + @Check(name = "结束时间") private Date finishTime; - /** - * 消息备注;失败原因/备注 - */ + @Check(name = "消息备注;失败原因/备注") private String message; - /** - * 发生错误的详细原因,通常是堆栈信息。 - */ + @Check(name = "发生错误的详细原因,通常是堆栈信息。") private String errorCause; - /** - * task执行顺序 - */ + @Check(name = "task执行顺序") private Integer position; private Integer spend; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskOutputView.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskOutputView.java index aeb7d7b90..526a6e2ca 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskOutputView.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskOutputView.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,7 +18,8 @@ import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.board.service.util.ModelMapper; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.web.util.ModelMapper; import java.util.List; @@ -26,13 +27,9 @@ * @author zane.luo */ public class TaskOutputView extends AbstractOutputModel { - /** - * 由组件创建的 task - */ + @Check(name = "由组件创建的 task") private TaskOutputModel task; - /** - * task 输出的结果 - */ + @Check(name = "task 输出的结果") private List results; public TaskOutputView() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskProgressOuputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskProgressOuputModel.java index df337ab9d..63cd5c03e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskProgressOuputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskProgressOuputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,8 +17,9 @@ package com.welab.wefe.board.service.dto.entity.job; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import javax.persistence.EnumType; import javax.persistence.Enumerated; @@ -29,19 +30,13 @@ */ public class TaskProgressOuputModel extends AbstractOutputModel { - /** - * 项目id - */ + @Check(name = "项目id") private String projectId; - /** - * 流程号 - */ + @Check(name = "流程号") private String flowId; - /** - * 任务id - */ + @Check(name = "任务id") private String jobId; /** @@ -50,14 +45,10 @@ public class TaskProgressOuputModel extends AbstractOutputModel { @Enumerated(EnumType.STRING) private JobMemberRole role; - /** - * 流程节点id - */ + @Check(name = "流程节点id") private String flowNodeId; - /** - * 任务id - */ + @Check(name = "任务id") private String taskId; /** @@ -66,34 +57,22 @@ public class TaskProgressOuputModel extends AbstractOutputModel { @Enumerated(EnumType.STRING) private ComponentType taskType; - /** - * 预计总工程量 - */ + @Check(name = "预计总工程量") private int expectWorkAmount; - /** - * 实际总工程量 - */ + @Check(name = "实际总工程量") private int reallyWorkAmount; - /** - * 进度 - */ + @Check(name = "进度") private int progress; - /** - * 进度百分比 - */ + @Check(name = "进度百分比") private double progressRate; - /** - * updated_time - created_time,毫秒。 - */ + @Check(name = "updated_time - created_time,毫秒。") private int spend; - /** - * 预计结束时间 - */ + @Check(name = "预计结束时间") private Date expectEndTime; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskResultOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskResultOutputModel.java index 662965d9e..e26aa9aef 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskResultOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/TaskResultOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,84 +22,53 @@ import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskStatus; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskStatus; /** * @author zane.luo */ public class TaskResultOutputModel extends AbstractOutputModel { - /** - * 任务Id - */ + @Check(name = "任务Id") private String jobId; - /** - * 流程Id - */ + @Check(name = "流程Id") private String flowId; - /** - * 流程节点Id - */ + @Check(name = "流程节点Id") private String flowNodeId; - /** - * 子任务Id - */ + @Check(name = "子任务Id") private String taskId; - /** - * 任务名称,例如:vert_lr_0 - */ + @Check(name = "任务名称,例如:vert_lr_0") private String name; - /** - * 组件id - */ + @Check(name = "组件id") private ComponentType componentType; - /** - * 成员角色 - */ + @Check(name = "成员角色") private JobMemberRole role; - /** - * 类型,一个 task 会有多行不同类型的 result - */ + @Check(name = "类型,一个 task 会有多行不同类型的 result") private String type; - /** - * 执行结果 - */ + @Check(name = "执行结果") private JSONObject result; - /** - * 是否是可以导出到 serving 的模型 - */ + @Check(name = "是否是可以导出到 serving 的模型") private boolean servingModel; - /** - * task的状态 - */ + @Check(name = "task的状态") private TaskStatus status; - /** - * 开始时间 - */ + @Check(name = "开始时间") private Date startTime; - /** - * 结束时间 - */ + @Check(name = "结束时间") private Date finishTime; - /** - * 消息备注;失败原因/备注 - */ + @Check(name = "消息备注;失败原因/备注") private String message; - /** - * 发生错误的详细原因,通常是堆栈信息。 - */ + @Check(name = "发生错误的详细原因,通常是堆栈信息。") private String errorCause; - /** - * task执行顺序 - */ + @Check(name = "task执行顺序") private Integer position; private Integer spend; /** * 参与方 - * */ + */ private List members; public JSONObject getResult() { @@ -247,6 +216,6 @@ public List getMembers() { public void setMembers(List members) { this.members = members; } - + //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/AbstractJobForGatewayModelingConfigOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/AbstractJobForGatewayModelingConfigOutputModel.java deleted file mode 100644 index a3bb1b284..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/AbstractJobForGatewayModelingConfigOutputModel.java +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.entity.job.gateway; - -import com.welab.wefe.common.enums.FederatedLearningType; - -import javax.persistence.MappedSuperclass; - -/** - * @author seven.zeng - */ -@MappedSuperclass -public class AbstractJobForGatewayModelingConfigOutputModel { - - /** - * 联邦学习模式 - */ - private FederatedLearningType flType; - - //region getter/setter - - public FederatedLearningType getFlType() { - return flType; - } - - public void setFlType(FederatedLearningType flType) { - this.flType = flType; - } - - - //endregion - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/FlowInfoForGatewayOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/FlowInfoForGatewayOutputModel.java deleted file mode 100644 index 4da1f32d9..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/FlowInfoForGatewayOutputModel.java +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.entity.job.gateway; - -import com.welab.wefe.common.enums.JobMemberRole; - -import javax.persistence.EnumType; -import javax.persistence.Enumerated; - -/** - * @author zane.luo - */ -public class FlowInfoForGatewayOutputModel { - - /** - * 项目id 非主键 - */ - private String projectId; - - /** - * 流程id 非主键 - */ - private String flowId; - - /** - * 流程名称 - */ - private String flowName; - - /** - * 流程描述 - */ - private String flowDesc; - - /** - * 我方身份;枚举(promoter/provider) - */ - @Enumerated(EnumType.STRING) - private JobMemberRole myRole; - - /** - * 流程图默认配置 - */ - private String defaultConfig; - - /** - * 流程图边 - */ - private String edges; - - /** - * 分组列表 - */ - private String combos; - - public String getProjectId() { - return projectId; - } - - public void setProjectId(String projectId) { - this.projectId = projectId; - } - - public String getFlowId() { - return flowId; - } - - public void setFlowId(String flowId) { - this.flowId = flowId; - } - - public String getFlowName() { - return flowName; - } - - public void setFlowName(String flowName) { - this.flowName = flowName; - } - - public String getFlowDesc() { - return flowDesc; - } - - public void setFlowDesc(String flowDesc) { - this.flowDesc = flowDesc; - } - - public JobMemberRole getMyRole() { - return myRole; - } - - public void setMyRole(JobMemberRole myRole) { - this.myRole = myRole; - } - - public String getDefaultConfig() { - return defaultConfig; - } - - public void setDefaultConfig(String defaultConfig) { - this.defaultConfig = defaultConfig; - } - - public String getEdges() { - return edges; - } - - public void setEdges(String edges) { - this.edges = edges; - } - - public String getCombos() { - return combos; - } - - public void setCombos(String combos) { - this.combos = combos; - } - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/FlowNodeInputForGatewayOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/FlowNodeInputForGatewayOutputModel.java deleted file mode 100644 index d46b289bd..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/FlowNodeInputForGatewayOutputModel.java +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.entity.job.gateway; - -/** - * @author zane.luo - */ -public class FlowNodeInputForGatewayOutputModel { - - /** - * 节点ID - */ - private String nodeId; - - /** - * 组件ID - */ - private String componentId; - - /** - * 入参 - */ - private String input; - - public String getNodeId() { - return nodeId; - } - - public void setNodeId(String nodeId) { - this.nodeId = nodeId; - } - - public String getComponentId() { - return componentId; - } - - public void setComponentId(String componentId) { - this.componentId = componentId; - } - - public String getInput() { - return input; - } - - public void setInput(String input) { - this.input = input; - } - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/JobForGatewayMemberOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/JobForGatewayMemberOutputModel.java deleted file mode 100644 index c461fc502..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/JobForGatewayMemberOutputModel.java +++ /dev/null @@ -1,188 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.entity.job.gateway; - -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; - -/** - * @author seven.zeng - */ -public class JobForGatewayMemberOutputModel { - /** - * 成员 Id - */ - private String memberId; - /** - * 成员名称 - */ - private String memberName; - /** - * 在任务中的角色;枚举(promoter/provider/arbiter) - */ - private JobMemberRole jobRole; - /** - * 数据集 Id - */ - private String dataSetId; - /** - * 数据集名称 - */ - private String dataSetName; - /** - * 数据量 - */ - private Long dataSetRows; - /** - * 特征列 - */ - private String featureColumnList; - /** - * 字段列表 - */ - private String columnNameList; - /** - * 主键列 - */ - private String primaryKeyColumn; - /** - * 审核结果;枚举值(adopt/disagree) - */ - private AuditStatus auditResult; - /** - * 审核意见 - */ - private String auditComment; - - /** - * 关联的任务id - */ - private String sourceJobId; - - /** - * 是否包含 Y 值 - */ - private Boolean containsY; - - // region getting/setting - - public String getMemberId() { - return memberId; - } - - public void setMemberId(String memberId) { - this.memberId = memberId; - } - - public String getMemberName() { - return memberName; - } - - public void setMemberName(String memberName) { - this.memberName = memberName; - } - - public JobMemberRole getJobRole() { - return jobRole; - } - - public void setJobRole(JobMemberRole jobRole) { - this.jobRole = jobRole; - } - - public String getDataSetId() { - return dataSetId; - } - - public void setDataSetId(String dataSetId) { - this.dataSetId = dataSetId; - } - - public String getDataSetName() { - return dataSetName; - } - - public void setDataSetName(String dataSetName) { - this.dataSetName = dataSetName; - } - - public Long getDataSetRows() { - return dataSetRows; - } - - public void setDataSetRows(Long dataSetRows) { - this.dataSetRows = dataSetRows; - } - - public String getFeatureColumnList() { - return featureColumnList; - } - - public void setFeatureColumnList(String featureColumnList) { - this.featureColumnList = featureColumnList; - } - - public String getColumnNameList() { - return columnNameList; - } - - public void setColumnNameList(String columnNameList) { - this.columnNameList = columnNameList; - } - - public String getPrimaryKeyColumn() { - return primaryKeyColumn; - } - - public void setPrimaryKeyColumn(String primaryKeyColumn) { - this.primaryKeyColumn = primaryKeyColumn; - } - - public AuditStatus getAuditResult() { - return auditResult; - } - - public void setAuditResult(AuditStatus auditResult) { - this.auditResult = auditResult; - } - - public String getAuditComment() { - return auditComment; - } - - public void setAuditComment(String auditComment) { - this.auditComment = auditComment; - } - - public String getSourceJobId() { - return sourceJobId; - } - - public void setSourceJobId(String sourceJobId) { - this.sourceJobId = sourceJobId; - } - - public Boolean getContainsY() { - return containsY; - } - - public void setContainsY(Boolean containsY) { - this.containsY = containsY; - } - - // endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayDataSetOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayDataSetOutputModel.java index 9abd3800d..9ee939b98 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayDataSetOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayDataSetOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,9 @@ package com.welab.wefe.board.service.dto.entity.job.gateway; -import com.welab.wefe.common.enums.AuditStatus; + +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.AuditStatus; import javax.persistence.EnumType; import javax.persistence.Enumerated; @@ -27,46 +29,28 @@ * @author zane.luo */ public class ProjectForGatewayDataSetOutputModel { - /** - * 成员id - */ + @Check(name = "成员id") private String memberId; - /** - * 成员id - */ + @Check(name = "成员id") private String memberName; - /** - * 数据集 Id - */ + @Check(name = "数据集 Id") private String dataSetId; - /** - * 数据集名称 - */ + @Check(name = "数据集名称") private String dataSetName; - /** - * 数据量 - */ + @Check(name = "数据量") private Long dataSetRows; - /** - * 关键字 - */ + @Check(name = "关键字") private String dataSetKeys; - /** - * 数据集列数 - */ + @Check(name = "数据集列数") private long dataSetColumnNum; - /** - * 是否包含 Y 值 - */ + @Check(name = "是否包含 Y 值") private boolean containsY; - /** - * 特征列 - */ + @Check(name = "特征列") private String featureColumnList; /** @@ -74,9 +58,7 @@ public class ProjectForGatewayDataSetOutputModel { */ @Enumerated(EnumType.STRING) private AuditStatus dataSetStatus; - /** - * 状态更新时间 - */ + @Check(name = "状态更新时间") private Date statusUpdatedTime; public String getMemberId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayMemberOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayMemberOutputModel.java index 6fe64a9da..84948aef8 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayMemberOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayMemberOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,8 +16,10 @@ package com.welab.wefe.board.service.dto.entity.job.gateway; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; + +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import javax.persistence.EnumType; import javax.persistence.Enumerated; @@ -27,13 +29,9 @@ */ public class ProjectForGatewayMemberOutputModel { - /** - * 成员 Id - */ + @Check(name = "成员 Id") private String memberId; - /** - * 成员名称 - */ + @Check(name = "成员名称") private String memberName; /** @@ -47,9 +45,7 @@ public class ProjectForGatewayMemberOutputModel { */ @Enumerated(EnumType.STRING) private AuditStatus auditResult; - /** - * 审核意见 - */ + @Check(name = "审核意见") private String auditComment; public String getMemberId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayOutputModel.java deleted file mode 100644 index f55163e8b..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/job/gateway/ProjectForGatewayOutputModel.java +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.entity.job.gateway; - -import com.welab.wefe.common.enums.JobMemberRole; - -import java.util.List; - -/** - * @author zane.luo - */ -public class ProjectForGatewayOutputModel { - - /** - * 项目ID,非主键 - */ - private String projectId; - /** - * 名称 - */ - private String name; - - /** - * 描述 - */ - private String projectDesc; - - /** - * 合作方 - */ - private List memberList; - - /** - * 合作方数据集 - */ - private List dataSetList; - - /** - * 角色 - */ - private JobMemberRole myRole; - - public String getProjectId() { - return projectId; - } - - public void setProjectId(String projectId) { - this.projectId = projectId; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public String getProjectDesc() { - return projectDesc; - } - - public void setProjectDesc(String projectDesc) { - this.projectDesc = projectDesc; - } - - public List getMemberList() { - return memberList; - } - - public void setMemberList(List memberList) { - this.memberList = memberList; - } - - public JobMemberRole getMyRole() { - return myRole; - } - - public void setMyRole(JobMemberRole myRole) { - this.myRole = myRole; - } - - public List getDataSetList() { - return dataSetList; - } - - public void setDataSetList(List dataSetList) { - this.dataSetList = dataSetList; - } - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/AbstractModelingConfigOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/AbstractModelingConfigOutputModel.java index bca3dbaed..a22487e1b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/AbstractModelingConfigOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/AbstractModelingConfigOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,25 +17,20 @@ package com.welab.wefe.board.service.dto.entity.modeling_config; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.FederatedLearningType; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; /** * @author zane.luo */ public class AbstractModelingConfigOutputModel extends AbstractOutputModel { - /** - * 配置名称 - */ + @Check(name = "配置名称") private String name; - /** - * 联邦学习模式 - */ + @Check(name = "联邦学习模式") private FederatedLearningType flType; - /** - * 是否已删除 - */ + @Check(name = "是否已删除") private Boolean deleted = false; //region getter/setter diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingConfigLogisticRegressionOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingConfigLogisticRegressionOutputModel.java index 419cce0a8..0a889725f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingConfigLogisticRegressionOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingConfigLogisticRegressionOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,87 +16,51 @@ package com.welab.wefe.board.service.dto.entity.modeling_config; +import com.welab.wefe.common.fieldvalidate.annotation.Check; + /** * 模型配置·逻辑回归 * * @author Zane */ public class ModelingConfigLogisticRegressionOutputModel extends AbstractModelingConfigOutputModel { - /** - * 模型初始化方式 - */ + @Check(name = "模型初始化方式") private String initParam_InitMethod; - /** - * 是否需要偏置系数 - */ + @Check(name = "是否需要偏置系数") private Boolean initParam_FitIntercept; - /** - * 惩罚方式 - */ + @Check(name = "惩罚方式") private String penalty; - /** - * 收敛容忍度 - */ + @Check(name = "收敛容忍度") private Double tol; - /** - * 惩罚项系数 - */ + @Check(name = "惩罚项系数") private Double alpha; - /** - * 优化算法 - */ + @Check(name = "优化算法") private String optimizer; - /** - * 批量大小 - */ + @Check(name = "批量大小") private Integer batchSize; - /** - * 学习率 - */ + @Check(name = "学习率") private Double learningRate; - /** - * 最大迭代次数 - */ + @Check(name = "最大迭代次数") private Integer maxIter; - /** - * 判断收敛性与否的方法 - */ + @Check(name = "判断收敛性与否的方法") private String earlyStop; - /** - * 同态加密方法 - */ + @Check(name = "同态加密方法") private String encryptParam_Method; - /** - * 在KFold中使用分割符大小 - */ + @Check(name = "在KFold中使用分割符大小") private Integer cvParam_NSplits; - /** - * 在KFold之前是否进行洗牌 - */ + @Check(name = "在KFold之前是否进行洗牌") private Boolean cvParam_Shuffle; - /** - * 是否需要进行此模块 - */ + @Check(name = "是否需要进行此模块") private Boolean cvParam_NeedCv; - /** - * 学习速率的衰减率 - */ + @Check(name = "学习速率的衰减率") private Double decay; - /** - * 衰减率是否开平方 - */ + @Check(name = "衰减率是否开平方") private Boolean decaySqrt; - /** - * 多分类策略;枚举(ovr/ovo) - */ + @Check(name = "多分类策略;枚举(ovr/ovo)") private String multiClass; - /** - * 验证频次 - */ + @Check(name = "验证频次") private Integer validationFreqs; - /** - * 提前结束的迭代次数 - */ + @Check(name = "提前结束的迭代次数") private Integer earlyStoppingRounds; //region getter/setter diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingConfigXGBoostOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingConfigXGBoostOutputModel.java deleted file mode 100644 index 7e96138de..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingConfigXGBoostOutputModel.java +++ /dev/null @@ -1,293 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.entity.modeling_config; - -/** - * 模型配置·逻辑回归 - * - * @author Zane - */ -public class ModelingConfigXGBoostOutputModel extends AbstractModelingConfigOutputModel { - /** - * 任务类型;枚举(分类/回归) - */ - private String taskType; - /** - * 学习率 - */ - private Double learningRate; - /** - * 最大树数量 - */ - private Integer numTrees; - /** - * 特征随机采样比率 - */ - private Double subsampleFeatureRate; - /** - * n次迭代没变化是否停止 - */ - private Boolean nIterNoChange; - /** - * 收敛阀值 - */ - private Double tol; - /** - * 最大桶数量 - */ - private Integer binNum; - /** - * 标准函数;默认xgboost - */ - private String treeParam_CriterionMethod; - /** - * 标准参数 - */ - private String treeParam_CriterionParams; - /** - * 树的最大深度 - */ - private Integer treeParam_MaxDepth; - /** - * 分裂一个内部节点(非叶子节点)需要的最小样本;默认2 - */ - private Integer treeParam_MinSampleSplit; - /** - * 每个叶子节点包含的最小样本数 - */ - private Integer treeParam_MinLeafNode; - /** - * 单个拆分的要达到的最小增益 - */ - private Double treeParam_MinImpuritySplit; - /** - * 可拆分的最大并行数量 - */ - private Integer treeParam_MaxSplitNodes; - /** - * 目标函数 - */ - private String objectiveParam_Objective; - /** - * 学习目标参数 - */ - private String objectiveParam_Params; - /** - * 加密算法 - */ - private String encryptParam_Method; - - /** - * 在KFold中使用分割符大小 - */ - private Integer cvParam_NSplits; - /** - * 在KFold之前是否进行洗牌 - */ - private Boolean cvParam_Shuffle; - /** - * 是否需要进行此模块 - */ - private Boolean cvParam_NeedCv; - /** - * 验证频次 - */ - private Integer validationFreqs; - /** - * 提前结束的迭代次数 - */ - private Integer earlyStoppingRounds; - - //region getter/setter - - public String getTaskType() { - return taskType; - } - - public void setTaskType(String taskType) { - this.taskType = taskType; - } - - public Double getLearningRate() { - return learningRate; - } - - public void setLearningRate(Double learningRate) { - this.learningRate = learningRate; - } - - public Integer getNumTrees() { - return numTrees; - } - - public void setNumTrees(Integer numTrees) { - this.numTrees = numTrees; - } - - public Double getSubsampleFeatureRate() { - return subsampleFeatureRate; - } - - public void setSubsampleFeatureRate(Double subsampleFeatureRate) { - this.subsampleFeatureRate = subsampleFeatureRate; - } - - public Boolean getnIterNoChange() { - return nIterNoChange; - } - - public void setnIterNoChange(Boolean nIterNoChange) { - this.nIterNoChange = nIterNoChange; - } - - public Double getTol() { - return tol; - } - - public void setTol(Double tol) { - this.tol = tol; - } - - public Integer getBinNum() { - return binNum; - } - - public void setBinNum(Integer binNum) { - this.binNum = binNum; - } - - public String getTreeParam_CriterionMethod() { - return treeParam_CriterionMethod; - } - - public void setTreeParam_CriterionMethod(String treeParam_CriterionMethod) { - this.treeParam_CriterionMethod = treeParam_CriterionMethod; - } - - public String getTreeParam_CriterionParams() { - return treeParam_CriterionParams; - } - - public void setTreeParam_CriterionParams(String treeParam_CriterionParams) { - this.treeParam_CriterionParams = treeParam_CriterionParams; - } - - public Integer getTreeParam_MaxDepth() { - return treeParam_MaxDepth; - } - - public void setTreeParam_MaxDepth(Integer treeParam_MaxDepth) { - this.treeParam_MaxDepth = treeParam_MaxDepth; - } - - public Integer getTreeParam_MinSampleSplit() { - return treeParam_MinSampleSplit; - } - - public void setTreeParam_MinSampleSplit(Integer treeParam_MinSampleSplit) { - this.treeParam_MinSampleSplit = treeParam_MinSampleSplit; - } - - public Integer getTreeParam_MinLeafNode() { - return treeParam_MinLeafNode; - } - - public void setTreeParam_MinLeafNode(Integer treeParam_MinLeafNode) { - this.treeParam_MinLeafNode = treeParam_MinLeafNode; - } - - public Double getTreeParam_MinImpuritySplit() { - return treeParam_MinImpuritySplit; - } - - public void setTreeParam_MinImpuritySplit(Double treeParam_MinImpuritySplit) { - this.treeParam_MinImpuritySplit = treeParam_MinImpuritySplit; - } - - public Integer getTreeParam_MaxSplitNodes() { - return treeParam_MaxSplitNodes; - } - - public void setTreeParam_MaxSplitNodes(Integer treeParam_MaxSplitNodes) { - this.treeParam_MaxSplitNodes = treeParam_MaxSplitNodes; - } - - public String getObjectiveParam_Objective() { - return objectiveParam_Objective; - } - - public void setObjectiveParam_Objective(String objectiveParam_Objective) { - this.objectiveParam_Objective = objectiveParam_Objective; - } - - public String getObjectiveParam_Params() { - return objectiveParam_Params; - } - - public void setObjectiveParam_Params(String objectiveParam_Params) { - this.objectiveParam_Params = objectiveParam_Params; - } - - public String getEncryptParam_Method() { - return encryptParam_Method; - } - - public void setEncryptParam_Method(String encryptParam_Method) { - this.encryptParam_Method = encryptParam_Method; - } - - public Integer getCvParam_NSplits() { - return cvParam_NSplits; - } - - public void setCvParam_NSplits(Integer cvParam_NSplits) { - this.cvParam_NSplits = cvParam_NSplits; - } - - public Boolean getCvParam_Shuffle() { - return cvParam_Shuffle; - } - - public void setCvParam_Shuffle(Boolean cvParam_Shuffle) { - this.cvParam_Shuffle = cvParam_Shuffle; - } - - public Boolean getCvParam_NeedCv() { - return cvParam_NeedCv; - } - - public void setCvParam_NeedCv(Boolean cvParam_NeedCv) { - this.cvParam_NeedCv = cvParam_NeedCv; - } - - public Integer getValidationFreqs() { - return validationFreqs; - } - - public void setValidationFreqs(Integer validationFreqs) { - this.validationFreqs = validationFreqs; - } - - public Integer getEarlyStoppingRounds() { - return earlyStoppingRounds; - } - - public void setEarlyStoppingRounds(Integer earlyStoppingRounds) { - this.earlyStoppingRounds = earlyStoppingRounds; - } -//endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingInfoOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingInfoOutputModel.java index bbd4ab968..77a80a2e8 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingInfoOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/modeling_config/ModelingInfoOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,61 +18,38 @@ import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; /** * @author lonnie */ public class ModelingInfoOutputModel extends AbstractOutputModel { - /** - * 任务Id - */ + @Check(name = "任务Id") private String jobId; - /** - * 流程Id - */ + @Check(name = "流程Id") private String flowId; - /** - * 流程节点Id - */ + @Check(name = "流程节点Id") private String flowNodeId; - /** - * 子任务Id - */ + @Check(name = "子任务Id") private String taskId; - /** - * 流程名称 - */ + @Check(name = "流程名称") private String flowName; - /** - * 任务名称,例如:vert_lr_0 - */ + @Check(name = "任务名称,例如:vert_lr_0") private String name; - /** - * 组件类型 - */ + @Check(name = "组件类型") private ComponentType componentType; - /** - * 组件类型中文名 - */ + @Check(name = "组件类型中文名") private String componentName; - /** - * 成员角色 - */ + @Check(name = "成员角色") private JobMemberRole role; - /** - * 类型,一个 task 会有多行不同类型的 result - */ + @Check(name = "类型,一个 task 会有多行不同类型的 result") private String type; - /** - * 执行结果 - */ + @Check(name = "执行结果") private JSONObject result; - /** - * 是否是可以导出到 serving 的模型 - */ + @Check(name = "是否是可以导出到 serving 的模型") private boolean servingModel; public String getJobId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectDataSetOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectDataSetOutputModel.java deleted file mode 100644 index 8a7bc9e69..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectDataSetOutputModel.java +++ /dev/null @@ -1,151 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.entity.project; - -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; - -import java.util.Date; - -/** - * @author zane.luo - */ -public class ProjectDataSetOutputModel extends DataSetOutputModel { - - /** - * 项目 Id 项目主键 - */ - private String projectId; - /** - * 成员id - */ - private String memberId; - /** - * 成员角色 - *

- * 由于存在自己和自己建模的情况,所以需要用角色区分数据集归属。 - */ - private JobMemberRole memberRole; - /** - * 数据集 Id - */ - private String dataSetId; - /** - * 状态 - */ - private AuditStatus auditStatus; - /** - * 审核意见 - */ - private String auditComment; - /** - * 状态更新时间 - */ - private Date statusUpdatedTime; - - /** - * 是否包含 Y 值 - */ - private Boolean containsY; - - /** - * 数据集是否已删除 - */ - private boolean deleted; - - - //region getter/setter - - - public JobMemberRole getMemberRole() { - return memberRole; - } - - public void setMemberRole(JobMemberRole memberRole) { - this.memberRole = memberRole; - } - - public boolean isDeleted() { - return deleted; - } - - public void setDeleted(boolean deleted) { - this.deleted = deleted; - } - - public String getAuditComment() { - return auditComment; - } - - public void setAuditComment(String auditComment) { - this.auditComment = auditComment; - } - - public String getProjectId() { - return projectId; - } - - public void setProjectId(String projectId) { - this.projectId = projectId; - } - - public String getMemberId() { - return memberId; - } - - public void setMemberId(String memberId) { - this.memberId = memberId; - } - - public String getDataSetId() { - return dataSetId; - } - - public void setDataSetId(String dataSetId) { - this.dataSetId = dataSetId; - } - - public AuditStatus getAuditStatus() { - return auditStatus; - } - - public void setAuditStatus(AuditStatus auditStatus) { - this.auditStatus = auditStatus; - } - - public Date getStatusUpdatedTime() { - return statusUpdatedTime; - } - - public void setStatusUpdatedTime(Date statusUpdatedTime) { - this.statusUpdatedTime = statusUpdatedTime; - } - - @Override - public Boolean getContainsY() { - return containsY; - } - - @Override - public void setContainsY(Boolean containsY) { - this.containsY = containsY; - } - - //endregion - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectDetailMemberOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectDetailMemberOutputModel.java index 04e2376a1..3e4ebe861 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectDetailMemberOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectDetailMemberOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,22 +16,24 @@ package com.welab.wefe.board.service.dto.entity.project; +import com.welab.wefe.board.service.dto.entity.project.data_set.ProjectDataResourceOutputModel; + import java.util.List; /** * @author zane.luo */ public class ProjectDetailMemberOutputModel extends ProjectMemberOutputModel { - private List dataSetList; + private List dataResourceList; //region getter/setter - public List getDataSetList() { - return dataSetList; + public List getDataResourceList() { + return dataResourceList; } - public void setDataSetList(List dataSetList) { - this.dataSetList = dataSetList; + public void setDataResourceList(List dataResourceList) { + this.dataResourceList = dataResourceList; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowDetailOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowDetailOutputModel.java index d5265b7a7..16405f9b1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowDetailOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowDetailOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,6 +17,7 @@ package com.welab.wefe.board.service.dto.entity.project; import com.welab.wefe.board.service.dto.entity.job.ProjectFlowNodeOutputModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import java.util.List; @@ -30,13 +31,9 @@ public class ProjectFlowDetailOutputModel extends ProjectFlowOutputModel { private boolean isCreator; - /** - * 被oot的任务ID - */ + @Check(name = "被oot的任务ID") private String ootJobId; - /** - * 被oot的模型id - */ + @Check(name = "被oot的模型id") private String ootModelFlowNodeId; public ProjectOutputModel getProject() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowListOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowListOutputModel.java index 893099053..de1fc7a88 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowListOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowListOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,9 +18,11 @@ import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.ProjectFlowStatus; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectFlowStatus; import java.util.Date; @@ -28,52 +30,32 @@ * @author zane.luo */ public class ProjectFlowListOutputModel extends AbstractOutputModel { - /** - * 是否已被删除 - */ + @Check(name = "是否已被删除") private Boolean deleted; - /** - * 联邦任务类型(横向/纵向) - */ + @Check(name = "联邦任务类型(横向/纵向)") private FederatedLearningType federatedLearningType; - /** - * 项目ID - */ + @Check(name = "深度学习任务类型(目标检测、图像分类)") + private DeepLearningJobType deepLearningJobType; + @Check(name = "项目ID") private String projectId; - /** - * 流程ID - */ + @Check(name = "流程ID") private String flowId; - /** - * 流程名称 - */ + @Check(name = "流程名称") private String flowName; - /** - * 流程描述 - */ + @Check(name = "流程描述") private String flowDesc; - /** - * 流程的状态 - */ + @Check(name = "流程的状态") private ProjectFlowStatus flowStatus; private Date statusUpdatedTime; private String message; - /** - * 我方角色 - */ + @Check(name = "我方角色") private JobMemberRole myRole; - /** - * 是否是创建者 - */ + @Check(name = "是否是创建者") private boolean isCreator; - /** - * 任务进度 - */ + @Check(name = "任务进度") private Integer jobProgress; - /** - * 创建此流程的成员的ID - */ + @Check(name = "创建此流程的成员的ID") private String creatorMemberId; public String getCreatorMemberName() { @@ -99,6 +81,14 @@ public void setFederatedLearningType(FederatedLearningType federatedLearningType this.federatedLearningType = federatedLearningType; } + public DeepLearningJobType getDeepLearningJobType() { + return deepLearningJobType; + } + + public void setDeepLearningJobType(DeepLearningJobType deepLearningJobType) { + this.deepLearningJobType = deepLearningJobType; + } + public String getMessage() { return message; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowOutputModel.java index a7a12d4ed..01713f3a5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,9 +20,11 @@ import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.ProjectFlowStatus; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectFlowStatus; import java.util.Date; @@ -30,48 +32,30 @@ * @author zane.luo */ public class ProjectFlowOutputModel extends AbstractOutputModel { - /** - * 是否已被删除 - */ + @Check(name = "是否已被删除") private Boolean deleted; - /** - * 联邦任务类型(横向/纵向) - */ + @Check(name = "联邦任务类型(横向/纵向)") private FederatedLearningType federatedLearningType; - /** - * 项目ID - */ + @Check(name = "深度学习任务类型(目标检测、图像分类)") + private DeepLearningJobType deepLearningJobType; + @Check(name = "项目ID") private String projectId; - /** - * 流程ID - */ + @Check(name = "流程ID") private String flowId; - /** - * 流程名称 - */ + @Check(name = "流程名称") private String flowName; - /** - * 流程描述 - */ + @Check(name = "流程描述") private String flowDesc; - /** - * 画布中编辑的图 - */ + @Check(name = "画布中编辑的图") private JSONObject graph; - /** - * 创建此流程的成员的ID - */ + @Check(name = "创建此流程的成员的ID") private String creatorMemberId; - /** - * 流程的状态 - */ + @Check(name = "流程的状态") private ProjectFlowStatus flowStatus; private Date statusUpdatedTime; private String message; - /** - * 我方角色 - */ + @Check(name = "我方角色") private JobMemberRole myRole; private ProjectModelingOutputModel projectModelingOutputModel; @@ -107,6 +91,14 @@ public void setFederatedLearningType(FederatedLearningType federatedLearningType this.federatedLearningType = federatedLearningType; } + public DeepLearningJobType getDeepLearningJobType() { + return deepLearningJobType; + } + + public void setDeepLearningJobType(DeepLearningJobType deepLearningJobType) { + this.deepLearningJobType = deepLearningJobType; + } + public String getProjectId() { return projectId; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowProgressOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowProgressOutputModel.java index 070dabbb2..70442be03 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowProgressOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectFlowProgressOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,8 @@ package com.welab.wefe.board.service.dto.entity.project; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.ProjectFlowStatus; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ProjectFlowStatus; import java.util.Date; @@ -25,24 +26,15 @@ * @author zane.luo */ public class ProjectFlowProgressOutputModel extends AbstractOutputModel { - /** - * 项目ID - */ + @Check(name = "项目ID") private String projectId; - /** - * 流程ID - */ + @Check(name = "流程ID") private String flowId; - - /** - * 流程的状态 - */ + @Check(name = "流程的状态") private ProjectFlowStatus flowStatus; private Date statusUpdatedTime; private String message; - /** - * 任务进度 - */ + @Check(name = "任务进度") private Integer jobProgress; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectMemberOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectMemberOutputModel.java index cd60ff45f..6a6bd8bfe 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectMemberOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectMemberOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,63 +18,42 @@ import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.JobMemberRole; /** * @author zane.luo */ public class ProjectMemberOutputModel extends AbstractOutputModel { - /** - * 邀请方成员Id - */ + @Check(name = "邀请方成员Id") private String inviterId; - /** - * 邀请方成员名称 - */ + @Check(name = "邀请方成员名称") private String inviterName; - /** - * 是否是初始化项目时添加进来的(关系到审核流程不同) - */ + @Check(name = "是否是初始化项目时添加进来的(关系到审核流程不同)") private boolean fromCreateProject; - /** - * 所属项目 Id 项目主键 - */ + @Check(name = "所属项目 Id 项目主键") private String projectId; - /** - * 成员 Id - */ + @Check(name = "成员 Id") private String memberId; - /** - * 在任务中的角色;枚举(promoter/provider/arbiter) - */ + @Check(name = "在任务中的角色;枚举(promoter/provider/arbiter)") private JobMemberRole memberRole; - /** - * 综合的审核结果 - */ + @Check(name = "综合的审核结果") private AuditStatus auditStatus; - /** - * 自己是否同意 - */ + @Check(name = "自己是否同意") private AuditStatus auditStatusFromMyself; - /** - * 其他人是否同意 - */ + @Check(name = "其他人是否同意") private AuditStatus auditStatusFromOthers; - /** - * 审核意见 - */ + @Check(name = "审核意见") private String auditComment; - /** - * 是否已退出 - */ + @Check(name = "是否已退出") private boolean exited = false; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectModelingOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectModelingOutputModel.java index bbdbae0c6..a933754ed 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectModelingOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectModelingOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,36 +17,27 @@ package com.welab.wefe.board.service.dto.entity.project; import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; -import com.welab.wefe.common.enums.ComponentType; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ComponentType; /** * @author lonnie */ public class ProjectModelingOutputModel extends AbstractOutputModel { - /** - * 流程id - */ + @Check(name = "流程id") private String flowId; - /** - * job_id - */ + @Check(name = "job_id") private String jobId; - /** - * job名字 - */ + @Check(name = "job名字") private String jobName; - /** - * 模型评估任务id - */ + @Check(name = "模型评估任务id") private String evaluationTaskId; - /** - * 模型类型 - */ + @Check(name = "模型类型") private ComponentType modelingType; public String getFlowId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectOutputModel.java index de26e5e0d..a660f036a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,9 +18,11 @@ import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectType; import java.util.Date; import java.util.List; @@ -30,85 +32,55 @@ */ public class ProjectOutputModel extends AbstractOutputModel { - /** - * 项目ID - */ + @Check(name = "项目ID") private String projectId; - /** - * 名称 - */ + @Check(name = "名称") private String name; - /** - * 项目描述 - */ + @Check(name = "项目描述") private String projectDesc; private AuditStatus auditStatus; - /** - * 自己是否同意 - */ + @Check(name = "自己是否同意") private AuditStatus auditStatusFromMyself; - /** - * 其他人是否同意 - */ + @Check(name = "其他人是否同意") private AuditStatus auditStatusFromOthers; - /** - * 审核意见 - */ + @Check(name = "审核意见") private String auditComment; - /** - * 我方身份;枚举(promoter/provider) - */ + @Check(name = "我方身份;枚举(promoter/provider)") private JobMemberRole myRole; - /** - * 是否是创建者 - */ + @Check(name = "是否是创建者") private boolean isCreator; - /** - * 我方成员ID - */ + @Check(name = "我方成员ID") private String memberId; - /** - * 消息备注 失败原因/备注 - */ + @Check(name = "消息备注 失败原因/备注") private String message; private ProjectDetailMemberOutputModel promoter; private List providerList; private List promoterList; - /** - * 退出项目的操作者 - */ + @Check(name = "退出项目的操作者") private String exitedBy; - /** - * 退出时间 - */ + @Check(name = "退出时间") private Date exitedTime; - /** - * 当前成员是否退出了项目 - */ + @Check(name = "当前成员是否退出了项目") private boolean isExited; - /** - * 是否已关闭 - */ + @Check(name = "是否已关闭") private boolean closed = false; - /** - * 关闭项目的操作者 - */ + @Check(name = "关闭项目的操作者") private String closedBy; - /** - * 关闭时间 - */ + @Check(name = "关闭时间") private Date closedTime; private JObject flowStatusStatistics; + @Check(name = "项目类型") + private ProjectType projectType; public String getExitOperatorNickname() { return CacheObjects.getNickname(exitedBy); @@ -286,4 +258,12 @@ public boolean getIsExited() { public void setIsExited(boolean isExited) { this.isExited = isExited; } + + public ProjectType getProjectType() { + return projectType; + } + + public void setProjectType(ProjectType projectType) { + this.projectType = projectType; + } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectQueryOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectQueryOutputModel.java index 472521381..6490524a1 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectQueryOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectQueryOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,9 +18,11 @@ import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectType; import java.util.Date; import java.util.List; @@ -29,102 +31,64 @@ * @author zane.luo */ public class ProjectQueryOutputModel extends AbstractOutputModel { - /** - * 项目ID - */ + @Check(name = "项目ID") private String projectId; - /** - * 名称 - */ + @Check(name = "名称") private String name; - /** - * 项目描述 - */ + @Check(name = "项目描述") private String projectDesc; private AuditStatus auditStatus; - /** - * 我方身份;枚举(promoter/provider) - */ + @Check(name = "我方身份;枚举(promoter/provider)") private JobMemberRole myRole; - /** - * 我方成员ID - */ + @Check(name = "我方成员ID") private String memberId; - /** - * 状态更新时间 - */ + @Check(name = "状态更新时间") private Date statusUpdatedTime; - /** - * 开始时间 - */ + @Check(name = "开始时间") private Date startTime; - /** - * 结束时间 - */ + @Check(name = "结束时间") private Date finishTime; - /** - * 进度 - */ + @Check(name = "进度") private Integer progress; - /** - * 进度更新时间 - */ + @Check(name = "进度更新时间") private Date progressUpdatedTime; - /** - * 消息备注 失败原因/备注 - */ + @Check(name = "消息备注 失败原因/备注") private String message; private List memberList; - /** - * 发起方ID - */ + @Check(name = "发起方ID") private String promoter; - /** - * 发起方name - */ + @Check(name = "发起方name") private String promoterName; - /** - * 退出项目的操作者 - */ + @Check(name = "退出项目的操作者") private String exitedBy; - /** - * 退出时间 - */ + @Check(name = "退出时间") private Date exitedTime; - /** - * 是否已关闭 - */ + @Check(name = "是否已关闭") private boolean closed = false; - /** - * 关闭项目的操作者 - */ + @Check(name = "关闭项目的操作者") private String closedBy; - /** - * 关闭时间 - */ + @Check(name = "关闭时间") private Date closedTime; - /** - * 各流程状态的统计 - */ + @Check(name = "各流程状态的统计") private JObject flowStatusStatistics; - /** - * 待审核数据集数量 - */ + @Check(name = "待审核数据集数量") private int needMeAuditDataSetCount; + @Check(name = "项目类型") + private ProjectType projectType; public String getExitOperatorNickname() { @@ -314,6 +278,14 @@ public void setNeedMeAuditDataSetCount(int needMeAuditDataSetCount) { this.needMeAuditDataSetCount = needMeAuditDataSetCount; } + public ProjectType getProjectType() { + return projectType; + } + + public void setProjectType(ProjectType projectType) { + this.projectType = projectType; + } + //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectUsageDetailOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectUsageDetailOutputModel.java index 30cbbd9a8..392ff0707 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectUsageDetailOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/ProjectUsageDetailOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,13 @@ package com.welab.wefe.board.service.dto.entity.project; + +import com.welab.wefe.common.fieldvalidate.annotation.Check; import com.welab.wefe.common.web.dto.AbstractApiOutput; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectType; + +import java.util.Date; /** @@ -24,15 +30,24 @@ */ public class ProjectUsageDetailOutputModel extends AbstractApiOutput { - /** - * 项目ID - */ + @Check(name = "项目ID") private String projectId; - - /** - * 名称 - */ + @Check(name = "名称") private String name; + @Check(name = "项目描述") + private String projectDesc; + @Check(name = "我方角色") + private JobMemberRole myRole; + @Check(name = "该项目的创建者ID") + private String memberId; + @Check(name = "开始时间") + private Date startTime; + @Check(name = "结束时间") + private Date finishTime; + @Check(name = "项目类型") + private ProjectType projectType; + + // region getter/setter public String getProjectId() { return projectId; @@ -49,4 +64,55 @@ public String getName() { public void setName(String name) { this.name = name; } + + public String getProjectDesc() { + return projectDesc; + } + + public void setProjectDesc(String projectDesc) { + this.projectDesc = projectDesc; + } + + public JobMemberRole getMyRole() { + return myRole; + } + + public void setMyRole(JobMemberRole myRole) { + this.myRole = myRole; + } + + public String getMemberId() { + return memberId; + } + + public void setMemberId(String memberId) { + this.memberId = memberId; + } + + public Date getStartTime() { + return startTime; + } + + public void setStartTime(Date startTime) { + this.startTime = startTime; + } + + public Date getFinishTime() { + return finishTime; + } + + public void setFinishTime(Date finishTime) { + this.finishTime = finishTime; + } + + public ProjectType getProjectType() { + return projectType; + } + + public void setProjectType(ProjectType projectType) { + this.projectType = projectType; + } + + + // endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/DerivedProjectDataSetOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/data_set/DerivedProjectDataSetOutputModel.java similarity index 88% rename from board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/DerivedProjectDataSetOutputModel.java rename to board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/data_set/DerivedProjectDataSetOutputModel.java index e80471ea6..d75c0f41b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/DerivedProjectDataSetOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/data_set/DerivedProjectDataSetOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.welab.wefe.board.service.dto.entity.project; +package com.welab.wefe.board.service.dto.entity.project.data_set; import com.welab.wefe.board.service.dto.vo.JobMemberWithDataSetOutputModel; @@ -25,7 +25,7 @@ * * @author zane.luo */ -public class DerivedProjectDataSetOutputModel extends ProjectDataSetOutputModel { +public class DerivedProjectDataSetOutputModel extends ProjectDataResourceOutputModel { private List members; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/data_set/ProjectDataResourceOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/data_set/ProjectDataResourceOutputModel.java new file mode 100644 index 000000000..6cb4e065e --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/entity/project/data_set/ProjectDataResourceOutputModel.java @@ -0,0 +1,128 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.entity.project.data_set; + +import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.DataResourceOutputModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; + +import java.util.Date; + +/** + * @author zane.luo + */ +public class ProjectDataResourceOutputModel extends AbstractOutputModel { + + @Check(name = "项目 Id 项目主键") + private String projectId; + @Check(name = "成员id") + private String memberId; + @Check(name = "成员角色", desc = "由于存在自己和自己建模的情况,所以需要用角色区分数据集归属。") + private JobMemberRole memberRole; + @Check(name = "数据集 Id") + private String dataSetId; + @Check(name = "状态") + private AuditStatus auditStatus; + @Check(name = "审核意见") + private String auditComment; + @Check(name = "状态更新时间") + private Date statusUpdatedTime; + @Check(name = "数据集类型") + private DataResourceType dataResourceType; + @Check(name = "数据集详情") + private DataResourceOutputModel dataResource; + + //region getter/setter + + + public JobMemberRole getMemberRole() { + return memberRole; + } + + public void setMemberRole(JobMemberRole memberRole) { + this.memberRole = memberRole; + } + + public String getAuditComment() { + return auditComment; + } + + public void setAuditComment(String auditComment) { + this.auditComment = auditComment; + } + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getMemberId() { + return memberId; + } + + public void setMemberId(String memberId) { + this.memberId = memberId; + } + + public String getDataSetId() { + return dataSetId; + } + + public void setDataSetId(String dataSetId) { + this.dataSetId = dataSetId; + } + + public AuditStatus getAuditStatus() { + return auditStatus; + } + + public void setAuditStatus(AuditStatus auditStatus) { + this.auditStatus = auditStatus; + } + + public Date getStatusUpdatedTime() { + return statusUpdatedTime; + } + + public void setStatusUpdatedTime(Date statusUpdatedTime) { + this.statusUpdatedTime = statusUpdatedTime; + } + + public DataResourceOutputModel getDataResource() { + return dataResource; + } + + public void setDataResource(DataResourceOutputModel dataResource) { + this.dataResource = dataResource; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + //endregion + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/BloomFilterColumnInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/BloomFilterColumnInputModel.java new file mode 100644 index 000000000..999812d2f --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/BloomFilterColumnInputModel.java @@ -0,0 +1,77 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.fusion; + +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ColumnDataType; + +/** + * @author jacky.jiang + */ +public class BloomFilterColumnInputModel extends AbstractCheckModel { + + @Check(name = "字段名称") + private String name; + @Check(name = "数据类型") + private ColumnDataType dataType; + /** + * 注释 + */ + @Check(regex = "^.{0,250}$", messageOnInvalid = "注释太长啦~") + private String comment; + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + if (getDataType() == null) { + throw new StatusCodeWithException("请给字段【" + getName() + "】设置数据类型", StatusCode.PARAMETER_VALUE_INVALID); + } + } + + //region getter/setter + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public ColumnDataType getDataType() { + return dataType; + } + + public void setDataType(ColumnDataType dataType) { + this.dataType = dataType; + } + + public String getComment() { + return comment; + } + + public void setComment(String comment) { + this.comment = comment; + } + + + //endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/BloomFilterColumnOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/BloomFilterColumnOutputModel.java new file mode 100644 index 000000000..37b4d5ca9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/BloomFilterColumnOutputModel.java @@ -0,0 +1,110 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.fusion; + +import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ColumnDataType; + +import java.util.Map; + +/** + * @author Zane + */ +public class BloomFilterColumnOutputModel extends AbstractOutputModel { + + @Check(name = "过滤器Id") + private String bloomfilterId; + @Check(name = "字段序号") + private Integer index; + @Check(name = "字段名称") + private String name; + @Check(name = "数据类型") + private ColumnDataType dataType; + @Check(name = "注释") + private String comment; + @Check(name = "空值数据行数") + private Long emptyRows; + /** + * 数值分布 + */ + private Map valueDistribution; + + + //region getter/setter + + + public String getBloomfilterId() { + return bloomfilterId; + } + + public void setBloomfilterId(String bloomfilterId) { + this.bloomfilterId = bloomfilterId; + } + + public Integer getIndex() { + return index; + } + + public void setIndex(Integer index) { + this.index = index; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public ColumnDataType getDataType() { + return dataType; + } + + public void setDataType(ColumnDataType dataType) { + this.dataType = dataType; + } + + public String getComment() { + return comment; + } + + public void setComment(String comment) { + this.comment = comment; + } + + public Long getEmptyRows() { + return emptyRows; + } + + public void setEmptyRows(Long emptyRows) { + this.emptyRows = emptyRows; + } + + public Map getValueDistribution() { + return valueDistribution; + } + + public void setValueDistribution(Map valueDistribution) { + this.valueDistribution = valueDistribution; + } + + + //endregion + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/BloomFilterTaskOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/BloomFilterTaskOutputModel.java new file mode 100644 index 000000000..9c95f6731 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/BloomFilterTaskOutputModel.java @@ -0,0 +1,112 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.fusion; + +import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; + +/** + * @author jacky.jiang + */ +public class BloomFilterTaskOutputModel extends AbstractOutputModel { + + @Check(name = "过滤器名") + private String bloomFilterName; + @Check(name = "过滤器id") + private String bloomFilterId; + @Check(name = "总数据行数") + private long totalRowCount; + @Check(name = "已写入数据行数") + private long addedRowCount; + @Check(name = "任务进度百分比") + private int progress; + @Check(name = "预计剩余耗时") + private long estimateTime; + @Check(name = "主键重复条数") + private long repeatIdRowCount; + @Check(name = "错误消息") + private String errorMessage; + + // region getter/setter + + + public String getBloomFilterName() { + return bloomFilterName; + } + + public void setBloomFilterName(String bloomFilterName) { + this.bloomFilterName = bloomFilterName; + } + + public String getBloomFilterId() { + return bloomFilterId; + } + + public void setBloomFilterId(String bloomFilterId) { + this.bloomFilterId = bloomFilterId; + } + + public long getTotalRowCount() { + return totalRowCount; + } + + public void setTotalRowCount(long totalRowCount) { + this.totalRowCount = totalRowCount; + } + + public long getAddedRowCount() { + return addedRowCount; + } + + public void setAddedRowCount(long addedRowCount) { + this.addedRowCount = addedRowCount; + } + + public int getProgress() { + return progress; + } + + public void setProgress(int progress) { + this.progress = progress; + } + + public long getEstimateTime() { + return estimateTime; + } + + public void setEstimateTime(long estimateTime) { + this.estimateTime = estimateTime; + } + + public long getRepeatIdRowCount() { + return repeatIdRowCount; + } + + public void setRepeatIdRowCount(long repeatIdRowCount) { + this.repeatIdRowCount = repeatIdRowCount; + } + + public String getErrorMessage() { + return errorMessage; + } + + public void setErrorMessage(String errorMessage) { + this.errorMessage = errorMessage; + } + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/FusionMemberInfo.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/FusionMemberInfo.java new file mode 100644 index 000000000..4303e9b67 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/FusionMemberInfo.java @@ -0,0 +1,128 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.fusion; + + + +import com.welab.wefe.board.service.util.primarykey.FieldInfo; +import com.welab.wefe.board.service.util.primarykey.PrimaryKeyUtils; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; + +import java.util.List; + +/** + * @author hunter.zhao + */ +public class FusionMemberInfo { + String memberId; + String memberName; + JobMemberRole role; + + String dataResourceId; + String dataResourceName; + DataResourceType dataResourceType; + Long rowCount; + String hashFunction; + List fieldInfoList; + + String columnNameList; + + + public String getMemberId() { + return memberId; + } + + public void setMemberId(String memberId) { + this.memberId = memberId; + } + + public String getMemberName() { + return memberName; + } + + public void setMemberName(String memberName) { + this.memberName = memberName; + } + + public JobMemberRole getRole() { + return role; + } + + public void setRole(JobMemberRole role) { + this.role = role; + } + + public String getDataResourceId() { + return dataResourceId; + } + + public void setDataResourceId(String dataResourceId) { + this.dataResourceId = dataResourceId; + } + + public String getDataResourceName() { + return dataResourceName; + } + + public void setDataResourceName(String dataResourceName) { + this.dataResourceName = dataResourceName; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public Long getRowCount() { + return rowCount; + } + + public void setRowCount(Long rowCount) { + this.rowCount = rowCount; + } + + public String getHashFunction() { + return hashFunction; + } + + public void setHashFunction(String hashFunction) { + this.hashFunction = hashFunction; + } + + public void setHashFunction(List fieldInfos) { + this.hashFunction = PrimaryKeyUtils.hashFunction(fieldInfos); + } + + public String getColumnNameList() { + return columnNameList; + } + + public void setColumnNameList(String columnNameList) { + this.columnNameList = columnNameList; + } + + public List getFieldInfoList() { + return fieldInfoList; + } + + public void setFieldInfoList(List fieldInfoList) { + this.fieldInfoList = fieldInfoList; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/FusionResultExportProgress.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/FusionResultExportProgress.java new file mode 100644 index 000000000..00d99e8eb --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/FusionResultExportProgress.java @@ -0,0 +1,116 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.fusion; + + + +import com.welab.wefe.board.service.fusion.enums.ExportStatus; + +/** + * @author hunter.zhao + */ +public class FusionResultExportProgress { + String businessId; + + String tableName; + + int progress; + + int totalDataCount; + + int processedCount; + + ExportStatus status; + + long finishTime; + + public FusionResultExportProgress() { + } + + public FusionResultExportProgress(String businessId, String tableName, int totalDataCount) { + this.businessId = businessId; + this.totalDataCount = totalDataCount; + this.tableName = tableName; + this.status = ExportStatus.exporting; + } + + public int getProgress() { + return Double.valueOf( + Double.valueOf(processedCount) / Double.valueOf(totalDataCount) * 100 + ).intValue(); + } + + public void setProgress(int progress) { + this.progress = progress; + } + + public int getTotalDataCount() { + return totalDataCount; + } + + public void setTotalDataCount(int totalDataCount) { + this.totalDataCount = totalDataCount; + } + + public int getProcessedCount() { + return processedCount; + } + + public void setProcessedCount(int processedCount) { + this.processedCount = processedCount; + } + + public synchronized void increment() { + processedCount++; + + if (processedCount == totalDataCount) { + this.finishTime = System.currentTimeMillis(); + this.status = ExportStatus.success; + } + } + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public String getTableName() { + return tableName; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public ExportStatus getStatus() { + return status; + } + + public void setStatus(ExportStatus status) { + this.status = status; + } + + public long getFinishTime() { + return finishTime; + } + + public void setFinishTime(long finishTime) { + this.finishTime = finishTime; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/FusionTaskOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/FusionTaskOutput.java new file mode 100644 index 000000000..0c3285acf --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/FusionTaskOutput.java @@ -0,0 +1,323 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.fusion; + +import com.welab.wefe.board.service.dto.entity.AbstractOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.BloomFilterOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.TableDataSetOutputModel; +import com.welab.wefe.board.service.fusion.enums.ExportStatus; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.fusion.core.enums.AlgorithmType; +import com.welab.wefe.fusion.core.enums.FusionTaskStatus; +import com.welab.wefe.fusion.core.enums.PSIActuatorRole; + +import javax.persistence.Column; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; +import java.util.List; + +; + +/** + * @author hunter.zhao + */ +public class FusionTaskOutput extends AbstractOutputModel { + + private String businessId; + + String name; + + FusionTaskStatus status; + + String error; + + FusionMemberInfo promoter; + + FusionMemberInfo provider; + + JobMemberRole myRole; + + String dstMemberId; + + String dataResourceId; + + String dataResourceName; + + DataResourceType dataResourceType; + + String hashFunction; + + + @Check(name = "Number of rows of data resources") + Long rowCount; + + String partnerDataResourceId; + + String partnerDataResourceName; + + DataResourceType partnerDataResourceType; + + String partnerHashFunction; + + @Check(name = "Number of rows of data resources") + public Long partnerRowCount; + + @Check(name = "Whether the trace") + public boolean isTrace; + + @Check(name = "Traces the field") + public String traceColumn; + + PSIActuatorRole psiActuatorRole; + + AlgorithmType algorithm; + + @Check(name = "Number of fusion") + public int fusionCount; + + public long spend; + + + private String description; + + + public String comment; + + public ExportStatus ExportStatus; + + + public String getBusinessId() { + return businessId; + } + + public void setBusinessId(String businessId) { + this.businessId = businessId; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public FusionTaskStatus getStatus() { + return status; + } + + public void setStatus(FusionTaskStatus status) { + this.status = status; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + + public String getDataResourceId() { + return dataResourceId; + } + + public void setDataResourceId(String dataResourceId) { + this.dataResourceId = dataResourceId; + } + + public String getDataResourceName() { + return dataResourceName; + } + + public void setDataResourceName(String dataResourceName) { + this.dataResourceName = dataResourceName; + } + + public DataResourceType getDataResourceType() { + return dataResourceType; + } + + public void setDataResourceType(DataResourceType dataResourceType) { + this.dataResourceType = dataResourceType; + } + + public Long getRowCount() { + return rowCount; + } + + public void setRowCount(Long rowCount) { + this.rowCount = rowCount; + } + + public Long getPartnerRowCount() { + return partnerRowCount; + } + + public void setPartnerRowCount(Long partnerRowCount) { + this.partnerRowCount = partnerRowCount; + } + + public PSIActuatorRole getPsiActuatorRole() { + return psiActuatorRole; + } + + public void setPsiActuatorRole(PSIActuatorRole psiActuatorRole) { + this.psiActuatorRole = psiActuatorRole; + } + + public AlgorithmType getAlgorithm() { + return algorithm; + } + + public void setAlgorithm(AlgorithmType algorithm) { + this.algorithm = algorithm; + } + + + public int getFusionCount() { + return fusionCount; + } + + public void setFusionCount(int fusionCount) { + this.fusionCount = fusionCount; + } + + public long getSpend() { + return spend; + } + + public void setSpend(long spend) { + this.spend = spend; + } + + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public String getComment() { + return comment; + } + + public void setComment(String comment) { + this.comment = comment; + } + + public boolean isTrace() { + return isTrace; + } + + public void setTrace(boolean trace) { + isTrace = trace; + } + + public String getTraceColumn() { + return traceColumn; + } + + public void setTraceColumn(String traceColumn) { + this.traceColumn = traceColumn; + } + + public FusionMemberInfo getPromoter() { + return promoter; + } + + public void setPromoter(FusionMemberInfo promoter) { + this.promoter = promoter; + } + + public FusionMemberInfo getProvider() { + return provider; + } + + public void setProvider(FusionMemberInfo provider) { + this.provider = provider; + } + + public JobMemberRole getMyRole() { + return myRole; + } + + public void setMyRole(JobMemberRole myRole) { + this.myRole = myRole; + } + + public String getPartnerDataResourceId() { + return partnerDataResourceId; + } + + public void setPartnerDataResourceId(String partnerDataResourceId) { + this.partnerDataResourceId = partnerDataResourceId; + } + + public String getPartnerDataResourceName() { + return partnerDataResourceName; + } + + public void setPartnerDataResourceName(String partnerDataResourceName) { + this.partnerDataResourceName = partnerDataResourceName; + } + + public DataResourceType getPartnerDataResourceType() { + return partnerDataResourceType; + } + + public void setPartnerDataResourceType(DataResourceType partnerDataResourceType) { + this.partnerDataResourceType = partnerDataResourceType; + } + + public String getDstMemberId() { + return dstMemberId; + } + + public void setDstMemberId(String dstMemberId) { + this.dstMemberId = dstMemberId; + } + + public String getHashFunction() { + return hashFunction; + } + + public void setHashFunction(String hashFunction) { + this.hashFunction = hashFunction; + } + + public String getPartnerHashFunction() { + return partnerHashFunction; + } + + public void setPartnerHashFunction(String partnerHashFunction) { + this.partnerHashFunction = partnerHashFunction; + } + + public com.welab.wefe.board.service.fusion.enums.ExportStatus getExportStatus() { + return ExportStatus; + } + + public void setExportStatus(com.welab.wefe.board.service.fusion.enums.ExportStatus exportStatus) { + ExportStatus = exportStatus; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/PsiMeta.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/PsiMeta.java new file mode 100644 index 000000000..21a3f86bd --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/fusion/PsiMeta.java @@ -0,0 +1,49 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.fusion; + + + + +import com.welab.wefe.common.util.Base64Util; +import org.apache.commons.compress.utils.Lists; + +import java.util.List; + +/** + * @author hunter.zhao + */ +public class PsiMeta { + List bs; + + public List getBs() { + return bs; + } + + public void setBs(List bs) { + this.bs = bs; + } + + public static PsiMeta of(byte[][] bs) { + PsiMeta psiMeta = new PsiMeta(); + List bitStr = Lists.newArrayList(); + for (int i = 0; i < bs.length; i++) { + bitStr.add(Base64Util.encode(bs[i])); + } + psiMeta.bs = bitStr; + return psiMeta; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/AlertConfigModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/AlertConfigModel.java index d3089de1c..585ed7497 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/AlertConfigModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/AlertConfigModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/BoardConfigModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/BoardConfigModel.java index 298919cdb..7131f68c3 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/BoardConfigModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/BoardConfigModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/CalculationEngineConfigModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/CalculationEngineConfigModel.java new file mode 100644 index 000000000..0814b8790 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/CalculationEngineConfigModel.java @@ -0,0 +1,29 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.globalconfig; + +/** + * 计算引擎相关配置 + * + * @author zane + * @date 2021/12/3 + */ +public class CalculationEngineConfigModel { + /** + * SPARK、FC + */ + public String backend = "SPARK"; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/DeepLearningConfigModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/DeepLearningConfigModel.java new file mode 100644 index 000000000..dae8ce550 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/DeepLearningConfigModel.java @@ -0,0 +1,25 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.globalconfig; + +/** + * @author zane + * @date 2021/10/29 + */ +public class DeepLearningConfigModel { + public String device = "cpu"; + public String paddleVisualDlBaseUrl; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/FlowConfigModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/FlowConfigModel.java index 6fb3f2922..5c9725ade 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/FlowConfigModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/FlowConfigModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,10 +16,6 @@ package com.welab.wefe.board.service.dto.globalconfig; -/** - * @author Zane - */ - /** * @author zane.luo */ diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/FunctionComputeConfigModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/FunctionComputeConfigModel.java new file mode 100644 index 000000000..4dd9853b7 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/FunctionComputeConfigModel.java @@ -0,0 +1,25 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.globalconfig; + +/** + * @author zane + * @date 2021/10/29 + */ +public class FunctionComputeConfigModel { + public int maxCostInDay = 500; + public int maxCostInMonth = 1000; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/GatewayConfigModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/GatewayConfigModel.java index cfd28c0e9..b99b7b312 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/GatewayConfigModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/GatewayConfigModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,10 +16,6 @@ package com.welab.wefe.board.service.dto.globalconfig; -/** - * @author Zane - */ - /** * @author zane.luo */ diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/GlobalConfigInput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/GlobalConfigInput.java index b09b1542d..f65abbf26 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/GlobalConfigInput.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/GlobalConfigInput.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/MailServerModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/MailServerModel.java index a5748e36c..eca80d337 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/MailServerModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/MailServerModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,26 +20,20 @@ * @author Zane */ +import com.welab.wefe.common.fieldvalidate.annotation.Check; + /** * @author zane.luo */ public class MailServerModel { - /** - * 邮件服务器地址 - */ + @Check(name = "邮件服务器地址") private String mailHost; - /** - * 邮件服务器端口 - */ + @Check(name = "邮件服务器端口") private Integer mailPort; - /** - * 邮件用户名 - */ + @Check(name = "邮件用户名") private String mailUsername; - /** - * 邮件密码 - */ + @Check(name = "邮件密码") private String mailPassword; // region getter/setter diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/MemberInfoModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/MemberInfoModel.java index a1ac67100..d2198bc23 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/MemberInfoModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/MemberInfoModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,60 +16,41 @@ package com.welab.wefe.board.service.dto.globalconfig; -/** - * @author Zane - */ +import com.welab.wefe.common.constant.SecretKeyType; +import com.welab.wefe.common.fieldvalidate.annotation.Check; /** * @author zane.luo */ public class MemberInfoModel { - /** - * 联邦成员 Id; - * 全局唯一,默认为uuid。 - */ + @Check(name = "联邦成员 Id", desc = "全局唯一,默认为uuid。") private String memberId; - /** - * 联邦成员名称 - */ + @Check(name = "联邦成员名称") private String memberName; - /** - * 联邦成员邮箱 - */ + @Check(name = "联邦成员邮箱") private String memberEmail; - /** - * 联邦成员电话 - */ + @Check(name = "联邦成员电话") private String memberMobile; - /** - * 联邦成员网关访问地址 - */ + @Check(name = "联邦成员网关访问地址") private String memberGatewayUri; - /** - * 是否允许对外公开数据集基础信息 - */ + @Check(name = "是否允许对外公开数据集基础信息") private Boolean memberAllowPublicDataSet; - /** - * 私钥 - */ + @Check(name = "私钥") private String rsaPrivateKey; - /** - * 公钥 - */ + @Check(name = "公钥") private String rsaPublicKey; - /** - * 成员头像 - */ + @Check(name = "成员头像") private String memberLogo; - /** - * 成员隐身状态 - */ + @Check(name = "成员隐身状态") private Boolean memberHidden; + @Check(name = "密钥类型") + private SecretKeyType secretKeyType = SecretKeyType.rsa; + //region getter/setter public String getMemberId() { @@ -152,5 +133,12 @@ public void setMemberHidden(Boolean memberHidden) { this.memberHidden = memberHidden; } + public SecretKeyType getSecretKeyType() { + return secretKeyType; + } + + public void setSecretKeyType(SecretKeyType secretKeyType) { + this.secretKeyType = secretKeyType; + } //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/ServingConfigModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/ServingConfigModel.java index 60f4a9b11..ea79f9a98 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/ServingConfigModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/globalconfig/ServingConfigModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,10 +16,6 @@ package com.welab.wefe.board.service.dto.globalconfig; -/** - * @author Zane - */ - /** * @author zane.luo */ diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Env.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Env.java deleted file mode 100644 index ffd7dd7e7..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Env.java +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.kernel; - -import com.welab.wefe.common.data.storage.common.DBType; -import com.welab.wefe.common.enums.JobBackendType; -import com.welab.wefe.common.enums.env.EnvName; - -/** - * @author zane.luo - */ -public class Env { - private DBType dbType; - private JobBackendType backend; - private int workMode; - private EnvName name; - - - //region getter/setter - - public DBType getDbType() { - return dbType; - } - - public void setDbType(DBType dbType) { - this.dbType = dbType; - } - - public JobBackendType getBackend() { - return backend; - } - - public void setBackend(JobBackendType backend) { - this.backend = backend; - } - - public int getWorkMode() { - return workMode; - } - - public void setWorkMode(int workMode) { - this.workMode = workMode; - } - - public EnvName getName() { - return name; - } - - public void setName(EnvName name) { - this.name = name; - } - - - //endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/KernelJob.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/KernelJob.java deleted file mode 100644 index 34e16965e..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/KernelJob.java +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.kernel; - -import com.welab.wefe.common.enums.FederatedLearningModel; -import com.welab.wefe.common.enums.FederatedLearningType; - -import java.util.List; - -/** - * @author zane.luo - */ -public class KernelJob { - private FederatedLearningType federatedLearningType; - private Project project; - private Env env; - private List members; - private List dataSets; - /** - * Mixed Federation promoter_id - */ - private String mixPromoterMemberId; - private FederatedLearningModel federatedLearningMode; - - //region getter/setter - - - public FederatedLearningType getFederatedLearningType() { - return federatedLearningType; - } - - public void setFederatedLearningType(FederatedLearningType federatedLearningType) { - this.federatedLearningType = federatedLearningType; - } - - public Project getProject() { - return project; - } - - public void setProject(Project project) { - this.project = project; - } - - public Env getEnv() { - return env; - } - - public void setEnv(Env env) { - this.env = env; - } - - public List getMembers() { - return members; - } - - public void setMembers(List members) { - this.members = members; - } - - public List getDataSets() { - return dataSets; - } - - public void setDataSets(List dataSets) { - this.dataSets = dataSets; - } - - public String getMixPromoterMemberId() { - return mixPromoterMemberId; - } - - public void setMixPromoterMemberId(String mixPromoterMemberId) { - this.mixPromoterMemberId = mixPromoterMemberId; - } - - public FederatedLearningModel getFederatedLearningMode() { - return federatedLearningMode; - } - - public void setFederatedLearningMode(FederatedLearningModel federatedLearningMode) { - this.federatedLearningMode = federatedLearningMode; - } - //endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Member.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Member.java index 3bd6eb27e..d42e613b8 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Member.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Member.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,9 +16,20 @@ package com.welab.wefe.board.service.dto.kernel; +import com.welab.wefe.board.service.api.member.GetMemberMachineLearningEnvApi; +import com.welab.wefe.board.service.component.DataIOComponent; import com.welab.wefe.board.service.database.entity.job.JobMemberMySqlModel; +import com.welab.wefe.board.service.dto.kernel.machine_learning.Env; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.board.service.service.GatewayService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.enums.JobBackendType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; + +import java.util.ArrayList; +import java.util.List; + /** * @author zane.luo @@ -28,15 +39,94 @@ public class Member { private String memberId; private String memberName; private JobMemberRole memberRole; + private JobBackendType backend; public Member() { } - public Member(JobMemberMySqlModel member) { - this.memberId = member.getMemberId(); + /** + * 创建用于深度学习的 Member 对象 + */ + public static List forDeepLearning(List members) { + List list = new ArrayList<>(); + if (members == null) { + return list; + } + + for (JobMemberMySqlModel member : members) { + Member m = new Member(); + m.memberId = member.getMemberId(); + m.memberName = CacheObjects.getMemberName(member.getMemberId()); + m.memberRole = member.getJobRole(); + m.backend = null; + list.add(m); + } + return list; + } + + public static List forMachineLearning(List members) { + List list = new ArrayList<>(); + if (members == null) { + return list; + } + + for (JobMemberMySqlModel member : members) { + list.add(forMachineLearning(member)); + } + return list; + } + + /** + * 创建一个新的 Member 对象,用于 Machine Learning 的 Job。 + */ + public static Member forMachineLearning(JobMemberMySqlModel member) { + Member m = new Member(); + m.memberId = member.getMemberId(); + m.memberName = CacheObjects.getMemberName(member.getMemberId()); + m.memberRole = member.getJobRole(); + m.backend = getMemberJobBackendType(member.getMemberId()); + return m; + } + + /** + * 创建一个新的 Member 对象,用于 Machine Learning 的 Job。 + */ + public static Member forMachineLearning(DataIOComponent.DataSetItem dataSetItem) { + Member member = new Member(); + member.setMemberId(dataSetItem.getMemberId()); + member.setMemberName(CacheObjects.getMemberName(dataSetItem.getMemberId())); + member.setMemberRole(dataSetItem.getMemberRole()); + member.backend = getMemberJobBackendType(dataSetItem.getMemberId()); + return member; + } + + /** + * 创建一个新的 Member 对象,用于 Machine Learning 的 Job。 + */ + public static Member forMachineLearning(String memberId, JobMemberRole role) { + Member member = new Member(); + member.setMemberId(memberId); + member.setMemberName(CacheObjects.getMemberName(memberId)); + member.setMemberRole(role); + member.backend = getMemberJobBackendType(memberId); + return member; + } + - this.memberName = CacheObjects.getMemberName(member.getMemberId()); - this.memberRole = member.getJobRole(); + private static JobBackendType getMemberJobBackendType(String memberId) { + // 自己,从本地取。 + if (CacheObjects.isCurrentMember(memberId)) { + return Env.get().getBackend(); + } + + GatewayService gatewayService = Launcher.getBean(GatewayService.class); + Env env = null; + try { + env = gatewayService.callOtherMemberBoard(memberId, GetMemberMachineLearningEnvApi.class, Env.class); + } catch (StatusCodeWithException e) { + return null; + } + return env.getBackend(); } //region getter/setter @@ -65,6 +155,12 @@ public void setMemberRole(JobMemberRole memberRole) { this.memberRole = memberRole; } + public JobBackendType getBackend() { + return backend; + } + public void setBackend(JobBackendType backend) { + this.backend = backend; + } //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/deep_learning/Env.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/deep_learning/Env.java new file mode 100644 index 000000000..0a3ac2ce2 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/deep_learning/Env.java @@ -0,0 +1,135 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.kernel.deep_learning; + +import com.welab.wefe.board.service.component.deep_learning.ImageDataIOComponent; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.common.Convert; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Comparator; +import java.util.LinkedHashMap; + +/** + * @author zane + * @date 2021/11/22 + */ +public class Env { + protected final Logger LOG = LoggerFactory.getLogger(this.getClass()); + /** + * 本方 worker 个数 + *

+ * 计算逻辑: + * 以所有样本集中最小样本数为基数 + * 各成员的 worker 数为自己样本数除以基数,四舍五入取整。 + * 为避免极端情况,最大不超过5。 + */ + public int localWorkerNum; + /** + * worker 总个数,即各方 worker 数之和。 + */ + public int workerNum; + /** + * 本方 worker 索引,多方之间不能重复。 + * e.g: [0,1] + */ + public int[] localTrainerIndexs; + /** + * 设备 cpu/gpu + */ + public String device = "cpu"; + /** + * 是否使用 visualdl 可视化 + */ + public boolean useVdl = true; + /** + * 是否基于上次执行一半的任务继续执行 + */ + public boolean resume = false; + + public Env() { + } + + public Env(ImageDataIOComponent.Params imageDataIoParam) throws StatusCodeWithException { + imageDataIoParam.fillDataSetDetail(); + /** + * 1. 前端应该不允许使用标注量为0的样本 + * 2. ImageDataIO 中阻止了标注量为 0 的样本 + * + * 所以正常情况下不应会让这里出现0 + */ + for (ImageDataIOComponent.DataSetItem dataSetItem : imageDataIoParam.dataSetList) { + if (dataSetItem.dataResource.getLabeledCount() < 1) { + StatusCode + .PARAMETER_VALUE_INVALID + .throwException( + "成员【" + CacheObjects.getMemberName(dataSetItem.memberId) + "】的数据集(" + + dataSetItem.dataResource.getName() + ")已标注样本量为 0," + + "请检查各成员的数据集在 union 中的标注量是否正确。" + ); + } + } + + // 以所有样本集中最小样本数为基数,用于计算各成员需要的 worker 数。 + double min = imageDataIoParam.dataSetList + .stream() + .mapToLong(x -> x.dataResource.getLabeledCount()) + .min() + .orElse(0); + + // 对成员按 member_id 排序,使各成员生成的 worker 顺序一致。 + imageDataIoParam.dataSetList.sort(Comparator.comparing(x -> x.getMemberId())); + + // 计算各方的 worker 数 + LinkedHashMap workerCountMap = new LinkedHashMap<>(); + for (ImageDataIOComponent.DataSetItem dataSetItem : imageDataIoParam.dataSetList) { + int workerCount = Convert.toInt( + Math.round( + dataSetItem.dataResource.getLabeledCount() / min + ) + ); + + // VisualFL暂时不支持多 worker,暂时强制指定1个。 + workerCount = 1; + + // 限制上限 + if (workerCount > 10) { + workerCount = 10; + } + + // is me + if (CacheObjects.getMemberId().equals(dataSetItem.getMemberId())) { + this.localWorkerNum = workerCount; + int startIndex = workerCountMap.values().stream().mapToInt(x -> x).sum(); + int endIndex = startIndex + this.localWorkerNum - 1; + if (startIndex == endIndex) { + this.localTrainerIndexs = new int[]{startIndex}; + } else { + this.localTrainerIndexs = new int[]{startIndex, endIndex}; + } + + } + + workerCountMap.put(dataSetItem.getMemberId(), workerCount); + } + + this.workerNum = workerCountMap.values().stream().mapToInt(x -> x).sum(); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/deep_learning/KernelJob.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/deep_learning/KernelJob.java new file mode 100644 index 000000000..e10ed0ddd --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/deep_learning/KernelJob.java @@ -0,0 +1,36 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.kernel.deep_learning; + +import com.welab.wefe.board.service.dto.kernel.Member; +import com.welab.wefe.common.wefe.enums.JobMemberRole; + +import java.util.List; + +/** + * @author zane + * @date 2021/11/22 + */ +public class KernelJob { + public String projectId; + public String jobId; + public String taskId; + public String jobType = "paddle_fl"; + public JobMemberRole role; + public String memberId; + public Env env; + public List members; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/Env.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/Env.java new file mode 100644 index 000000000..76ecddb3c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/Env.java @@ -0,0 +1,93 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.kernel.machine_learning; + +import com.alibaba.fastjson.annotation.JSONField; +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.board.service.dto.globalconfig.CalculationEngineConfigModel; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.data.storage.common.DBType; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.enums.JobBackendType; +import com.welab.wefe.common.wefe.enums.env.EnvName; + + +/** + * @author zane.luo + */ +public class Env { + private DBType dbType; + private JobBackendType backend; + private int workMode; + private EnvName name; + + + @JSONField(serialize = false) + public static Env get() { + Env env = new Env(); + CalculationEngineConfigModel calculationEngineConfig = Launcher.getBean(GlobalConfigService.class).getCalculationEngineConfig(); + if (StringUtil.isEmpty(calculationEngineConfig.backend)) { + throw new RuntimeException("计算环境未选择,请在[全局设置][计算引擎设置]中指定计算环境。"); + } + + Config config = Launcher.getBean(Config.class); + + env.setBackend(JobBackendType.valueOf(calculationEngineConfig.backend)); + env.setDbType(config.getDbType()); + env.setWorkMode(config.getWorkMode()); + env.setName(config.getEnvName()); + return env; + } + + //region getter/setter + + public DBType getDbType() { + return dbType; + } + + public void setDbType(DBType dbType) { + this.dbType = dbType; + } + + public JobBackendType getBackend() { + return backend; + } + + public void setBackend(JobBackendType backend) { + this.backend = backend; + } + + public int getWorkMode() { + return workMode; + } + + public void setWorkMode(int workMode) { + this.workMode = workMode; + } + + public EnvName getName() { + return name; + } + + public void setName(EnvName name) { + this.name = name; + } + + + //endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/JobDataSet.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/JobDataSet.java similarity index 80% rename from board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/JobDataSet.java rename to board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/JobDataSet.java index 52ec74d46..3c9c090a2 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/JobDataSet.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/JobDataSet.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,10 +14,11 @@ * limitations under the License. */ -package com.welab.wefe.board.service.dto.kernel; +package com.welab.wefe.board.service.dto.kernel.machine_learning; + -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import java.util.List; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/KernelJob.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/KernelJob.java new file mode 100644 index 000000000..562d6f354 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/KernelJob.java @@ -0,0 +1,98 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.kernel.machine_learning; + +import com.welab.wefe.board.service.dto.kernel.Member; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.FederatedLearningModel; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; + +import java.util.List; + +/** + * @author zane.luo + */ +public class KernelJob { + private FederatedLearningType federatedLearningType; + private Project project; + private Env env; + private List members; + private List dataSets; + @Check(name = "Mixed Federation promoter_id") + private String mixPromoterMemberId; + private FederatedLearningModel federatedLearningMode; + + //region getter/setter + + + public FederatedLearningType getFederatedLearningType() { + return federatedLearningType; + } + + public void setFederatedLearningType(FederatedLearningType federatedLearningType) { + this.federatedLearningType = federatedLearningType; + } + + public Project getProject() { + return project; + } + + public void setProject(Project project) { + this.project = project; + } + + public Env getEnv() { + return env; + } + + public void setEnv(Env env) { + this.env = env; + } + + public List getMembers() { + return members; + } + + public void setMembers(List members) { + this.members = members; + } + + public List getDataSets() { + return dataSets; + } + + public void setDataSets(List dataSets) { + this.dataSets = dataSets; + } + + public String getMixPromoterMemberId() { + return mixPromoterMemberId; + } + + public void setMixPromoterMemberId(String mixPromoterMemberId) { + this.mixPromoterMemberId = mixPromoterMemberId; + } + + public FederatedLearningModel getFederatedLearningMode() { + return federatedLearningMode; + } + + public void setFederatedLearningMode(FederatedLearningModel federatedLearningMode) { + this.federatedLearningMode = federatedLearningMode; + } + //endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/KernelTask.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/KernelTask.java similarity index 80% rename from board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/KernelTask.java rename to board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/KernelTask.java index 53dd4f9b6..be23fbed2 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/KernelTask.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/KernelTask.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,7 +14,10 @@ * limitations under the License. */ -package com.welab.wefe.board.service.dto.kernel; +package com.welab.wefe.board.service.dto.kernel.machine_learning; + +import com.welab.wefe.board.service.dto.kernel.Member; +import com.welab.wefe.common.fieldvalidate.annotation.Check; import java.util.List; @@ -32,26 +35,16 @@ public KernelTask(List members) { private List members; - /** - * Mixed Federation promoter_id - */ + @Check(name = "Mixed Federation promoter_id") private String mixPromoterMemberId; - /** - * Whether it is the main node of the current provider - */ + @Check(name = "Whether it is the main node of the current provider") private boolean providerMaster; - /** - * The id of the current provider - */ + @Check(name = "The id of the current provider") private String providerInnerId; - /** - * The primary node id of the current provider - */ + @Check(name = "The primary node id of the current provider") private String providerMasterInnerId; - /** - * Other id of the current provider, not including itself - */ + @Check(name = "Other id of the current provider, not including itself") private List providerOtherInnerId; public List getMembers() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Project.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/Project.java similarity index 87% rename from board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Project.java rename to board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/Project.java index 392e4d645..da6dd40c2 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/Project.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/Project.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.welab.wefe.board.service.dto.kernel; +package com.welab.wefe.board.service.dto.kernel.machine_learning; /** * @author zane.luo diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/TaskConfig.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/TaskConfig.java similarity index 84% rename from board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/TaskConfig.java rename to board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/TaskConfig.java index 04f8c1026..c78154f7b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/TaskConfig.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/kernel/machine_learning/TaskConfig.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,9 +14,11 @@ * limitations under the License. */ -package com.welab.wefe.board.service.dto.kernel; +package com.welab.wefe.board.service.dto.kernel.machine_learning; + -import com.welab.wefe.common.enums.ComponentType; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.ComponentType; import java.util.Map; @@ -27,13 +29,9 @@ public class TaskConfig { private KernelJob job; private ComponentType module; - /** - * 组件的输入相关信息 - */ + @Check(name = "组件的输入相关信息") private Map input; - /** - * 组件的输出相关信息 - */ + @Check(name = "组件的输出相关信息") private Map output; /** diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/union/UnionDataSetOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/union/UnionDataSetOutput.java index 4ebc93416..555c636bf 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/union/UnionDataSetOutput.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/union/UnionDataSetOutput.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/AccountInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/AccountInputModel.java index 1626deba6..210c6dca9 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/AccountInputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/AccountInputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/AuditStatusCounts.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/AuditStatusCounts.java index e0789aba4..44477ea24 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/AuditStatusCounts.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/AuditStatusCounts.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,8 @@ package com.welab.wefe.board.service.dto.vo; -import com.welab.wefe.common.enums.AuditStatus; + +import com.welab.wefe.common.wefe.enums.AuditStatus; /** * @author zane.luo diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/DataSetAddInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/DataSetAddInputModel.java deleted file mode 100644 index 5609fe1ee..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/DataSetAddInputModel.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.vo; - -import com.welab.wefe.board.service.constant.Config; -import com.welab.wefe.board.service.constant.DataSetAddMethod; -import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.Launcher; -import org.apache.commons.lang3.StringUtils; - -import java.util.List; - -/** - * @author zane.luo - */ -public class DataSetAddInputModel extends DataSetBaseInputModel { - @Check(name = "数据集名称", require = true, regex = "^.{4,30}$", messageOnInvalid = "数据集名称长度不能少于4,不能大于30") - private String name; - - @Check(name = "关键词", require = true, regex = "^.{1,128}$", messageOnInvalid = "关键词太多了啦~") - private List tags; - - @Check(name = "描述", regex = "^.{0,3072}$", messageOnInvalid = "你写的描述太多了~") - private String description; - - @Check(messageOnEmpty = "请指定数据集文件") - private String filename; - - @Check(require = true) - private DataSetAddMethod dataSetAddMethod; - - @Check(require = true, name = "是否需要去重") - private boolean deduplication; - - @Check(name = "数据源id") - private String dataSourceId; - - @Check(name = "sql脚本") - private String sql; - - public DataSetAddInputModel() { - } - - public DataSetAddInputModel(String dataSourceId, String sql) { - this.dataSourceId = dataSourceId; - this.sql = sql; - } - - @Override - public void checkAndStandardize() throws StatusCodeWithException { - super.checkAndStandardize(); - - // 如果来源是数据库,则要求dataSourceId、sql不能为空 - if (DataSetAddMethod.Database.equals(dataSetAddMethod)) { - if (StringUtils.isEmpty(dataSourceId)) { - throw new StatusCodeWithException("dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); - } - - if (StringUtils.isEmpty(sql)) { - throw new StatusCodeWithException("请填入sql查询语句", StatusCode.PARAMETER_CAN_NOT_BE_EMPTY); - } - } else { - if (StringUtils.isEmpty(filename)) { - throw new StatusCodeWithException("请指定数据集文件", StatusCode.PARAMETER_CAN_NOT_BE_EMPTY); - } - - // 如果是指定服务器上的本地文件,则必须指定配置文件配置的目录下的文件。 - if (DataSetAddMethod.LocalFile.equals(dataSetAddMethod)) { - Config config = Launcher.CONTEXT.getBean(Config.class); - - if (!filename.startsWith(config.getFileUploadDir())) { - StatusCode - .PARAMETER_VALUE_INVALID - .throwException("您指定的文件路径必须以 " + config.getFileUploadDir() + " 开头,请手动将数据集文件拷贝到该目录后重试。"); - } - } - - } - } - - //region getter/setter - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public List getTags() { - return tags; - } - - public void setTags(List tags) { - this.tags = tags; - } - - public String getDescription() { - return description; - } - - public void setDescription(String description) { - this.description = description; - } - - public String getFilename() { - return filename; - } - - public void setFilename(String filename) { - this.filename = filename; - } - - public DataSetAddMethod getDataSetAddMethod() { - return dataSetAddMethod; - } - - public void setDataSetAddMethod(DataSetAddMethod dataSetAddMethod) { - this.dataSetAddMethod = dataSetAddMethod; - } - - public boolean isDeduplication() { - return deduplication; - } - - public void setDeduplication(boolean deduplication) { - this.deduplication = deduplication; - } - - public String getDataSourceId() { - return dataSourceId; - } - - public void setDataSourceId(String dataSourceId) { - this.dataSourceId = dataSourceId; - } - - public String getSql() { - return sql; - } - - public void setSql(String sql) { - this.sql = sql; - } - - //endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/DataSetAddOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/DataSetAddOutput.java deleted file mode 100644 index ce4abc1a0..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/DataSetAddOutput.java +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.vo; - -/** - * @author Zane - */ -public class DataSetAddOutput { - private String id; - private long repeatDataCount; - - - public DataSetAddOutput() { - } - - public DataSetAddOutput(String id, int repeatDataCount) { - this.id = id; - this.repeatDataCount = repeatDataCount; - } - - - //region getter/setter - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - public long getRepeatDataCount() { - return repeatDataCount; - } - - public void setRepeatDataCount(long repeatDataCount) { - this.repeatDataCount = repeatDataCount; - } - - //endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/DataSetBaseInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/DataSetBaseInputModel.java deleted file mode 100644 index cbe90a6cf..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/DataSetBaseInputModel.java +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.vo; - -import com.welab.wefe.board.service.dto.entity.data_set.DataSetColumnInputModel; -import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.ColumnDataType; -import com.welab.wefe.common.enums.DataSetPublicLevel; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.fieldvalidate.annotation.Check; -import com.welab.wefe.common.web.dto.AbstractApiInput; -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.lang3.StringUtils; - -import java.util.List; - -/** - * @author zane.luo - */ -public class DataSetBaseInputModel extends AbstractApiInput { - @Check(name = "可见级别", require = true) - private DataSetPublicLevel publicLevel; - @Check( - name = "可见成员列表", - desc = "只有在列表中的联邦成员才可以看到该数据集的基本信息", - regex = "^.{0,3072}$", - messageOnInvalid = "你选择的 member 太多了~" - ) - private String publicMemberList; - @Check(require = true) - private List metadataList; - - @Override - public void checkAndStandardize() throws StatusCodeWithException { - super.checkAndStandardize(); - - if (publicLevel == DataSetPublicLevel.PublicWithMemberList && StringUtils.isEmpty(publicMemberList)) { - throw new StatusCodeWithException("请指定可见成员", StatusCode.PARAMETER_VALUE_INVALID); - } - - if (CollectionUtils.isEmpty(metadataList)) { - throw new StatusCodeWithException("请设置该数据集的元数据", StatusCode.PARAMETER_VALUE_INVALID); - } - - for (DataSetColumnInputModel item : metadataList) { - if (item.getDataType() == null) { - throw new StatusCodeWithException("请给字段【" + item.getName() + "】设置数据类型", StatusCode.PARAMETER_VALUE_INVALID); - } - } - } - - - // region getter/setter - - public DataSetPublicLevel getPublicLevel() { - return publicLevel; - } - - public void setPublicLevel(DataSetPublicLevel publicLevel) { - this.publicLevel = publicLevel; - } - - public String getPublicMemberList() { - return publicMemberList; - } - - public void setPublicMemberList(String publicMemberList) { - this.publicMemberList = publicMemberList; - } - - public List getMetadataList() { - return metadataList; - } - - public void setMetadataList(List metadataList) { - this.metadataList = metadataList; - } - - - // endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobArbiterInfo.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobArbiterInfo.java index 046eec2cc..3b2ee5fc5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobArbiterInfo.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobArbiterInfo.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,20 +16,18 @@ package com.welab.wefe.board.service.dto.vo; +import com.welab.wefe.common.fieldvalidate.annotation.Check; + /** * Save the arbiter information of the current member in a process * * @author winter.zou */ public class JobArbiterInfo { - /** - * Whether there is an arbiter - */ + @Check(name = "Whether there is an arbiter") private boolean hasArbiter; - /** - * member_id of arbiter - */ + @Check(name = "member_id of arbiter") private String arbiterMemberId; public boolean isHasArbiter() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobMemberWithDataSetOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobMemberWithDataSetOutputModel.java index f868cd359..75196886a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobMemberWithDataSetOutputModel.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobMemberWithDataSetOutputModel.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,15 +17,14 @@ package com.welab.wefe.board.service.dto.vo; import com.welab.wefe.board.service.dto.entity.job.JobMemberOutputModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; /** * @author zane.luo */ public class JobMemberWithDataSetOutputModel extends JobMemberOutputModel { private String featureNameList; - /** - * 特征数量 - */ + @Check(name = "特征数量") private Integer featureCount; //region getter/setter diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobProgressOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobProgressOutput.java index f90dfa7ed..5ab01c62a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobProgressOutput.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/JobProgressOutput.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,9 +20,10 @@ import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.dto.entity.job.JobMemberOutputModel; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.JobStatus; -import com.welab.wefe.common.enums.TaskStatus; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.JobStatus; +import com.welab.wefe.common.wefe.enums.TaskStatus; + /** * 任务执行进度 diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/MemberServiceStatusOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/MemberServiceStatusOutput.java deleted file mode 100644 index ad3d313ef..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/MemberServiceStatusOutput.java +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.vo; - -import com.welab.wefe.common.enums.MemberService; - -/** - * @author zane - */ -public class MemberServiceStatusOutput { - private MemberService service; - private String value; - private boolean success; - private String message; - private String name; - private Long spend; - - public MemberServiceStatusOutput() { - } - - public MemberServiceStatusOutput(MemberService service) { - this.service = service; - } - - // region getter/setter - - public MemberService getService() { - return service; - } - - public void setService(MemberService service) { - this.service = service; - } - - public String getValue() { - return value; - } - - public void setValue(String value) { - this.value = value; - } - - public boolean isSuccess() { - return success; - } - - public void setSuccess(boolean success) { - this.success = success; - } - - public String getMessage() { - return message; - } - - public void setMessage(String message) { - this.message = message; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public Long getSpend() { - return spend; - } - - public void setSpend(Long spend) { - this.spend = spend; - } - // endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/OnlineAccountOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/OnlineAccountOutput.java index 508b09a41..28f44bc5b 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/OnlineAccountOutput.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/OnlineAccountOutput.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,15 +16,15 @@ package com.welab.wefe.board.service.dto.vo; +import com.welab.wefe.common.fieldvalidate.annotation.Check; + /** * 在线账号 * * @author aaron.li **/ public class OnlineAccountOutput { - /** - * 账号ID - */ + @Check(name = "账号ID") private String accountId; public String getAccountId() { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/RoleCounts.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/RoleCounts.java index 23f23bf3d..77b4ac1cd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/RoleCounts.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/RoleCounts.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,7 +16,8 @@ package com.welab.wefe.board.service.dto.vo; -import com.welab.wefe.common.enums.JobMemberRole; + +import com.welab.wefe.common.wefe.enums.JobMemberRole; /** * @author zane.luo diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/ServerCheckPointOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/ServerCheckPointOutput.java deleted file mode 100644 index d7d03b6a2..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/ServerCheckPointOutput.java +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.vo; - -/** - * @author zane - */ -public class ServerCheckPointOutput { - private String desc; - private boolean success; - private String message; - private String value; - private Long spend; - - public ServerCheckPointOutput() { - } - - public static ServerCheckPointOutput success(String name, String desc, String value, long spend) { - ServerCheckPointOutput output = new ServerCheckPointOutput(); - output.setDesc(desc); - output.setSuccess(false); - output.setMessage("success"); - output.setValue(value); - output.setSpend(spend); - return output; - } - - public static ServerCheckPointOutput fail(String name, String desc, String value, long spend, Exception e) { - ServerCheckPointOutput output = new ServerCheckPointOutput(); - output.setDesc(desc); - output.setSuccess(false); - output.setMessage(e.getMessage()); - output.setValue(value); - output.setSpend(spend); - return output; - } - - // region getter/setter - - public String getDesc() { - return desc; - } - - public void setDesc(String desc) { - this.desc = desc; - } - - public boolean isSuccess() { - return success; - } - - public void setSuccess(boolean success) { - this.success = success; - } - - public String getMessage() { - return message; - } - - public void setMessage(String message) { - this.message = message; - } - - public String getValue() { - return value; - } - - public void setValue(String value) { - this.value = value; - } - - public Long getSpend() { - return spend; - } - - public void setSpend(Long spend) { - this.spend = spend; - } - - - // endregion -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/ServiceAvailableOutput.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/ServiceAvailableOutput.java deleted file mode 100644 index 16b4c8a0a..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/ServiceAvailableOutput.java +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.dto.vo; - -import com.welab.wefe.common.enums.MemberService; - -import java.util.List; - -/** - * 服务可用性 - * - * @author aaron.li - **/ -public class ServiceAvailableOutput { - /** - * 服务名 - */ - private MemberService service; - - /** - * 是否成功(当其下的所有服务列表为true时该值才为true,否则为false) - */ - private boolean success; - /** - * 描述 - */ - private String message; - - /** - * 相应服务列表 - */ - private List memberServiceStatusOutputList; - - public MemberService getService() { - return service; - } - - public void setService(MemberService service) { - this.service = service; - } - - public boolean isSuccess() { - return success; - } - - public void setSuccess(boolean success) { - this.success = success; - } - - public List getMemberServiceStatusOutputList() { - return memberServiceStatusOutputList; - } - - public void setMemberServiceStatusOutputList(List memberServiceStatusOutputList) { - this.memberServiceStatusOutputList = memberServiceStatusOutputList; - } - - public String getMessage() { - return message; - } - - public void setMessage(String message) { - this.message = message; - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/AbstractDataResourceUpdateInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/AbstractDataResourceUpdateInputModel.java new file mode 100644 index 000000000..159b42251 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/AbstractDataResourceUpdateInputModel.java @@ -0,0 +1,155 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.vo.data_resource; + +import com.welab.wefe.board.service.database.repository.data_resource.DataResourceRepository; +import com.welab.wefe.board.service.dto.globalconfig.MemberInfoModel; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.wefe.enums.DataResourcePublicLevel; +import org.apache.commons.lang3.StringUtils; + +import java.util.List; + +/** + * @author zane.luo + */ +public class AbstractDataResourceUpdateInputModel extends AbstractApiInput { + private String id; + + @Check(name = "数据集名称", require = true, regex = "^.{4,30}$", messageOnInvalid = "数据集名称长度不能少于4,不能大于30") + private String name; + @Check(name = "描述", regex = "^.{0,3072}$", messageOnInvalid = "你写的描述太多了~") + private String description; + @Check(name = "关键词", require = true, regex = "^.{1,128}$", messageOnInvalid = "关键词太多了啦~") + private List tags; + + + @Check(name = "可见级别", require = true) + private DataResourcePublicLevel publicLevel; + @Check( + name = "可见成员列表", + desc = "只有在列表中的联邦成员才可以看到该数据集的基本信息", + regex = "^.{0,3072}$", + messageOnInvalid = "你选择的 member 太多了~" + ) + private String publicMemberList; + + + public AbstractDataResourceUpdateInputModel() { + } + + public AbstractDataResourceUpdateInputModel(String name, List tags, String description) { + this.name = name; + this.tags = tags; + this.description = description; + } + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + // 当全局拒绝暴露时,禁止选择暴露资源。 + MemberInfoModel member = Launcher.getBean(GlobalConfigService.class).getMemberInfo(); + if (publicLevel != DataResourcePublicLevel.OnlyMyself) { + if (!member.getMemberAllowPublicDataSet()) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("当前联邦成员不允许资源对外可见,请在[全局设置][成员设置]中开启。"); + } + + if (member.getMemberHidden()) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("当前联邦成员已被管理员隐身,隐身状态下不允许资源可见。"); + } + } + + + if (publicLevel == DataResourcePublicLevel.PublicWithMemberList && StringUtils.isEmpty(publicMemberList)) { + throw new StatusCodeWithException("请指定可见成员", StatusCode.PARAMETER_VALUE_INVALID); + } + + int countByName = 0; + DataResourceRepository repository = Launcher.getBean(DataResourceRepository.class); + if (StringUtil.isEmpty(id)) { + countByName = repository.countByName(name); + } else { + countByName = repository.countByName(name, id); + } + + if (countByName > 0) { + throw new StatusCodeWithException("此资源名称已存在,请换一个名称", StatusCode.PARAMETER_VALUE_INVALID); + } + } + + + // region getter/setter + + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public List getTags() { + return tags; + } + + public void setTags(List tags) { + this.tags = tags; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public DataResourcePublicLevel getPublicLevel() { + return publicLevel; + } + + public void setPublicLevel(DataResourcePublicLevel publicLevel) { + this.publicLevel = publicLevel; + } + + public String getPublicMemberList() { + return publicMemberList; + } + + public void setPublicMemberList(String publicMemberList) { + this.publicMemberList = publicMemberList; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/BloomFilterAddInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/BloomFilterAddInputModel.java new file mode 100644 index 000000000..f752937bd --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/BloomFilterAddInputModel.java @@ -0,0 +1,146 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.vo.data_resource; + +import com.welab.wefe.board.service.constant.BloomfilterAddMethod; +import com.welab.wefe.board.service.constant.DataSetAddMethod; +import com.welab.wefe.board.service.util.primarykey.FieldInfo; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; + +import java.util.List; + +/** + * @author jacky.jiang + */ +public class BloomFilterAddInputModel extends BloomFilterUpdateInputModel { + + @Check(messageOnEmpty = "请指定过滤器文件") + private String filename; + @Check(require = true) + private BloomfilterAddMethod bloomfilterAddMethod; + + @Check(require = true, name = "是否需要去重") + private boolean deduplication; + + @Check(name = "数据源id") + private String dataSourceId; + + @Check(name = "sql脚本") + private String sql; + + @Check(name = "选择的id特征列") + private String hashFunction; + + @Check(name = "主键处理") + private List fieldInfoList; + + public BloomFilterAddInputModel() { + } + + public BloomFilterAddInputModel(String dataSourceId, String sql) { + this.dataSourceId = dataSourceId; + this.sql = sql; + } + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + if(CollectionUtils.isEmpty(fieldInfoList)){ + throw new StatusCodeWithException("请设置主键!", StatusCode.PARAMETER_VALUE_INVALID); + } + + // 如果来源是数据库,则要求dataSourceId、sql不能为空 + if (DataSetAddMethod.Database.equals(bloomfilterAddMethod)) { + if (StringUtils.isEmpty(dataSourceId)) { + throw new StatusCodeWithException("dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); + } + + if (StringUtils.isEmpty(sql)) { + throw new StatusCodeWithException("请填入sql查询语句", StatusCode.PARAMETER_CAN_NOT_BE_EMPTY); + } + } else { + if (StringUtils.isEmpty(filename)) { + throw new StatusCodeWithException("请指定数据集文件", StatusCode.PARAMETER_CAN_NOT_BE_EMPTY); + } + } + } + + //region getter/setter + + + public String getFilename() { + return filename; + } + + public void setFilename(String filename) { + this.filename = filename; + } + + public BloomfilterAddMethod getBloomfilterAddMethod() { + return bloomfilterAddMethod; + } + + public void setBloomfilterAddMethod(BloomfilterAddMethod bloomfilterAddMethod) { + this.bloomfilterAddMethod = bloomfilterAddMethod; + } + + public boolean isDeduplication() { + return deduplication; + } + + public void setDeduplication(boolean deduplication) { + this.deduplication = deduplication; + } + + public String getDataSourceId() { + return dataSourceId; + } + + public void setDataSourceId(String dataSourceId) { + this.dataSourceId = dataSourceId; + } + + public String getSql() { + return sql; + } + + public void setSql(String sql) { + this.sql = sql; + } + + public String getHashFunction() { + return hashFunction; + } + + public void setHashFunction(String hashFunction) { + this.hashFunction = hashFunction; + } + + public List getFieldInfoList() { + return fieldInfoList; + } + + public void setFieldInfoList(List fieldInfoList) { + this.fieldInfoList = fieldInfoList; + } +//endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/BloomFilterUpdateInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/BloomFilterUpdateInputModel.java new file mode 100644 index 000000000..f9a336015 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/BloomFilterUpdateInputModel.java @@ -0,0 +1,41 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource; + +import com.welab.wefe.board.service.dto.fusion.BloomFilterColumnInputModel; + +import java.util.List; + +/** + * @author jacky.jiang + * @date 2021/12/2 + */ +public class BloomFilterUpdateInputModel extends AbstractDataResourceUpdateInputModel { + private List metadataList; + + // region getter/setter + + public List getMetadataList() { + return metadataList; + } + + public void setMetadataList(List metadataList) { + this.metadataList = metadataList; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/DataResourceAddOutputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/DataResourceAddOutputModel.java new file mode 100644 index 000000000..6893a1008 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/DataResourceAddOutputModel.java @@ -0,0 +1,35 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource; + +import com.welab.wefe.board.service.service.AbstractService; + +/** + * @author zane + * @date 2021/12/3 + */ +public class DataResourceAddOutputModel extends AbstractService { + public String dataResourceId; + public String uploadTaskId; + + public DataResourceAddOutputModel() { + } + + public DataResourceAddOutputModel(String dataResourceId, String uploadTaskId) { + this.dataResourceId = dataResourceId; + this.uploadTaskId = uploadTaskId; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/ImageDataSetAddInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/ImageDataSetAddInputModel.java new file mode 100644 index 000000000..5a8bcac55 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/ImageDataSetAddInputModel.java @@ -0,0 +1,71 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.vo.data_resource; + + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; + +import java.io.File; + +/** + * @author zane.luo + */ +public class ImageDataSetAddInputModel extends ImageDataSetUpdateInputModel { + @Check(require = true, messageOnEmpty = "请指定数据集文件") + public String filename; + @Check(name = "数据集应用的任务类型", require = true) + public DeepLearningJobType forJobType; + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + File file = WeFeFileSystem.getFilePath(DataResourceType.ImageDataSet, filename).toFile(); + + if (!file.exists()) { + StatusCode + .FILE_IO_ERROR + .throwException("未找到文件:" + filename + ",请重试刷新页面后重新上传。"); + } + } + + // region getter/setter + + public String getFilename() { + return filename; + } + + public void setFilename(String filename) { + this.filename = filename; + } + + public DeepLearningJobType getForJobType() { + return forJobType; + } + + public void setForJobType(DeepLearningJobType forJobType) { + this.forJobType = forJobType; + } + + // endregion + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/ImageDataSetUpdateInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/ImageDataSetUpdateInputModel.java new file mode 100644 index 000000000..b0fce49d5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/ImageDataSetUpdateInputModel.java @@ -0,0 +1,24 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource; + +/** + * @author zane + * @date 2021/11/8 + */ +public class ImageDataSetUpdateInputModel extends AbstractDataResourceUpdateInputModel { + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/TableDataSetAddInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/TableDataSetAddInputModel.java new file mode 100644 index 000000000..9a54bbddb --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/TableDataSetAddInputModel.java @@ -0,0 +1,130 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.dto.vo.data_resource; + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.constant.DataSetAddMethod; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import org.apache.commons.lang3.StringUtils; + +/** + * @author zane.luo + */ +public class TableDataSetAddInputModel extends TableDataSetUpdateInputModel { + @Check(messageOnEmpty = "请指定数据集文件") + private String filename; + @Check(require = true) + private DataSetAddMethod dataSetAddMethod; + + @Check(require = true, name = "是否需要去重") + private boolean deduplication; + + @Check(name = "数据源id") + private String dataSourceId; + + @Check(name = "sql脚本") + private String sql; + + public TableDataSetAddInputModel() { + } + + public TableDataSetAddInputModel(String dataSourceId, String sql) { + this.dataSourceId = dataSourceId; + this.sql = sql; + } + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + switch (dataSetAddMethod) { + case Database: + if (StringUtils.isEmpty(dataSourceId)) { + throw new StatusCodeWithException("dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); + } + + if (StringUtils.isEmpty(sql)) { + throw new StatusCodeWithException("请填入sql查询语句", StatusCode.PARAMETER_CAN_NOT_BE_EMPTY); + } + break; + case HttpUpload: + case LocalFile: + if (StringUtils.isEmpty(filename)) { + throw new StatusCodeWithException("请指定数据集文件", StatusCode.PARAMETER_CAN_NOT_BE_EMPTY); + } + break; + default: + } + + // 如果是指定服务器上的本地文件,则必须指定配置文件配置的目录下的文件。 + if (DataSetAddMethod.LocalFile.equals(dataSetAddMethod)) { + String rootDir = WeFeFileSystem.getRootDir().toAbsolutePath().toString(); + if (!filename.startsWith(rootDir)) { + StatusCode + .PARAMETER_VALUE_INVALID + .throwException("您指定的文件路径必须以 " + rootDir + " 开头,请手动将数据集文件拷贝到该目录后重试。"); + } + } + + } + + //region getter/setter + + public DataSetAddMethod getDataSetAddMethod() { + return dataSetAddMethod; + } + + public void setDataSetAddMethod(DataSetAddMethod dataSetAddMethod) { + this.dataSetAddMethod = dataSetAddMethod; + } + + public boolean isDeduplication() { + return deduplication; + } + + public void setDeduplication(boolean deduplication) { + this.deduplication = deduplication; + } + + public String getDataSourceId() { + return dataSourceId; + } + + public void setDataSourceId(String dataSourceId) { + this.dataSourceId = dataSourceId; + } + + public String getSql() { + return sql; + } + + public void setSql(String sql) { + this.sql = sql; + } + + public String getFilename() { + return filename; + } + + public void setFilename(String filename) { + this.filename = filename; + } + + //endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/TableDataSetUpdateInputModel.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/TableDataSetUpdateInputModel.java new file mode 100644 index 000000000..c720b2347 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/TableDataSetUpdateInputModel.java @@ -0,0 +1,56 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource; + + +import com.welab.wefe.board.service.dto.entity.data_set.DataSetColumnInputModel; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import org.apache.commons.collections4.CollectionUtils; + +import java.util.List; + +/** + * @author zane + * @date 2021/11/8 + */ +public class TableDataSetUpdateInputModel extends AbstractDataResourceUpdateInputModel { + @Check(require = true) + private List metadataList; + + @Override + public void checkAndStandardize() throws StatusCodeWithException { + super.checkAndStandardize(); + + if (CollectionUtils.isEmpty(metadataList)) { + throw new StatusCodeWithException("请设置该数据集的元数据", StatusCode.PARAMETER_VALUE_INVALID); + } + } + + // region getter/setter + + public List getMetadataList() { + return metadataList; + } + + public void setMetadataList(List metadataList) { + this.metadataList = metadataList; + } + + + // endregion +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Annotation.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Annotation.java new file mode 100644 index 000000000..66ded2299 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Annotation.java @@ -0,0 +1,79 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource.image_data_set; + +import com.thoughtworks.xstream.annotations.XStreamAlias; +import com.thoughtworks.xstream.annotations.XStreamImplicit; +import com.welab.wefe.common.util.StringUtil; + +import java.util.ArrayList; +import java.util.List; + +/** + * https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/docs/tutorials/PrepareDataSet.md + * + * @author zane + * @date 2021/11/8 + */ +@XStreamAlias("annotation") +public class Annotation { + public String folder; + public String filename; + public String path; + public Source source; + public Size size; + /** + * 暂时没用到,先使用默认值。 + */ + public int segmented = 0; + @XStreamImplicit + public List objectList; + + public LabelInfo toLabelInfo() { + LabelInfo labelInfo = new LabelInfo(); + + if (objectList != null) { + for (Object object : objectList) { + LabelInfo.Item item = new LabelInfo.Item( + object.name, + object.bndbox.xmin, + object.bndbox.ymin, + object.bndbox.xmax, + object.bndbox.ymax + ); + + labelInfo.objects.add(item); + } + + } + + return labelInfo; + } + + public List getLabelList() { + List list = new ArrayList<>(); + if (objectList == null) { + return list; + } + + for (Object object : objectList) { + if (StringUtil.isNotEmpty(object.name)) { + list.add(object.name); + } + } + return list; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Bndbox.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Bndbox.java new file mode 100644 index 000000000..d61e59645 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Bndbox.java @@ -0,0 +1,27 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource.image_data_set; + +/** + * @author zane + * @date 2021/11/8 + */ +public class Bndbox { + public int xmin; + public int ymin; + public int xmax; + public int ymax; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/LabelInfo.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/LabelInfo.java new file mode 100644 index 000000000..0f5d2af40 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/LabelInfo.java @@ -0,0 +1,95 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource.image_data_set; + +import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * @author zane + * @date 2021/11/12 + */ +public class LabelInfo extends AbstractCheckModel { + @Check(name = "图片中标记的对象列表") + public List objects = new ArrayList<>(); + + public List labelList() { + List list = new ArrayList<>(); + if (objects == null || objects.isEmpty()) { + return list; + } + + list = objects + .stream() + .filter(x -> StringUtil.isNotEmpty(x.label)) + .map(x -> x.label) + .collect(Collectors.toList()); + + return list; + } + + /** + * 是否包含标注信息 + */ + public boolean isLabeled() { + if (objects == null || objects.isEmpty()) { + return false; + } + return objects.stream().anyMatch(x -> StringUtil.isNotEmpty(x.label)); + } + + public static class Item extends AbstractCheckModel { + + public String label; + /** + * 是否:难以识别的物体 + */ + public boolean difficult = false; + /** + * 是否:遮挡超过15-20% + */ + public boolean truncated = false; + public List points; + + public Item() { + } + + public Item(String label, int minX, int minY, int maxX, int maxY) { + this.label = label; + this.points = new ArrayList<>(); + this.points.add(new Point(minX, minY)); + this.points.add(new Point(maxX, maxY)); + } + } + + public static class Point extends AbstractCheckModel { + public int x; + public int y; + + public Point() { + } + + public Point(int x, int y) { + this.x = x; + this.y = y; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Object.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Object.java new file mode 100644 index 000000000..6ec358720 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Object.java @@ -0,0 +1,40 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource.image_data_set; + +import com.thoughtworks.xstream.annotations.XStreamAlias; + +/** + * @author zane + * @date 2021/11/8 + */ +@XStreamAlias("object") +public class Object { + public String name; + /** + * 关于目标物体姿态描述(非必须字段) + */ + public String pose = "Unspecified"; + /** + * 如果物体的遮挡超过15-20%并且位于边界框之外,请标记为truncated(非必须字段) + */ + public int truncated; + /** + * 难以识别的物体标记为difficult(非必须字段) + */ + public int difficult; + public Bndbox bndbox; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Size.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Size.java new file mode 100644 index 000000000..541615cd2 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Size.java @@ -0,0 +1,26 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource.image_data_set; + +/** + * @author zane + * @date 2021/11/8 + */ +public class Size { + public int width; + public int height; + public int depth; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Source.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Source.java new file mode 100644 index 000000000..57bf61911 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_resource/image_data_set/Source.java @@ -0,0 +1,24 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_resource.image_data_set; + +/** + * @author zane + * @date 2021/11/8 + */ +public class Source { + public String database = "Unknown"; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Annotation.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Annotation.java new file mode 100644 index 000000000..ee29c9845 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Annotation.java @@ -0,0 +1,82 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_set.image_data_set; + +import com.thoughtworks.xstream.annotations.XStreamAlias; +import com.thoughtworks.xstream.annotations.XStreamImplicit; +import com.welab.wefe.common.util.StringUtil; + +import java.util.ArrayList; +import java.util.List; + +/** + * https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/docs/tutorials/PrepareDataSet.md + * + * @author zane + * @date 2021/11/8 + */ +@XStreamAlias("annotation") +public class Annotation { + /** + * folder的取值为:train、test + */ + public String folder; + public String filename; + public String path; + public Source source; + public Size size; + /** + * 暂时没用到,先使用默认值。 + */ + public int segmented = 0; + @XStreamImplicit + public List objectList; + + public LabelInfo toLabelInfo() { + LabelInfo labelInfo = new LabelInfo(); + + if (objectList != null) { + for (Object object : objectList) { + LabelInfo.Item item = new LabelInfo.Item( + object.name, + object.bndbox.xmin, + object.bndbox.ymin, + object.bndbox.xmax, + object.bndbox.ymax + ); + + labelInfo.objects.add(item); + } + + } + + return labelInfo; + } + + public List getLabelList() { + List list = new ArrayList<>(); + if (objectList == null) { + return list; + } + + for (Object object : objectList) { + if (StringUtil.isNotEmpty(object.name)) { + list.add(object.name); + } + } + return list; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Bndbox.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Bndbox.java new file mode 100644 index 000000000..dcb5f92ec --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Bndbox.java @@ -0,0 +1,37 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_set.image_data_set; + +/** + * @author zane + * @date 2021/11/8 + */ +public class Bndbox { + public int xmin; + public int xmax; + public int ymin; + public int ymax; + + public Bndbox() { + } + + public Bndbox(int xmin, int xmax, int ymin, int ymax) { + this.xmin = xmin; + this.xmax = xmax; + this.ymin = ymin; + this.ymax = ymax; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/LabelInfo.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/LabelInfo.java new file mode 100644 index 000000000..043b2d02a --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/LabelInfo.java @@ -0,0 +1,106 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_set.image_data_set; + +import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; +import com.welab.wefe.common.fieldvalidate.annotation.Check; +import com.welab.wefe.common.util.StringUtil; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * @author zane + * @date 2021/11/12 + */ +public class LabelInfo extends AbstractCheckModel { + @Check(name = "图片中标记的对象列表") + public List objects = new ArrayList<>(); + + public List labelList() { + List list = new ArrayList<>(); + if (objects == null || objects.isEmpty()) { + return list; + } + + list = objects + .stream() + .filter(x -> StringUtil.isNotEmpty(x.label)) + .map(x -> x.label) + .collect(Collectors.toList()); + + return list; + } + + /** + * 是否包含标注信息 + */ + public boolean isLabeled() { + if (objects == null || objects.isEmpty()) { + return false; + } + return objects.stream().anyMatch(x -> StringUtil.isNotEmpty(x.label)); + } + + public static class Item extends AbstractCheckModel { + + public String label; + /** + * 是否:难以识别的物体 + */ + public boolean difficult = false; + /** + * 是否:遮挡超过15-20% + */ + public boolean truncated = false; + public List points; + + public Item() { + } + + public Item(String label, int minX, int minY, int maxX, int maxY) { + this.label = label; + this.points = new ArrayList<>(); + this.points.add(new Point(minX, minY)); + this.points.add(new Point(maxX, maxY)); + } + + public Object toLabelObject() { + Object object = new Object(); + LabelInfo.Point point1 = points.get(0); + LabelInfo.Point point2 = points.get(1); + object.bndbox = new Bndbox(point1.x, point1.y, point2.x, point2.y); + object.name = label; + object.difficult = difficult ? 1 : 0; + object.truncated = truncated ? 1 : 0; + return object; + } + } + + public static class Point extends AbstractCheckModel { + public int x; + public int y; + + public Point() { + } + + public Point(int x, int y) { + this.x = x; + this.y = y; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Object.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Object.java new file mode 100644 index 000000000..52920219f --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Object.java @@ -0,0 +1,40 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_set.image_data_set; + +import com.thoughtworks.xstream.annotations.XStreamAlias; + +/** + * @author zane + * @date 2021/11/8 + */ +@XStreamAlias("object") +public class Object { + public String name; + /** + * 关于目标物体姿态描述(非必须字段) + */ + public String pose = "Unspecified"; + /** + * 如果物体的遮挡超过15-20%并且位于边界框之外,请标记为truncated(非必须字段) + */ + public int truncated; + /** + * 难以识别的物体标记为difficult(非必须字段) + */ + public int difficult; + public Bndbox bndbox; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Size.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Size.java new file mode 100644 index 000000000..9d65a5018 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Size.java @@ -0,0 +1,26 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_set.image_data_set; + +/** + * @author zane + * @date 2021/11/8 + */ +public class Size { + public int width; + public int height; + public int depth; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Source.java b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Source.java new file mode 100644 index 000000000..5082f41cc --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/dto/vo/data_set/image_data_set/Source.java @@ -0,0 +1,24 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.dto.vo.data_set.image_data_set; + +/** + * @author zane + * @date 2021/11/8 + */ +public class Source { + public String database = "Unknown"; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/exception/FlowNodeException.java b/board/board-service/src/main/java/com/welab/wefe/board/service/exception/FlowNodeException.java index 74f23dd3a..0c6218943 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/exception/FlowNodeException.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/exception/FlowNodeException.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/exception/MemberGatewayException.java b/board/board-service/src/main/java/com/welab/wefe/board/service/exception/MemberGatewayException.java index ed5744e7c..31bc6d3a0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/exception/MemberGatewayException.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/exception/MemberGatewayException.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/actuator/ClientActuator.java b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/actuator/ClientActuator.java new file mode 100644 index 000000000..1e273484b --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/actuator/ClientActuator.java @@ -0,0 +1,330 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.fusion.actuator; + + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.welab.wefe.board.service.api.project.fusion.actuator.psi.*; +import com.welab.wefe.board.service.dto.fusion.PsiMeta; +import com.welab.wefe.board.service.fusion.manager.ActuatorManager; +import com.welab.wefe.board.service.service.DataSetStorageService; +import com.welab.wefe.board.service.service.GatewayService; +import com.welab.wefe.board.service.service.fusion.FieldInfoService; +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.board.service.util.primarykey.FieldInfo; +import com.welab.wefe.board.service.util.primarykey.PrimaryKeyUtils; +import com.welab.wefe.common.data.storage.common.Constant; +import com.welab.wefe.common.data.storage.model.DataItemModel; +import com.welab.wefe.common.data.storage.model.PageInputModel; +import com.welab.wefe.common.data.storage.model.PageOutputModel; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.Base64Util; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.fusion.core.actuator.psi.AbstractPsiClientActuator; +import com.welab.wefe.fusion.core.dto.PsiActuatorMeta; +import com.welab.wefe.fusion.core.enums.FusionTaskStatus; +import com.welab.wefe.fusion.core.enums.PSIActuatorStatus; + +import java.util.List; +import java.util.concurrent.locks.ReentrantLock; + +/** + * @author hunter.zhao + */ +@SuppressWarnings("SynchronizeOnNonFinalField") +public class ClientActuator extends AbstractPsiClientActuator { + public List columnList; + + /** + * Fragment size, default 10000 + */ + public int shardSize = 10000; + public Integer currentIndex = 0; + public List fieldInfoList; + public String dstMemberId; + DataSetStorageService dataSetStorageService; + GatewayService gatewayService = Launcher.getBean(GatewayService.class); + + private String[] headers; + public Boolean serverIsReady = false; + private final ReentrantLock lock = new ReentrantLock(true); + + public ClientActuator(String businessId, String dataSetId, Boolean isTrace, String traceColumn, String dstMemberId, Long dataCount) { + super(businessId, dataSetId, isTrace, traceColumn, dataCount); + this.dstMemberId = dstMemberId; + } + + @Override + public void init() throws StatusCodeWithException { + FieldInfoService service = Launcher.getBean(FieldInfoService.class); + + columnList = service.columnList(businessId); + + + /** + * Calculate the fragment size based on the number of fields + */ + shardSize = shardSize / columnList.size(); + + /** + * Supplementary trace field + */ + if (isTrace) { + columnList.add(traceColumn); + } + + /** + * Find primary key composition fields + */ + fieldInfoList = service.fieldInfoList(businessId); + + /** + * Initialize dataset header + */ + dataSetStorageService = Launcher.CONTEXT.getBean(DataSetStorageService.class); + DataItemModel model = dataSetStorageService.getByKey( + Constant.DBName.WEFE_DATA, + dataSetStorageService.createRawDataSetTableName(dataSetId) + ".meta", + "header" + ); + headers = model.getV().toString().replace("\"", "").split(","); + } + + + @Override + public void close() throws Exception { + + //remove Actuator + ActuatorManager.remove(businessId); + + //update task status + FusionTaskService fusionTaskService = Launcher.CONTEXT.getBean(FusionTaskService.class); + switch (status) { + case success: + fusionTaskService.updateByBusinessId( + businessId, + FusionTaskStatus.Success, + dataCount, + fusionCount.longValue(), + processedCount.longValue(), + getSpend() + ); + break; + case falsify: + case running: + fusionTaskService.updateErrorByBusinessId( + businessId, + FusionTaskStatus.Interrupt, + dataCount, + fusionCount.longValue(), + processedCount.longValue(), + getSpend(), + error + ); + break; + default: + fusionTaskService.updateErrorByBusinessId( + businessId, + FusionTaskStatus.Failure, + dataCount, + fusionCount.longValue(), + processedCount.longValue(), + getSpend(), + error + ); + break; + } + } + + @Override + public void notifyServerClose() { + //notify the server that the task has ended + try { + gatewayService.callOtherMemberBoard( + dstMemberId, + ServerCloseApi.class, + new ServerCloseApi.Input(businessId, status.name(), error), + JSONObject.class); + } catch (Exception e) { + e.printStackTrace(); + } + } + + @Override + public List next() { + try { + lock.lock(); + long start = System.currentTimeMillis(); + + PageOutputModel model = dataSetStorageService.getListByPage( + Constant.DBName.WEFE_DATA, + dataSetStorageService.createRawDataSetTableName(dataSetId), + new PageInputModel(currentIndex, shardSize) + ); + + List list = model.getData(); + + List curList = Lists.newArrayList(); + list.forEach(x -> { + String[] values = x.getV().toString().split(","); + JObject jObject = JObject.create(); + for (int i = 0; i < headers.length; i++) { + if (columnList.contains(headers[i])) { + jObject.put(headers[i], values[i]); + } + } + curList.add(jObject); + }); + + + LOG.info("cursor {} spend: {} curList {} list {}", currentIndex, System.currentTimeMillis() - start, curList.size(), list.size()); + + currentIndex++; + + return curList; + + } finally { + lock.unlock(); + } + + } + + @Override + public void dump(List fruit) { + LOG.info("fruit insert ready..."); + + PsiDumpHelper.dump(businessId, columnList, fruit); + + LOG.info("fruit insert end..."); + } + + @Override + public Boolean hasNext() { + try { + lock.lock(); + PageOutputModel model = dataSetStorageService.getListByPage( + Constant.DBName.WEFE_DATA, + dataSetStorageService.createRawDataSetTableName(dataSetId), + new PageInputModel(currentIndex, shardSize) + ); + + LOG.info("currentIndex {} mode data size {}", currentIndex, model.getData().size()); + return model.getData().size() > 0; + } finally { + lock.unlock(); + } + + } + + @Override + public Integer sliceNumber() { + return dataCount.intValue() % shardSize == 0 ? dataCount.intValue() / shardSize + : dataCount.intValue() / shardSize + 1; + } + + @Override + public PsiActuatorMeta downloadBloomFilter() throws StatusCodeWithException { + + LOG.info("downloadBloomFilter start"); + + while (true) { + if (serverIsReady) { + break; + } + + try { + JSONObject result = gatewayService.callOtherMemberBoard( + dstMemberId, + ServerSynStatusApi.class, + new ServerSynStatusApi.Input(businessId), + JSONObject.class + ); + serverIsReady = result.getBoolean("ready"); + } catch (Exception e) { + LOG.error("请求合作方失败!错误原因: {}", e.getMessage()); + status = PSIActuatorStatus.exception; + } + } + + //调用gateway + JSONObject result = gatewayService.callOtherMemberBoard( + dstMemberId, + DownloadBFApi.class, + new DownloadBFApi.Input(businessId), + JSONObject.class + ); + + LOG.info("downloadBloomFilter end {} ", result); + + PsiActuatorMeta meta = JObject.toJavaObject(result, PsiActuatorMeta.class); + meta.setBfByDto(meta.getBfDto()); + return meta; + } + + @Override + public byte[][] queryFusionData(byte[][] bs) throws StatusCodeWithException { + + LOG.info("queryFusionData start"); + + //调用gateway + List stringList = Lists.newArrayList(); + for (int i = 0; i < bs.length; i++) { + stringList.add(Base64Util.encode(bs[i])); + } + + PsiMeta result = gatewayService.callOtherMemberBoard(dstMemberId, + PsiCryptoApi.class, + new PsiCryptoApi.Input(businessId, stringList), + PsiMeta.class + ); + + + List list = result.getBs(); + + byte[][] ss = new byte[list.size()][]; + for (int i = 0; i < list.size(); i++) { + ss[i] = Base64Util.base64ToByteArray(list.get(i)); + } + return ss; + } + + @Override + public void sendFusionData(List rs) { + List stringList = Lists.newArrayList(); + for (int i = 0; i < rs.size(); i++) { + stringList.add(Base64Util.encode(rs.get(i))); + } + + try { + gatewayService.callOtherMemberBoard( + dstMemberId, + ReceiveResultApi.class, + new ReceiveResultApi.Input(businessId, stringList) + ); + } catch (Exception e) { + LOG.info("sendFusionData error: ", e); + e.printStackTrace(); + } + } + + + @Override + public String hashValue(JObject value) { + return PrimaryKeyUtils.create(value, fieldInfoList); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/actuator/PsiDumpHelper.java b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/actuator/PsiDumpHelper.java new file mode 100644 index 000000000..a064b010a --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/actuator/PsiDumpHelper.java @@ -0,0 +1,73 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.fusion.actuator; + + + +import com.alibaba.fastjson.JSON; +import com.google.common.collect.Lists; +import com.welab.wefe.board.service.service.fusion.FusionResultStorageService; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.Launcher; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * @author hunter.zhao + */ +public class PsiDumpHelper { + + private static final FusionResultStorageService fusionResultStorageService; + + static { + fusionResultStorageService = Launcher.CONTEXT.getBean(FusionResultStorageService.class); + } + + private static void dumpHeaders(String businessId, List headers) { + //saveHeaderRow + if (fusionResultStorageService.isExists(fusionResultStorageService.createRawDataSetHeaderTableName(businessId))) { + return; + } + + fusionResultStorageService.saveHeaderRow(businessId, headers); + } + + public static void dump(String businessId, List headers, List fruit) { + + if (fruit.isEmpty()) { + return; + } + + dumpHeaders(businessId, headers); + + /** + * Fruit Standard formatting + */ + List> fruits = fruit. + stream(). + map(x -> { + List obj = Lists.newArrayList(); + for (Map.Entry column : x.entrySet()) { + obj.add(column.getValue()); + } + return obj; + }).collect(Collectors.toList()); + + fusionResultStorageService.saveDataRows(businessId, fruits); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/actuator/psi/ServerActuator.java b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/actuator/psi/ServerActuator.java new file mode 100644 index 000000000..e652fcea6 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/actuator/psi/ServerActuator.java @@ -0,0 +1,105 @@ +package com.welab.wefe.board.service.fusion.actuator.psi; + +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import com.alibaba.fastjson.JSON; +import com.google.common.collect.Lists; +import com.welab.wefe.board.service.fusion.actuator.PsiDumpHelper; +import com.welab.wefe.board.service.fusion.manager.ActuatorManager; +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.fusion.core.actuator.psi.AbstractPsiServerActuator; +import com.welab.wefe.fusion.core.enums.FusionTaskStatus; +import com.welab.wefe.fusion.core.utils.bf.BloomFilters; + +import java.math.BigInteger; +import java.util.List; + +/** + * @author hunter.zhao + */ +public class ServerActuator extends AbstractPsiServerActuator { + public ServerActuator(String businessId, BloomFilters bloomFilters, BigInteger n, BigInteger e, BigInteger d, BigInteger p, BigInteger q, Long dataCount) { + super(businessId, bloomFilters, n, e, d, p, q, dataCount); + } + + @Override + public void dump(List fruit) { + LOG.info("fruit insert ready..."); + + List headers = Lists.newArrayList(); + if (fruit.isEmpty()) { + return; + } + + for (String header : fruit.get(0).keySet()) { + headers.add(header); + } + + PsiDumpHelper.dump(businessId, headers, fruit); + + LOG.info("fruit insert end..."); + + System.out.println("测试结果:" + JSON.toJSONString(fruit)); + } + + @Override + public void close() throws Exception { + //remove Actuator + ActuatorManager.remove(businessId); + + //update task status + FusionTaskService fusionTaskService = Launcher.CONTEXT.getBean(FusionTaskService.class); + switch (status) { + case success: + fusionTaskService.updateByBusinessId( + businessId, + FusionTaskStatus.Success, + dataCount, + fusionCount.longValue(), + processedCount.longValue(), + getSpend() + ); + break; + case falsify: + case running: + fusionTaskService.updateErrorByBusinessId( + businessId, + FusionTaskStatus.Interrupt, + dataCount, + fusionCount.longValue(), + processedCount.longValue(), + getSpend(), + error + ); + break; + default: + fusionTaskService.updateErrorByBusinessId( + businessId, + FusionTaskStatus.Failure, + dataCount, + fusionCount.longValue(), + processedCount.longValue(), + getSpend(), + error + ); + break; + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/enums/ExportStatus.java b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/enums/ExportStatus.java new file mode 100644 index 000000000..773f46153 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/enums/ExportStatus.java @@ -0,0 +1,38 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.fusion.enums; + + +/** + * @author hunter.zhao + */ +public enum ExportStatus { + + /** + * 导出失败 + */ + failure, + + /** + * 导出成功 + */ + success, + + /** + * 正在导出 + */ + exporting; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/manager/ActuatorManager.java b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/manager/ActuatorManager.java new file mode 100644 index 000000000..bc96b9d4f --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/manager/ActuatorManager.java @@ -0,0 +1,184 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.fusion.manager; + +import com.google.common.collect.Lists; +import com.welab.wefe.board.service.database.entity.fusion.FusionActuatorInfoMySqlModel; +import com.welab.wefe.board.service.database.entity.fusion.FusionTaskMySqlModel; +import com.welab.wefe.board.service.database.repository.fusion.FusionActuatorInfoRepository; +import com.welab.wefe.board.service.fusion.actuator.ClientActuator; +import com.welab.wefe.board.service.fusion.actuator.psi.ServerActuator; +import com.welab.wefe.board.service.service.fusion.FusionTaskService; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.fusion.core.actuator.AbstractActuator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +/** + * @author hunter + */ +public class ActuatorManager { + public static final Logger LOG = LoggerFactory.getLogger(ActuatorManager.class); + + /** + * taskId : task + */ + private static final ConcurrentHashMap ACTUATORS = new ConcurrentHashMap<>(); + + public static AbstractActuator get(String businessId) { + + + return ACTUATORS.get(businessId); + } + + private static final FusionActuatorInfoRepository fusionActuatorInfoRepository; + private static final FusionTaskService fusionTaskService; + + static { + fusionActuatorInfoRepository = Launcher.CONTEXT.getBean(FusionActuatorInfoRepository.class); + fusionTaskService = Launcher.CONTEXT.getBean(FusionTaskService.class); + } + + public static void set(AbstractActuator task) { + + String businessId = task.getBusinessId(); + if (ACTUATORS.containsKey(businessId)) { + throw new RuntimeException(businessId + " This actuator already exists"); + } + + LOG.info("Set actuator successfully, businessId is {}", businessId); + ACTUATORS.put(businessId, task); + + } + + public synchronized static void remove(String businessId) { + + ACTUATORS.remove(businessId); + } + + public synchronized static List dashboard() throws StatusCodeWithException { + + List list = Lists.newArrayList(); + Object[] keys = ACTUATORS.keySet().toArray(); + + for (Object key : keys) { + + JObject info = getTaskInfo(key.toString()); + + if (info != null) { + list.add(info); + } + } + + return list; + } + + public static JObject getTaskInfo(String businessId) throws StatusCodeWithException { + AbstractActuator actuator = ACTUATORS.get(businessId); + if (actuator != null) { + return JObject + .create() + .append("business_id", businessId) + .append("fusion_count", actuator.getFusionCount()) + .append("processed_count", actuator.getProcessedCount()) + .append("data_count", actuator.getDataCount()) + .append("spend", actuator.getSpend()) + .append("status", "Running") + .append("estimated_spend", actuator.getEstimatedSpend()) + .append("progress", actuator.progress()); + } + + FusionTaskMySqlModel model = fusionTaskService.findByBusinessId(businessId); + if (model != null) { + return JObject + .create() + .append("business_id", businessId) + .append("fusion_count", model.getFusionCount()) + .append("processed_count", model.getProcessedCount()) + .append("data_count", model.getDataCount()) + .append("spend", model.getSpend()) + .append("status", model.getStatus()) + .append("progress", + Double.valueOf( + model.getProcessedCount().doubleValue() / model.getDataCount() * 100 + ).intValue() + ); + } + + return null; + } + + /** + * Number of tasks + * + * @return Number of tasks + */ + public static int size() { + return ACTUATORS.size(); + } + + public static void refresh(AbstractActuator actuator) { + if (actuator instanceof ClientActuator) { + FusionActuatorInfoMySqlModel info = new FusionActuatorInfoMySqlModel(); + info.setType(actuator.getClass().getSimpleName()); + info.setBusinessId(actuator.getBusinessId()); + info.setProgress(((ClientActuator) actuator).currentIndex); + fusionActuatorInfoRepository.save(info); + } else if (actuator instanceof ServerActuator) { + FusionActuatorInfoMySqlModel info = new FusionActuatorInfoMySqlModel(); + info.setType(actuator.getClass().getSimpleName()); + info.setBusinessId(actuator.getBusinessId()); + fusionActuatorInfoRepository.save(info); + } + } + + public static void main(String[] args) { + AbstractActuator actuator = new AbstractActuator("1") { + @Override + public void close() throws Exception { + + } + + @Override + public boolean isFinish() { + return false; + } + + @Override + public void init() throws StatusCodeWithException { + + } + + @Override + public void fusion() throws StatusCodeWithException { + + } + + @Override + public void dump(List fruit) { + + } + }; + + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/manager/ExportManager.java b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/manager/ExportManager.java new file mode 100644 index 000000000..1f4c4c1a7 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/fusion/manager/ExportManager.java @@ -0,0 +1,81 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.fusion.manager; + + + +import com.welab.wefe.board.service.database.entity.fusion.ExportProgressMySqlModel; +import com.welab.wefe.board.service.dto.fusion.FusionResultExportProgress; +import com.welab.wefe.board.service.service.fusion.ExportProgressService; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.web.util.ModelMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.ConcurrentHashMap; + +/** + * @author hunter.zhao + */ +public class ExportManager { + public static final Logger LOG = LoggerFactory.getLogger(ActuatorManager.class); + + /** + * businessId : JSON + */ + private static final ConcurrentHashMap EXPORT_TASK = new ConcurrentHashMap<>(); + + private static final ExportProgressService exportProgressService; + + static { + exportProgressService = Launcher.CONTEXT.getBean(ExportProgressService.class); + } + + public static FusionResultExportProgress get(String businessId) { + + if (EXPORT_TASK.get(businessId) == null) { + //直接查表 + ExportProgressMySqlModel model = exportProgressService.findLastByBusinessId(businessId); + return ModelMapper.map(model, FusionResultExportProgress.class); + } + + FusionResultExportProgress progress = EXPORT_TASK.get(businessId); + if (progress.getProgress() == 100) { + //remove; + romove(progress); + } + + return progress; + +// return EXPORT_TASK.get(businessId); + } + + public static void set(String businessId, FusionResultExportProgress progress) { + + if (EXPORT_TASK.containsKey(businessId)) { + throw new RuntimeException(" There are fusion tasks being exported"); + } + + LOG.info("Set FusionResultExportProgress successfully, businessId is {}", businessId); + EXPORT_TASK.put(businessId, progress); + } + + public static void romove(FusionResultExportProgress progress) { + exportProgressService.add(progress); + + EXPORT_TASK.remove(progress.getBusinessId()); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/listener/AppListener.java b/board/board-service/src/main/java/com/welab/wefe/board/service/listener/AppListener.java index bbd0ec973..fcd477771 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/listener/AppListener.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/listener/AppListener.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/listener/ApplicationReadyListener.java b/board/board-service/src/main/java/com/welab/wefe/board/service/listener/ApplicationReadyListener.java index c78db3eed..754b08fa6 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/listener/ApplicationReadyListener.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/listener/ApplicationReadyListener.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,9 +16,11 @@ package com.welab.wefe.board.service.listener; +import com.welab.wefe.board.service.dto.globalconfig.GatewayConfigModel; import com.welab.wefe.board.service.service.GatewayService; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import com.welab.wefe.common.util.HostUtil; +import com.welab.wefe.common.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -48,20 +50,26 @@ public void onApplicationEvent(ApplicationReadyEvent event) { } private void appendIpAddressToGatewayWhiteList() { + GatewayConfigModel gatewayConfig = globalConfigService.getGatewayConfig(); + if (gatewayConfig == null || StringUtil.isEmpty(gatewayConfig.intranetBaseUri)) { + LOG.error("gateway 内网地址尚未配置,board-service IP未登记到白名单。"); + return; + } + try { // Intranet IP String localIP = HostUtil.getLocalIp(); globalConfigService.appendIpToWhiteList( localIP, - "board 内网IP地址,由 board 自主上报。", + "board 内网IP地址,由 board 自主登记。", true ); - LOG.info("上报IP地址完成."); + LOG.info("登记IP到白名单完成."); // Notify the gateway to update the IP whitelist cache - gatewayService.refreshIpWhiteListCache(); + //gatewayService.refreshIpWhiteListCache(); } catch (Exception e) { - LOG.error("IP地址上报异常:", e); + LOG.error("IP地址登记到白名单异常:", e); } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/listener/ApplicationStartedListener.java b/board/board-service/src/main/java/com/welab/wefe/board/service/listener/ApplicationStartedListener.java index 6b6462f59..ab117c30d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/listener/ApplicationStartedListener.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/listener/ApplicationStartedListener.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/listener/AutoEncryptPhoneNumberListener.java b/board/board-service/src/main/java/com/welab/wefe/board/service/listener/AutoEncryptPhoneNumberListener.java new file mode 100644 index 000000000..3b4904843 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/listener/AutoEncryptPhoneNumberListener.java @@ -0,0 +1,68 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.listener; + +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.board.service.service.EncryptPhoneNumberService; +import com.welab.wefe.common.util.CommentedProperties; +import com.welab.wefe.common.util.StringUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.context.event.ApplicationStartedEvent; +import org.springframework.context.ApplicationListener; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.stereotype.Component; + +/** + * Auto encrypt mobile phone number + */ +@Component +public class AutoEncryptPhoneNumberListener implements ApplicationListener { + private static final Logger LOG = LoggerFactory.getLogger(AutoEncryptPhoneNumberListener.class); + + @Autowired + private ConfigurableEnvironment configurableEnvironment; + @Autowired + private EncryptPhoneNumberService encryptPhoneNumberService; + + @Autowired + private Config config; + + @Override + public void onApplicationEvent(ApplicationStartedEvent applicationStartedEvent) { + String configPath = configurableEnvironment.getProperty("config.path"); + if (StringUtil.isEmpty(configPath) || !config.isEncryptPhoneNumberOpen()) { + return; + } + String key = "has.auto.encrypt.phone.number"; + try { + CommentedProperties properties = new CommentedProperties(); + properties.load(configPath); + if (properties.containsKey(key)) { + return; + } + LOG.info("Start auto encrypt phone number........"); + encryptPhoneNumberService.encrypt(); + properties.append(key, "true", "Whether the mobile phone number has been automatically encrypted. The presence of this field indicates that it has been encrypted"); + properties.store(configPath); + LOG.info("End auto encrypt phone number!!!"); + } catch (Exception e) { + LOG.error("Auto encrypt phone number exception: ", e); + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/model/BaseFlowGraph.java b/board/board-service/src/main/java/com/welab/wefe/board/service/model/BaseFlowGraph.java index de1bbbf91..366ceb0ae 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/model/BaseFlowGraph.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/model/BaseFlowGraph.java @@ -1,4 +1,4 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,11 +22,11 @@ import com.welab.wefe.board.service.database.entity.job.JobMySqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectFlowNodeMySqlModel; import com.welab.wefe.board.service.exception.FlowNodeException; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.FederatedLearningType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; import org.apache.commons.collections4.CollectionUtils; import java.util.ArrayList; @@ -61,12 +61,15 @@ public abstract class BaseFlowGraph { * Nodes that will be executed in the current graph */ protected List jobSteps = new ArrayList<>(); + + protected String creatorMemberId; - public BaseFlowGraph(JobMySqlModel job, JobMySqlModel lastJob, List members, List mysqlNodes) throws StatusCodeWithException { + public BaseFlowGraph(JobMySqlModel job, JobMySqlModel lastJob, List members, List mysqlNodes, String creatorMemberId) throws StatusCodeWithException { this(job.getFederatedLearningType(), lastJob, mysqlNodes); this.job = job; this.members = members; + this.creatorMemberId = creatorMemberId; } @@ -383,5 +386,12 @@ public FederatedLearningType getFederatedLearningType() { return federatedLearningType; } + public String getCreatorMemberId() { + return creatorMemberId; + } + + public void setCreatorMemberId(String creatorMemberId) { + this.creatorMemberId = creatorMemberId; + } //endregion } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/model/FlowGraph.java b/board/board-service/src/main/java/com/welab/wefe/board/service/model/FlowGraph.java index f56ae3caa..b3d0703e2 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/model/FlowGraph.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/model/FlowGraph.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -24,9 +24,9 @@ import com.welab.wefe.board.service.database.entity.job.ProjectFlowNodeMySqlModel; import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.FederatedLearningType; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; import org.apache.commons.collections4.CollectionUtils; import java.util.ArrayList; @@ -39,8 +39,8 @@ public class FlowGraph extends BaseFlowGraph { - public FlowGraph(JobMySqlModel job, JobMySqlModel lastJob, List members, List mysqlNodes) throws StatusCodeWithException { - super(job, lastJob, members, mysqlNodes); + public FlowGraph(JobMySqlModel job, JobMySqlModel lastJob, List members, List mysqlNodes, String creatorMemberId) throws StatusCodeWithException { + super(job, lastJob, members, mysqlNodes, creatorMemberId); } public FlowGraph(FederatedLearningType federatedLearningType, List mysqlNodes) throws StatusCodeWithException { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/model/FlowGraphNode.java b/board/board-service/src/main/java/com/welab/wefe/board/service/model/FlowGraphNode.java index 0ab4dae93..c186c0a11 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/model/FlowGraphNode.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/model/FlowGraphNode.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,10 +20,10 @@ import com.welab.wefe.board.service.component.base.AbstractComponent; import com.welab.wefe.board.service.database.entity.job.JobMySqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectFlowNodeMySqlModel; -import com.welab.wefe.board.service.exception.FlowNodeException; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.fieldvalidate.AbstractCheckModel; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.apache.commons.lang3.StringUtils; import java.util.ArrayList; @@ -155,8 +155,8 @@ public AbstractCheckModel getParamsModel() { try { paramsModel = Components .get(super.getComponentType()) - .deserializationParam(this, super.getParams()); - } catch (FlowNodeException e) { + .deserializationParam(super.getParams()); + } catch (StatusCodeWithException e) { return null; } return paramsModel; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/onlinedemo/OnlineDemoBranchStrategy.java b/board/board-service/src/main/java/com/welab/wefe/board/service/onlinedemo/OnlineDemoBranchStrategy.java index 704ce5e15..1ed780c7d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/onlinedemo/OnlineDemoBranchStrategy.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/onlinedemo/OnlineDemoBranchStrategy.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -33,7 +33,7 @@ public class OnlineDemoBranchStrategy { public static void hackOnDelete(AbstractApiInput input, AbstractBaseMySqlModel model, String message) throws StatusCodeWithException { - Config config = Launcher.CONTEXT.getBean(Config.class); + Config config = Launcher.getBean(Config.class); if (!config.isOnlineDemo()) { return; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/onlinedemo/TianmiantechService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/onlinedemo/TianmiantechService.java index 7da6dce14..13d56cb4a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/onlinedemo/TianmiantechService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/onlinedemo/TianmiantechService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/operation/BoardApiLogger.java b/board/board-service/src/main/java/com/welab/wefe/board/service/operation/BoardApiLogger.java new file mode 100644 index 000000000..dba52de69 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/operation/BoardApiLogger.java @@ -0,0 +1,84 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.operation; + +import com.welab.wefe.board.service.api.account.CaptchaApi; +import com.welab.wefe.board.service.api.data_resource.upload_task.DataResourceUploadTaskDetailApi; +import com.welab.wefe.board.service.api.data_resource.upload_task.DataResourceUploadTaskQueryApi; +import com.welab.wefe.board.service.api.file.FileUploadApi; +import com.welab.wefe.board.service.api.file.MergeApi; +import com.welab.wefe.board.service.api.member.MemberAvailableCheckApi; +import com.welab.wefe.board.service.api.project.flow.FlowQueryApi; +import com.welab.wefe.board.service.api.project.job.GetJobProgressApi; +import com.welab.wefe.board.service.api.project.job.task.TaskProgressDetailApi; +import com.welab.wefe.board.service.api.project.member.audit.ProjectMemberAuditListApi; +import com.welab.wefe.board.service.api.service.ServiceAvailableApi; +import com.welab.wefe.board.service.database.entity.OperationLogMysqlModel; +import com.welab.wefe.board.service.database.repository.AccountRepository; +import com.welab.wefe.board.service.database.repository.OperationLogRepository; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.web.api.base.AbstractApi; +import com.welab.wefe.common.web.delegate.api_log.AbstractApiLogger; +import com.welab.wefe.common.web.delegate.api_log.ApiLog; +import org.springframework.stereotype.Component; + +import java.util.Arrays; +import java.util.List; + +/** + * @author zane + **/ +@Component +public class BoardApiLogger extends AbstractApiLogger { + + @Override + protected List> getIgnoreLogApiList() { + return Arrays.asList( + CaptchaApi.class, + FlowQueryApi.class, + ProjectMemberAuditListApi.class, + GetJobProgressApi.class, + ServiceAvailableApi.class, + MemberAvailableCheckApi.class, + TaskProgressDetailApi.class, + DataResourceUploadTaskQueryApi.class, + DataResourceUploadTaskDetailApi.class, + FileUploadApi.class, + MergeApi.class + ); + } + + @Override + protected void save(ApiLog apiLog) throws Exception { + OperationLogMysqlModel model = new OperationLogMysqlModel(); + model.setRequestTime(apiLog.getRequestTime()); + model.setRequestIp(apiLog.getCallerIp()); + model.setOperatorId(apiLog.getCallerId()); + model.setSpend(apiLog.getSpend()); + model.setLogInterface(apiLog.getApiName()); + model.setResultCode(apiLog.getResponseCode()); + model.setResultMessage(apiLog.getResponseMessage()); + + Launcher.getBean(OperationLogRepository.class).save(model); + } + + @Override + protected void updateAccountLastActionTime(String accountId) throws Exception { + Launcher.getBean(AccountRepository.class).updateLastActionTime(accountId); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/operation/OperationLogAfterApiExecute.java b/board/board-service/src/main/java/com/welab/wefe/board/service/operation/OperationLogAfterApiExecute.java deleted file mode 100644 index 1297743bb..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/operation/OperationLogAfterApiExecute.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.operation; - -import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.board.service.database.entity.OperationLogMysqlModel; -import com.welab.wefe.board.service.database.repository.OperationLogRepository; -import com.welab.wefe.common.CommonThreadPool; -import com.welab.wefe.common.util.StringUtil; -import com.welab.wefe.common.web.CurrentAccount; -import com.welab.wefe.common.web.CurrentAccount.Info; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.api.base.Api; -import com.welab.wefe.common.web.dto.ApiResult; -import com.welab.wefe.common.web.function.AfterApiExecuteFunction; -import com.welab.wefe.common.web.util.HttpServletRequestUtil; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Component; - -import javax.servlet.http.HttpServletRequest; -import java.util.Arrays; -import java.util.Date; - -/** - * @author eval - **/ -@Component -public class OperationLogAfterApiExecute implements AfterApiExecuteFunction { - - @Autowired - OperationLogRepository mOperationLogRepository; - - @Override - public void action(HttpServletRequest httpServletRequest, long start, AbstractApi api, JSONObject params, ApiResult result) { - final Info info = CurrentAccount.get(); - CommonThreadPool.run( - () -> log(httpServletRequest, start, api, result, info) - ); - } - - private static final String[] IGNORE_LOG_APIS = { - "project/flow/query", - "project/member/add/audit/list", - "flow/job/get_progress", - "member/service_status_check", - "task/progress/detail", - "data_set_task/query", - "data_set_task/detail", - "file/upload" - }; - - /** - * Check whether the request is ignored - */ - private boolean ignore(HttpServletRequest httpServletRequest, Api annotation) { - // Automatically refresh from the front end without writing logs. - if (httpServletRequest.getQueryString() != null) { - String value = httpServletRequest.getParameter("request-from-refresh"); - if (StringUtil.isNotEmpty(value) && "true".equals(value)) { - return true; - } - } - - // Blacklist, do not write logs. - String api = StringUtil.trim(annotation.path().toLowerCase(), '/', ' '); - return Arrays.asList(IGNORE_LOG_APIS).contains(api); - } - - private void log(HttpServletRequest httpServletRequest, long start, AbstractApi api, ApiResult result, Info info) { - if (info == null) { - return; - } - Api annotation = api.getClass().getAnnotation(Api.class); - - if (ignore(httpServletRequest, annotation)) { - return; - } - - String token = httpServletRequest.getHeader("token"); - String ip = HttpServletRequestUtil.getClientIp(httpServletRequest); - - OperationLogMysqlModel model = new OperationLogMysqlModel(); - model.setRequestTime(new Date(start)); - model.setToken(token); - model.setRequestIp(ip); - model.setOperatorId(info.getId()); - model.setOperatorPhone(info.getPhoneNumber()); - model.setSpend(result.spend); - - - String path = annotation.path(); - String name = annotation.name(); - model.setLogInterface(path); - if (StringUtil.isEmpty(name)) { - model.setInterfaceName(path); - } else { - model.setInterfaceName(name); - } - String action = path; - if (path.lastIndexOf("/") >= 0) { - action = path.substring(path.lastIndexOf("/") + 1); - } - model.setLogAction(action); - model.setResultCode(result.code); - model.setResultMessage(result.message); - - mOperationLogRepository.save(model); - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/operation/OperationLogBeforeApiExecute.java b/board/board-service/src/main/java/com/welab/wefe/board/service/operation/OperationLogBeforeApiExecute.java deleted file mode 100644 index b781c1da3..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/operation/OperationLogBeforeApiExecute.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.operation; - -import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.common.web.api.base.AbstractApi; -import com.welab.wefe.common.web.function.BeforeApiExecuteFunction; - -/** - * User operation log - * - * @author eval - **/ -public class OperationLogBeforeApiExecute implements BeforeApiExecuteFunction { - - @Override - public void action(AbstractApi api, JSONObject params) { - - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/scheduled/AccountScheduledService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/scheduled/AccountScheduledService.java new file mode 100644 index 000000000..1aa4a021b --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/scheduled/AccountScheduledService.java @@ -0,0 +1,52 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.scheduled; + +import com.welab.wefe.board.service.database.repository.AccountRepository; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Lazy; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +/** + * 对 account 表的定时任务 + * + * @author zane + * @date 2022/03/16 + */ +@Component +@Lazy(false) +public class AccountScheduledService { + protected final Logger LOG = LoggerFactory.getLogger(this.getClass()); + + @Autowired + private AccountRepository accountRepository; + + @Scheduled(fixedDelay = 600_000, initialDelay = 10_000) + //@Scheduled(fixedDelay = 5_000, initialDelay = 1_000) + public void run() { + + LOG.info("begin disableAccountWithoutAction90Days..."); + int count = accountRepository.disableAccountWithoutAction90Days(); + LOG.info("end disableAccountWithoutAction90Days:" + count); + + LOG.info("begin cancelAccountWithoutAction180Days..."); + count = accountRepository.cancelAccountWithoutAction180Days(); + LOG.info("end cancelAccountWithoutAction180Days:" + count); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/scheduled/OnlineDemoScheduledService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/scheduled/OnlineDemoScheduledService.java index b539b772f..edbf01935 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/scheduled/OnlineDemoScheduledService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/scheduled/OnlineDemoScheduledService.java @@ -18,7 +18,8 @@ import com.welab.wefe.board.service.constant.Config; import com.welab.wefe.board.service.database.entity.OperationLogMysqlModel; import com.welab.wefe.board.service.database.entity.base.AbstractMySqlModel; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; import com.welab.wefe.board.service.database.entity.job.JobMySqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectDataSetMySqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectFlowMySqlModel; @@ -135,11 +136,21 @@ public void clean() { ); /** - * 清理 data_set + * 清理 table_data_set * 1. 无项目引用的删掉 */ delete( - DataSetMysqlModel.class, + TableDataSetMysqlModel.class, + commonWhere + + "and id not in (select data_set_id from project_data_set)" + ); + + /** + * 清理 image_data_set + * 1. 无项目引用的删掉 + */ + delete( + ImageDataSetMysqlModel.class, commonWhere + "and id not in (select data_set_id from project_data_set)" ); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/AbstractUnionService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/AbstractUnionService.java new file mode 100644 index 000000000..55ee57d25 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/AbstractUnionService.java @@ -0,0 +1,381 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.sdk; + + +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONException; +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.api.union.MemberListApi; +import com.welab.wefe.board.service.api.union.member_auth.MemberRealnameAuthApi; +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.board.service.dto.globalconfig.MemberInfoModel; +import com.welab.wefe.board.service.service.AbstractService; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.constant.SecretKeyType; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.http.HttpContentType; +import com.welab.wefe.common.http.HttpRequest; +import com.welab.wefe.common.http.HttpResponse; +import com.welab.wefe.common.util.*; +import net.jodah.expiringmap.ExpiringMap; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.mime.content.InputStreamBody; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MultipartFile; + +import java.io.IOException; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; + + +/** + * @author Zane + */ +public abstract class AbstractUnionService extends AbstractService { + + /** + * cache + */ + protected static final ExpiringMap CACHE_MAP = ExpiringMap + .builder() + .expiration(60, TimeUnit.SECONDS) + .maxSize(500) + .build(); + + @Autowired + protected Config config; + + @Autowired + protected GlobalConfigService globalConfigService; + + /** + * initialize wefe system + */ + public void initializeSystem(MemberInfoModel model) throws StatusCodeWithException { + JObject params = JObject + .create() + .put("id", model.getMemberId()) + .put("member_id", model.getMemberId()) + .put("name", model.getMemberName()) + .put("mobile", model.getMemberMobile()) + .put("allow_open_data_set", model.getMemberAllowPublicDataSet()) + .put("public_key", model.getRsaPublicKey()) + .put("email", model.getMemberEmail()) + .put("gateway_uri", model.getMemberGatewayUri()) + .put("logo", model.getMemberLogo()) + .put("hidden", model.getMemberHidden()) + .put("secret_key_type", null == model.getSecretKeyType() ? SecretKeyType.rsa.name() : model.getSecretKeyType().name()); + + request("member/add", params, false); + } + + /** + * Report member information + */ + public void uploadMemberInfo(MemberInfoModel model) throws StatusCodeWithException { + JObject params = JObject + .create() + .put("id", model.getMemberId()) + .put("name", model.getMemberName()) + .put("mobile", model.getMemberMobile()) + .put("allow_open_data_set", model.getMemberAllowPublicDataSet()) + .put("public_key", model.getRsaPublicKey()) + .put("email", model.getMemberEmail()) + .put("gateway_uri", model.getMemberGatewayUri()) + .put("logo", model.getMemberLogo()) + .put("hidden", model.getMemberHidden()); + + request("member/update", params); + } + + + /** + * Reset key + */ + public void resetPublicKey(MemberInfoModel model) throws StatusCodeWithException { + JObject params = JObject + .create() + .put("id", model.getMemberId()) + .put("public_key", model.getRsaPublicKey()); + + request("member/update_public_key", params); + } + + /** + * Update member information (not including logo) + */ + public void uploadMemberInfoExcludeLogo(MemberInfoModel model) throws StatusCodeWithException { + JObject params = JObject + .create() + .put("id", model.getMemberId()) + .put("name", model.getMemberName()) + .put("mobile", model.getMemberMobile()) + .put("allow_open_data_set", model.getMemberAllowPublicDataSet()) + .put("public_key", model.getRsaPublicKey()) + .put("email", model.getMemberEmail()) + .put("gateway_uri", model.getMemberGatewayUri()) + .put("hidden", model.getMemberHidden()) + .put("secret_key_type", null == model.getSecretKeyType() ? SecretKeyType.rsa.name() : model.getSecretKeyType().name()); + + request("member/update_exclude_logo", params); + } + + /** + * Update member information logo + */ + public void updateMemberLogo(MemberInfoModel model) throws StatusCodeWithException { + JObject params = JObject + .create() + .put("id", model.getMemberId()) + .put("logo", model.getMemberLogo()); + + request("member/update_logo", params); + } + + /** + * Pagination query member + */ + public synchronized JSONObject queryMembers(MemberListApi.Input input) throws StatusCodeWithException { + + String key = "queryMembers" + JSON.toJSONString(input); + if (CACHE_MAP.containsKey(key)) { + return (JSONObject) CACHE_MAP.get(key); + } + + JObject params = JObject + .create() + .put("page_index", input.getPageIndex()) + .put("page_size", input.getPageSize()) + .put("name", input.getName()) + .put("id", input.getId()); + + JSONObject response = request("member/query", params); + CACHE_MAP.put(key, response); + return response; + } + + public JSONObject queryMemberById(String id) throws StatusCodeWithException { + return queryMember(id, ""); + } + + public JSONObject queryMember(String id, String name) throws StatusCodeWithException { + return queryMemberByPage(0, 0, id, name); + } + + public JSONObject queryMember(int pageIndex, int pageSize) throws StatusCodeWithException { + return queryMemberByPage(pageIndex, pageSize, "", ""); + } + + public JSONObject queryMemberByPage(int pageIndex, int pageSize, String id, String name) throws StatusCodeWithException { + JObject params = JObject.create() + .put("page_index", pageIndex) + .put("page_size", pageSize); + + if (StringUtil.isNotEmpty(id)) { + params.put("id", id); + } + + if (StringUtil.isNotEmpty(name)) { + params.put("name", name); + } + + return request("member/query", params); + } + + + public JSONObject request(String api) throws StatusCodeWithException { + return request(api, null, true); + } + + public JSONObject request(String api, JSONObject params) throws StatusCodeWithException { + return request(api, params, true); + } + + protected JSONObject request(String api, JSONObject params, boolean needSign) throws StatusCodeWithException { + if (params == null) { + params = new JSONObject(); + } + /** + * Prevent the map from being out of order, causing the verification to fail. + */ + params = new JSONObject(new TreeMap(params)); + + String data = params.toJSONString(); + + // rsa signature + if (needSign) { + String sign = null; + try { + SecretKeyType secretKeyType = CacheObjects.getSecretKeyType(); + // sign = RSAUtil.sign(data, CacheObjects.getRsaPrivateKey(), "UTF-8"); + sign = SignUtil.sign(data, CacheObjects.getRsaPrivateKey(), secretKeyType); + } catch (Exception e) { + throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); + } + + + JSONObject body = new JSONObject(); + body.put("member_id", CacheObjects.getMemberId()); + body.put("sign", sign); + body.put("data", data); + + data = body.toJSONString(); + } + + HttpResponse response = HttpRequest + .create(config.getUnionBaseUrl() + "/" + api) + .setBody(data) + .postJson(); + + if (!response.success()) { + throw new StatusCodeWithException(response.getMessage(), StatusCode.RPC_ERROR); + } + + JSONObject json; + try { + json = response.getBodyAsJson(); + } catch (JSONException e) { + throw new StatusCodeWithException("union 响应失败:" + response.getBodyAsString(), StatusCode.RPC_ERROR); + } + + if (json == null) { + throw new StatusCodeWithException("union 响应失败:" + response.getBodyAsString(), StatusCode.RPC_ERROR); + } + + Integer code = json.getInteger("code"); + if (code == null || !code.equals(0)) { + throw new StatusCodeWithException("union 响应失败(" + code + "):" + json.getString("message"), StatusCode.RPC_ERROR); + } + return json; + } + + + public JSONObject queryMemberAuthTypeList() throws StatusCodeWithException { + return request("member/authtype/query", JObject.create(), true); + } + + public JSONObject realnameAuth(MemberRealnameAuthApi.Input input) throws StatusCodeWithException { + return request("member/realname/auth", JObject.create(input), true); + } + + public JSONObject realnameAuthInfoQuery() throws StatusCodeWithException { + return request("member/realname/authInfo/query", JObject.create(), true); + } + + + public JSONObject realnameAuthAgreementTemplateQuery() throws StatusCodeWithException { + return request("realname/auth/agreement/template/query", JObject.create(), true); + } + + public JSONObject uploadFile(MultiValueMap files, JObject params) throws StatusCodeWithException { + + return request("member/file/upload", params, files, true); + } + + private JSONObject request(String api, JSONObject params, MultiValueMap files, boolean needSign) throws StatusCodeWithException { + /** + * Prevent the map from being out of order, causing the verification to fail. + */ + params = new JSONObject(new TreeMap(params)); + + String data = params.toJSONString(); + String sign = null; + // rsa signature + JSONObject body = new JSONObject(); + if (needSign) { + try { + sign = RSAUtil.sign(data, CacheObjects.getRsaPrivateKey(), "UTF-8"); + } catch (Exception e) { + e.printStackTrace(); + throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); + } + + + body.put("member_id", CacheObjects.getMemberId()); + body.put("sign", sign); + body.put("data", data); + + data = body.toJSONString(); + } + HttpResponse response; + String url = config.getUnionBaseUrl() + "/" + api; + // send http request without files + if (files == null) { + response = HttpRequest + .create(url) + .setBody(data) + .postJson(); + } + // send http request with files + else { + url = UrlUtil.appendQueryParameters(url, body); + HttpRequest request = HttpRequest + .create(url) + .setContentType(HttpContentType.MULTIPART); + + for (Map.Entry item : files.toSingleValueMap().entrySet()) { + try { + MultipartFile file = item.getValue(); + ContentType contentType = StringUtil.isEmpty(file.getContentType()) + ? ContentType.DEFAULT_BINARY + : ContentType.create(file.getContentType()); + + InputStreamBody streamBody = new InputStreamBody( + file.getInputStream(), + contentType, + file.getOriginalFilename() + ); + + + request.appendParameter(item.getKey(), streamBody); + } catch (IOException e) { + StatusCode.FILE_IO_ERROR.throwException(e); + } + } + + response = request.post(); + } + + + if (!response.success()) { + throw new StatusCodeWithException(response.getMessage(), StatusCode.RPC_ERROR); + } + + JSONObject json; + try { + json = response.getBodyAsJson(); + } catch (JSONException e) { + throw new StatusCodeWithException("union 响应失败:" + response.getBodyAsString(), StatusCode.RPC_ERROR); + } + + if (json == null) { + throw new StatusCodeWithException("union 响应失败:" + response.getBodyAsString(), StatusCode.RPC_ERROR); + } + + Integer code = json.getInteger("code"); + if (code == null || !code.equals(0)) { + throw new StatusCodeWithException("union 响应失败(" + code + "):" + json.getString("message"), StatusCode.RPC_ERROR); + } + return json; + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/FlowService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/FlowService.java index 69668a087..4de7f9b73 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/FlowService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/FlowService.java @@ -1,11 +1,11 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -27,6 +27,7 @@ import com.welab.wefe.common.http.HttpResponse; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.checkpoint.dto.ServiceAvailableCheckOutput; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -38,8 +39,13 @@ public class FlowService extends AbstractService { @Autowired private GlobalConfigService globalConfigService; - public JObject dashboard() throws StatusCodeWithException { - return request("/flow/dashboard", null); + public ServiceAvailableCheckOutput getAvailable() throws StatusCodeWithException { + return request("/service/available", null) + .toJavaObject(ServiceAvailableCheckOutput.class); + } + + public JObject alive() throws StatusCodeWithException { + return request("/service/alive", null); } private JObject request(String api, JSONObject params) throws StatusCodeWithException { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/PaddleVisualService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/PaddleVisualService.java new file mode 100644 index 000000000..92b9eb344 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/PaddleVisualService.java @@ -0,0 +1,83 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.sdk; + +import com.alibaba.fastjson.JSONException; +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.dto.globalconfig.DeepLearningConfigModel; +import com.welab.wefe.board.service.service.AbstractService; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.http.HttpRequest; +import com.welab.wefe.common.http.HttpResponse; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.util.StringUtil; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +/** + * @author zane + * @date 2022/2/19 + */ +@Service +public class PaddleVisualService extends AbstractService { + @Autowired + private GlobalConfigService globalConfigService; + + public JObject infer(JSONObject params) throws StatusCodeWithException { + return request("/infer", params); + } + + private JObject request(String api, JSONObject params) throws StatusCodeWithException { + DeepLearningConfigModel deepLearningConfig = globalConfigService.getDeepLearningConfig(); + + if (deepLearningConfig == null || StringUtil.isEmpty(deepLearningConfig.paddleVisualDlBaseUrl)) { + StatusCode.RPC_ERROR.throwException("尚未设置VisualFL服务地址,请在[全局设置][计算引擎设置]中设置VisualFL服务地址。"); + } + + if (params == null) { + params = new JSONObject(); + } + String data = params.toJSONString(); + + if (!api.startsWith("/")) { + api = "/" + api; + } + + HttpResponse response = HttpRequest + .create(deepLearningConfig.paddleVisualDlBaseUrl + api) + .setBody(data) + .postJson(); + + if (!response.success()) { + StatusCode.RPC_ERROR.throwException(response.getMessage()); + } + + JObject json; + try { + json = new JObject(response.getBodyAsJson()); + } catch (JSONException e) { + throw new StatusCodeWithException("paddle 响应失败:" + response.getBodyAsString(), StatusCode.RPC_ERROR); + } + + Integer code = json.getInteger("code"); + if (code == null || !code.equals(200)) { + throw new StatusCodeWithException("paddle 响应失败(" + code + "):" + json.getString("message"), StatusCode.RPC_ERROR); + } + return json; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/UnionService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/UnionService.java deleted file mode 100644 index 02773dbf6..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/UnionService.java +++ /dev/null @@ -1,446 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.sdk; - - -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.JSONException; -import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.board.service.api.union.DataSetTagListApi; -import com.welab.wefe.board.service.api.union.MemberListApi; -import com.welab.wefe.board.service.api.union.QueryDataSetApi; -import com.welab.wefe.board.service.api.union.TagListApi; -import com.welab.wefe.board.service.constant.Config; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; -import com.welab.wefe.board.service.dto.globalconfig.MemberInfoModel; -import com.welab.wefe.board.service.service.AbstractService; -import com.welab.wefe.board.service.service.CacheObjects; -import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; -import com.welab.wefe.common.CommonThreadPool; -import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.DataSetPublicLevel; -import com.welab.wefe.common.enums.SmsBusinessType; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.http.HttpRequest; -import com.welab.wefe.common.http.HttpResponse; -import com.welab.wefe.common.util.JObject; -import com.welab.wefe.common.util.RSAUtil; -import com.welab.wefe.common.util.StringUtil; -import net.jodah.expiringmap.ExpiringMap; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; - -import java.util.TreeMap; -import java.util.concurrent.TimeUnit; - - -/** - * @author Zane - */ -@Service -public class UnionService extends AbstractService { - - /** - * cache - */ - private static final ExpiringMap CACHE_MAP = ExpiringMap - .builder() - .expiration(60, TimeUnit.SECONDS) - .maxSize(500) - .build(); - - @Autowired - private Config config; - - @Autowired - private GlobalConfigService globalConfigService; - - /** - * initialize wefe system - */ - public void initializeSystem(MemberInfoModel model) throws StatusCodeWithException { - JObject params = JObject - .create() - .put("id", model.getMemberId()) - .put("member_id", model.getMemberId()) - .put("name", model.getMemberName()) - .put("mobile", model.getMemberMobile()) - .put("allow_open_data_set", model.getMemberAllowPublicDataSet()) - .put("public_key", model.getRsaPublicKey()) - .put("email", model.getMemberEmail()) - .put("gateway_uri", model.getMemberGatewayUri()) - .put("logo", model.getMemberLogo()) - .put("hidden", model.getMemberHidden()); - - request("member/add", params, false); - } - - /** - * Report member information - */ - public void uploadMemberInfo(MemberInfoModel model) throws StatusCodeWithException { - JObject params = JObject - .create() - .put("id", model.getMemberId()) - .put("name", model.getMemberName()) - .put("mobile", model.getMemberMobile()) - .put("allow_open_data_set", model.getMemberAllowPublicDataSet()) - .put("public_key", model.getRsaPublicKey()) - .put("email", model.getMemberEmail()) - .put("gateway_uri", model.getMemberGatewayUri()) - .put("logo", model.getMemberLogo()) - .put("hidden", model.getMemberHidden()); - - request("member/update", params); - } - - - /** - * Reset key - */ - public void resetPublicKey(MemberInfoModel model) throws StatusCodeWithException { - JObject params = JObject - .create() - .put("id", model.getMemberId()) - .put("public_key", model.getRsaPublicKey()); - - request("member/update_public_key", params); - } - - /** - * Update member information (not including logo) - */ - public void uploadMemberInfoExcludeLogo(MemberInfoModel model) throws StatusCodeWithException { - JObject params = JObject - .create() - .put("id", model.getMemberId()) - .put("name", model.getMemberName()) - .put("mobile", model.getMemberMobile()) - .put("allow_open_data_set", model.getMemberAllowPublicDataSet()) - .put("public_key", model.getRsaPublicKey()) - .put("email", model.getMemberEmail()) - .put("gateway_uri", model.getMemberGatewayUri()) - .put("hidden", model.getMemberHidden()); - - request("member/update_exclude_logo", params); - } - - /** - * Update member information logo - */ - public void updateMemberLogo(MemberInfoModel model) throws StatusCodeWithException { - JObject params = JObject - .create() - .put("id", model.getMemberId()) - .put("logo", model.getMemberLogo()); - - request("member/update_logo", params); - } - - /** - * Report data set information - */ - public void uploadDataSet(DataSetMysqlModel model) throws StatusCodeWithException { - MemberInfoModel member = globalConfigService.getMemberInfo(); - // If data exposure is prohibited globally, it will not be reported. - if (!member.getMemberAllowPublicDataSet()) { - return; - } - - // If this data set is not publicly available to anyone - if (model.getPublicLevel() == DataSetPublicLevel.OnlyMyself) { - // Notify union to remove the data set - dontPublicDataSet(model.getId()); - return; - } - - // Push data set information to union - JObject params = JObject - .create() - .put("id", model.getId()) - .put("name", model.getName()) - .put("member_id", CacheObjects.getMemberId()) - .put("contains_y", model.getContainsY()) - .put("row_count", model.getRowCount()) - .put("column_count", model.getColumnCount()) - .put("column_name_list", model.getColumnNameList()) - .put("feature_count", model.getFeatureCount()) - .put("feature_name_list", model.getFeatureNameList()) - .put("public_level", model.getPublicLevel()) - .put("public_member_list", model.getPublicMemberList()) - .put("usage_count_in_job", model.getUsageCountInJob()) - .put("usage_count_in_flow", model.getUsageCountInFlow()) - .put("usage_count_in_project", model.getUsageCountInProject()) - .put("tags", model.getTags()) - .put("description", model.getDescription()); - - CommonThreadPool.run(() -> { - try { - request("data_set/put", params); - } catch (StatusCodeWithException e) { - super.log(e); - } - }); - } - - /** - * Hidden data set - */ - public void dontPublicDataSet(String dataSetId) throws StatusCodeWithException { - JObject params = JObject - .create() - .put("id", dataSetId); - - request("data_set/delete", params); - } - - /** - * Pagination query member - */ - public synchronized JSONObject queryMembers(MemberListApi.Input input) throws StatusCodeWithException { - - String key = "queryMembers" + JSON.toJSONString(input); - if (CACHE_MAP.containsKey(key)) { - return (JSONObject) CACHE_MAP.get(key); - } - - JObject params = JObject - .create() - .put("page_index", input.getPageIndex()) - .put("page_size", input.getPageSize()) - .put("name", input.getName()) - .put("id", input.getId()); - - JSONObject response = request("member/query", params); - CACHE_MAP.put(key, response); - return response; - } - - public JSONObject queryMemberById(String id) throws StatusCodeWithException { - return queryMember(id, ""); - } - - public JSONObject queryMember(String id, String name) throws StatusCodeWithException { - return queryMemberByPage(0, 0, id, name); - } - - public JSONObject queryMember(int pageIndex, int pageSize) throws StatusCodeWithException { - return queryMemberByPage(pageIndex, pageSize, "", ""); - } - - public JSONObject queryMemberByPage(int pageIndex, int pageSize, String id, String name) throws StatusCodeWithException { - JObject params = JObject.create() - .put("page_index", pageIndex) - .put("page_size", pageSize); - - if (StringUtil.isNotEmpty(id)) { - params.put("id", id); - } - - if (StringUtil.isNotEmpty(name)) { - params.put("name", name); - } - - return request("member/query", params); - } - - /** - * Paging query data set tag - */ - public JSONObject queryDataSetTags(DataSetTagListApi.Input input) throws StatusCodeWithException { - String key = "queryDataSetTags" + JSON.toJSONString(input); - if (CACHE_MAP.containsKey(key)) { - return (JSONObject) CACHE_MAP.get(key); - } - - JObject params = JObject - .create() - .put("page_index", input.getPageIndex()) - .put("page_size", input.getPageSize()) - .put("tag_name", input.getTag()); - - JSONObject response = request("data_set/tags/query", params); - CACHE_MAP.put(key, response); - return response; - } - - /** - * Pagination query default tags - */ - public JSONObject queryTags(TagListApi.Input input) throws StatusCodeWithException { - - String key = "queryTags" + JSON.toJSONString(input); - if (CACHE_MAP.containsKey(key)) { - return (JSONObject) CACHE_MAP.get(key); - } - - JObject params = JObject - .create() - .put("page_index", input.getPageIndex()) - .put("page_size", input.getPageSize()); - - JSONObject response = request("default_tag/query", params); - CACHE_MAP.put(key, response); - return response; - } - - /** - * Paging query data set - */ - public JSONObject queryDataSets(QueryDataSetApi.Input input) throws StatusCodeWithException { - JObject params = JObject - .create() - .put("page_index", input.getPageIndex()) - .put("page_size", input.getPageSize()) - .put("id", input.getId()) - .put("tag", input.getTag()) - .put("name", input.getName()) - .put("contains_y", input.getContainsY()) - .put("member_id", input.getMemberId()); - - return request("data_set/query", params); - } - - /** - * Get details of a single data set - */ - public DataSetOutputModel queryDataSetDetail(String id) throws StatusCodeWithException { - - if (CACHE_MAP.containsKey(id)) { - return (DataSetOutputModel) CACHE_MAP.get(id); - } - - JObject params = JObject - .create() - .put("id", id); - - JSONObject result = request("data_set/detail", params); - - JSONObject data = result.getJSONObject("data"); - - if (data == null || data.isEmpty()) { - return null; - } - - return data.toJavaObject(DataSetOutputModel.class); - } - - public void sendVerificationCode(String mobile, SmsBusinessType smsBusinessType) throws StatusCodeWithException { - if (!StringUtil.checkPhoneNumber(mobile)) { - throw new StatusCodeWithException("非法的手机号", StatusCode.PARAMETER_VALUE_INVALID); - } - JObject params = JObject.create() - .append("mobile", mobile) - .append("smsBusinessType", smsBusinessType); - try { - request("sms/send_verification_code", params, true); - } catch (StatusCodeWithException e) { - throw new StatusCodeWithException(getUnionOrigExceptionMsg(e), StatusCode.SYSTEM_ERROR); - } catch (Exception e) { - throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); - } - } - - /** - * Check verification code - */ - public void checkVerificationCode(String mobile, String code, SmsBusinessType smsBusinessType) throws StatusCodeWithException { - JObject params = JObject.create() - .append("mobile", mobile) - .append("code", code) - .append("smsBusinessType", smsBusinessType); - try { - request("sms/check_verification_code", params, true); - } catch (StatusCodeWithException e) { - throw new StatusCodeWithException(getUnionOrigExceptionMsg(e), StatusCode.SYSTEM_ERROR); - } catch (Exception e) { - throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); - } - } - - private String getUnionOrigExceptionMsg(StatusCodeWithException e) { - String errorMsg = e.getMessage(); - if (StringUtil.isNotEmpty(errorMsg)) { - int index = errorMsg.indexOf(":"); - if (index != -1) { - errorMsg = errorMsg.substring(index + 1); - } - } - return errorMsg; - } - - private JSONObject request(String api, JSONObject params) throws StatusCodeWithException { - return request(api, params, true); - } - - private JSONObject request(String api, JSONObject params, boolean needSign) throws StatusCodeWithException { - /** - * Prevent the map from being out of order, causing the verification to fail. - */ - params = new JSONObject(new TreeMap(params)); - - String data = params.toJSONString(); - - // rsa signature - if (needSign) { - String sign = null; - try { - sign = RSAUtil.sign(data, CacheObjects.getRsaPrivateKey(), "UTF-8"); - } catch (Exception e) { - e.printStackTrace(); - throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); - } - - - JSONObject body = new JSONObject(); - body.put("member_id", CacheObjects.getMemberId()); - body.put("sign", sign); - body.put("data", data); - - data = body.toJSONString(); - } - - HttpResponse response = HttpRequest - .create(config.getUNION_BASE_URL() + "/" + api) - .setBody(data) - .postJson(); - - if (!response.success()) { - throw new StatusCodeWithException(response.getMessage(), StatusCode.RPC_ERROR); - } - - JSONObject json; - try { - json = response.getBodyAsJson(); - } catch (JSONException e) { - throw new StatusCodeWithException("union 响应失败:" + response.getBodyAsString(), StatusCode.RPC_ERROR); - } - - if (json == null) { - throw new StatusCodeWithException("union 响应失败:" + response.getBodyAsString(), StatusCode.RPC_ERROR); - } - - Integer code = json.getInteger("code"); - if (code == null || !code.equals(0)) { - throw new StatusCodeWithException("union 响应失败(" + code + "):" + json.getString("message"), StatusCode.RPC_ERROR); - } - return json; - } - - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/union/UnionService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/union/UnionService.java new file mode 100644 index 000000000..af962eb14 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/union/UnionService.java @@ -0,0 +1,189 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.sdk.union; + + +import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.BloomFilterOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.ImageDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.TableDataSetOutputModel; +import com.welab.wefe.board.service.dto.globalconfig.MemberInfoModel; +import com.welab.wefe.board.service.sdk.AbstractUnionService; +import com.welab.wefe.board.service.sdk.union.dto.MemberBaseInfo; +import com.welab.wefe.common.CommonThreadPool; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.checkpoint.dto.ServiceAvailableCheckOutput; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DataResourcePublicLevel; +import org.springframework.stereotype.Service; + +import java.util.LinkedHashMap; + + +/** + * @author Zane + */ +@Service +public class UnionService extends AbstractUnionService { + + public LinkedHashMap getMemberMap() throws StatusCodeWithException { + JSONObject result = request("member/map"); + JSONObject data = result.getJSONObject("data"); + + LinkedHashMap map = new LinkedHashMap<>(); + for (String memberId : data.keySet()) { + map.put( + memberId, + data.getJSONObject(memberId).toJavaObject(MemberBaseInfo.class) + ); + } + + return map; + } + + public ServiceAvailableCheckOutput getAvailable() throws StatusCodeWithException { + JSONObject result = request("service/available"); + + return result + .getJSONObject("data") + .toJavaObject(ServiceAvailableCheckOutput.class); + } + + /** + * 更新资源信息,使用此接口更新时,数据不会立即更新,有延迟。 + */ + public void lazyUpdateDataResource(DataResourceMysqlModel model) throws StatusCodeWithException { + MemberInfoModel member = globalConfigService.getMemberInfo(); + if (!member.getMemberAllowPublicDataSet() || member.getMemberHidden()) { + return; + } + + CommonThreadPool.run(() -> { + try { + JObject params = JObject + .create(model) + .append("data_resource_id", model.getId()) + // union 目前用的 data_set_id 为主键,但这是不科学的,这里临时迁就。 + .append("data_set_id", model.getId()); + + request("data_resource/lazy_update", params); + } catch (StatusCodeWithException e) { + super.log(e); + } + }); + + } + + public void upsertDataResource(DataResourceMysqlModel model) { + JObject params = JObject.create(model) + .append("data_resource_id", model.getId()); + + MemberInfoModel member = globalConfigService.getMemberInfo(); + // If data exposure is prohibited globally, it will not be reported. + if (!member.getMemberAllowPublicDataSet() || member.getMemberHidden()) { + return; + } + + CommonThreadPool.run(() -> { + try { + // If this data set is not publicly available to anyone + if (model.getPublicLevel() == DataResourcePublicLevel.OnlyMyself) { + // Notify union to remove the data set + hiddenDataResource(model); + return; + } + + request(StringUtil.stringToUnderLineLowerCase(model.getDataResourceType().name()) + "/put", params); + } catch (StatusCodeWithException e) { + super.log(e); + } + }); + + } + + /** + * Hidden data set + */ + public void deleteDataResource(DataResourceMysqlModel model) throws StatusCodeWithException { + JObject params = JObject + .create() + .put("data_resource_id", model.getId()); + + request("data_resource/delete", params); + } + + /** + * Hidden data set + */ + public void hiddenDataResource(DataResourceMysqlModel model) throws StatusCodeWithException { + JObject params = JObject + .create() + .put("data_resource_id", model.getId()); + + request("data_resource/hidden", params); + } + + + public OUT getDataResourceDetail(String dataResourceId, Class outputClass) throws StatusCodeWithException { + DataResourceType type = null; + if (outputClass == ImageDataSetOutputModel.class) { + type = DataResourceType.ImageDataSet; + } else if (outputClass == TableDataSetOutputModel.class) { + type = DataResourceType.TableDataSet; + } else if (outputClass == BloomFilterOutputModel.class) { + type = DataResourceType.BloomFilter; + } + return getDataResourceDetail(dataResourceId, type, outputClass); + } + + /** + * 获取数据资源详情 + */ + public OUT getDataResourceDetail(String dataResourceId, DataResourceType dataResourceType, Class outputClass) throws StatusCodeWithException { + String key = dataResourceId + "getDataResourceDetail" + outputClass.getSimpleName(); + if (CACHE_MAP.containsKey(key)) { + return (OUT) CACHE_MAP.get(key); + } + + JObject params = JObject + .create() + .put("data_resource_id", dataResourceId) + .put("data_resource_type", dataResourceType); + JSONObject result = request("data_resource/detail", params); + + JSONObject data = result.getJSONObject("data"); + if (data == null || data.isEmpty()) { + return null; + } + + JSONObject extraData = data.getJSONObject("extra_data"); + if (extraData != null) { + data.putAll(extraData); + data.remove("extra_data"); + } + + data.put("id", data.get("data_resource_id")); + + OUT output = data.toJavaObject(outputClass); + + CACHE_MAP.put(key, output); + return output; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/union/dto/MemberBaseInfo.java b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/union/dto/MemberBaseInfo.java new file mode 100644 index 000000000..f82435739 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/sdk/union/dto/MemberBaseInfo.java @@ -0,0 +1,28 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.sdk.union.dto; + +/** + * @author zane + * @date 2021/12/27 + */ +public class MemberBaseInfo { + public String memberId; + public String name; + public boolean hidden; + public boolean lostContact; + public boolean freezed; +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/AbstractService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/AbstractService.java index f4095b93a..d95461633 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/AbstractService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/AbstractService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,6 +17,8 @@ package com.welab.wefe.board.service.service; import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -29,9 +31,12 @@ public class AbstractService { @Autowired protected GatewayService gatewayService; - + @Autowired + protected UnionService unionService; @Autowired protected Config config; + @Autowired + protected GlobalConfigService globalConfigService; protected void log(Exception e) { LOG.error(e.getClass() + " " + e.getMessage(), e); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/BaseGatewayService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/BaseGatewayService.java index 57bc2e9d6..80da72806 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/BaseGatewayService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/BaseGatewayService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,6 +17,7 @@ package com.welab.wefe.board.service.service; import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.util.JsonFormat; import com.welab.wefe.board.service.dto.globalconfig.GatewayConfigModel; @@ -25,11 +26,10 @@ import com.welab.wefe.board.service.proto.meta.basic.GatewayMetaProto; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.GatewayActionType; -import com.welab.wefe.common.enums.GatewayProcessorType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.StringUtil; -import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.GatewayActionType; +import com.welab.wefe.common.wefe.enums.GatewayProcessorType; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import org.springframework.beans.factory.annotation.Autowired; @@ -49,14 +49,14 @@ public class BaseGatewayService extends AbstractService { /** * Send a message to your own gateway service */ - protected ApiResult sendToMyselfGateway(GatewayActionType action, String data, GatewayProcessorType processorType) { + protected JSONObject sendToMyselfGateway(GatewayActionType action, String data, GatewayProcessorType processorType) throws StatusCodeWithException { return sendToMyselfGateway(null, action, data, processorType); } /** * Send a message to your own gateway service */ - protected ApiResult sendToMyselfGateway(String gatewayUri, GatewayActionType action, String data, GatewayProcessorType processorType) { + protected JSONObject sendToMyselfGateway(String gatewayUri, GatewayActionType action, String data, GatewayProcessorType processorType) throws StatusCodeWithException { if (gatewayUri == null) { GatewayConfigModel gatewayConfig = globalConfigService.getGatewayConfig(); if (gatewayConfig != null) { @@ -64,7 +64,7 @@ protected ApiResult sendToMyselfGateway(String gatewayUri, GatewayActionType } } - return sendMessage( + return callGateway( gatewayUri, CacheObjects.getMemberId(), CacheObjects.getMemberName(), @@ -76,14 +76,15 @@ protected ApiResult sendToMyselfGateway(String gatewayUri, GatewayActionType /** * Send message to other party's gateway service */ - protected ApiResult sendToOtherGateway(String dstMemberId, GatewayActionType action, String data, GatewayProcessorType processorType) { - return sendMessage( + protected JSONObject sendToOtherGateway(String dstMemberId, GatewayActionType action, String data, GatewayProcessorType processorType) throws StatusCodeWithException { + return callGateway( globalConfigService.getGatewayConfig().intranetBaseUri, dstMemberId, CacheObjects.getMemberName(dstMemberId), action, data, - processorType); + processorType + ); } /** @@ -94,42 +95,41 @@ protected ApiResult sendToOtherGateway(String dstMemberId, GatewayActionType * @param dstMemberName The member_name of the target member * @param action action of the message * @param data data of the message - * @param processorType enum, see:{@link com.welab.wefe.common.enums.GatewayProcessorType} + * @param processorType enum, see:{@link com.welab.wefe.common.wefe.enums.GatewayProcessorType} */ - private ApiResult sendMessage(String gatewayUri, String dstMemberId, String dstMemberName, GatewayActionType action, String data, GatewayProcessorType processorType) { + private JSONObject callGateway(String gatewayUri, String dstMemberId, String dstMemberName, GatewayActionType action, String data, GatewayProcessorType processorType) throws StatusCodeWithException { if (StringUtil.isEmpty(gatewayUri)) { - ApiResult.ofErrorWithStatusCode(StatusCode.RPC_ERROR, "尚未设置 gateway 内网地址,请在[全局设置][系统设置]中设置 gateway 服务的内网地址。"); + StatusCode.RPC_ERROR.throwException("尚未设置 gateway 内网地址,请在[全局设置][系统设置]中设置 gateway 服务的内网地址。"); } GatewayMetaProto.TransferMeta transferMeta = buildTransferMeta(dstMemberId, dstMemberName, action, data, processorType); ManagedChannel grpcChannel = null; - ApiResult result = null; + String message = "[grpc] end to " + dstMemberName; try { grpcChannel = getGrpcChannel(gatewayUri); TransferServiceGrpc.TransferServiceBlockingStub clientStub = TransferServiceGrpc.newBlockingStub(grpcChannel); BasicMetaProto.ReturnStatus returnStatus = clientStub.send(transferMeta); if (returnStatus.getCode() != 0) { - result = ApiResult.ofErrorWithStatusCode(StatusCode.REMOTE_SERVICE_ERROR, returnStatus.getMessage()); - return result; - } - if (StringUtil.isEmpty(returnStatus.getData())) { - result = ApiResult.ofSuccess(null); - return result; + StatusCode.REMOTE_SERVICE_ERROR.throwException(returnStatus.getMessage()); } - result = JSON - .parseObject(returnStatus.getData()) - .toJavaObject(ApiResult.class); + + message += "success request:" + data; + LOG.info(message); + + return JSON.parseObject(returnStatus.getData()); } catch (Exception e) { + message += "fail message:" + e.getMessage() + " request:" + data; + LOG.error(message); + LOG.error("Request gateway exception, message: " + transferMetaToString(transferMeta) + ",exception:" + e.getMessage(), e); + try { checkPermission(e); } catch (StatusCodeWithException ex) { - result = ApiResult.ofErrorWithStatusCode(StatusCode.RPC_ERROR, ex.getMessage()); - return result; + StatusCode.RPC_ERROR.throwException(ex.getMessage()); } - result = ApiResult.ofErrorWithStatusCode(StatusCode.RPC_ERROR, e.getMessage()); - + StatusCode.RPC_ERROR.throwException(e.getMessage()); } finally { if (null != grpcChannel) { try { @@ -138,18 +138,9 @@ private ApiResult sendMessage(String gatewayUri, String dstMemberId, String d LOG.error("Closing gateway connection exception:", e); } } - - String message = "[grpc] end to " + dstMemberName; - message += " " + (result.success() ? "success" : "fail message:" + result.getMessage() + " request:" + data); - - if (result.success()) { - LOG.info(message); - } else { - LOG.error(message); - } } - return result; + return null; } @@ -202,6 +193,7 @@ private ManagedChannel getGrpcChannel(String gatewayUri) throws StatusCodeWithEx return ManagedChannelBuilder .forTarget(gatewayUri) .usePlaintext() + .maxInboundMessageSize(2000 * 1024 * 1024) .build(); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/BlacklistService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/BlacklistService.java index c06d23bf6..0fff06bd5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/BlacklistService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/BlacklistService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,7 +26,6 @@ import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.dto.entity.BlacklistOutputModel; import com.welab.wefe.board.service.dto.entity.MemberOutputModel; -import com.welab.wefe.board.service.sdk.UnionService; import com.welab.wefe.common.data.mysql.Where; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.JObject; @@ -48,12 +47,6 @@ public class BlacklistService extends AbstractService { @Autowired private BlacklistRepository blacklistRepository; - @Autowired - private UnionService unionService; - - @Autowired - private GatewayService gatewayService; - public PagingOutput list(BlacklistApi.Input input) throws StatusCodeWithException { List resultList = new ArrayList<>(); @@ -90,7 +83,7 @@ public PagingOutput list(BlacklistApi.Input input) throws /** * Add blacklist */ - public void add(AddApi.Input input) { + public void add(AddApi.Input input) throws StatusCodeWithException { List list = new ArrayList<>(); if (input.getMemberIds() != null) { @@ -112,7 +105,7 @@ public void add(AddApi.Input input) { gatewayService.refreshMemberBlacklistCache(); } - public void deleteFromBlacklist(DeleteApi.Input input) { + public void deleteFromBlacklist(DeleteApi.Input input) throws StatusCodeWithException { blacklistRepository.deleteById(input.getId()); CacheObjects.refreshMemberBlacklist(); // Notify gateway to update blacklist cache diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/CacheObjects.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/CacheObjects.java index 5d7a2eeb7..f3ecb0cfe 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/CacheObjects.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/CacheObjects.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,20 +16,20 @@ package com.welab.wefe.board.service.service; -import com.alibaba.fastjson.JSONArray; -import com.alibaba.fastjson.JSONObject; -import com.welab.wefe.board.service.api.union.MemberListApi; -import com.welab.wefe.board.service.database.entity.AccountMySqlModel; +import com.welab.wefe.board.service.database.entity.AccountMysqlModel; import com.welab.wefe.board.service.database.repository.AccountRepository; import com.welab.wefe.board.service.database.repository.BlacklistRepository; -import com.welab.wefe.board.service.database.repository.DataSetRepository; +import com.welab.wefe.board.service.database.repository.data_resource.DataResourceRepository; import com.welab.wefe.board.service.dto.globalconfig.MemberInfoModel; -import com.welab.wefe.board.service.sdk.UnionService; +import com.welab.wefe.board.service.sdk.union.UnionService; +import com.welab.wefe.board.service.sdk.union.dto.MemberBaseInfo; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import com.welab.wefe.common.Convert; +import com.welab.wefe.common.constant.SecretKeyType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.enums.DataResourceType; import org.springframework.data.domain.Sort; import java.util.*; @@ -54,17 +54,22 @@ public class CacheObjects { private static long LAST_REFRESH_MEMBER_MAP_TIME = 0; + private static long LAST_REFRESH_ACCOUNT_MAP_TIME = 0; private static String MEMBER_ID; private static String RSA_PRIVATE_KEY; private static String RSA_PUBLIC_KEY; private static String MEMBER_NAME; + private static SecretKeyType SECRET_KEY_TYPE = null; /** - * Data set tags + * Data resource tags * tag : count */ - private static final TreeMap DATA_SET_TAGS = new TreeMap<>(); + private static final TreeMap DATA_RESOURCE_TAGS = new TreeMap<>(); + private static final TreeMap TABLE_DATA_SET_TAGS = new TreeMap<>(); + private static final TreeMap IMAGE_DATA_SET_TAGS = new TreeMap<>(); + private static final TreeMap BLOOM_FILTER_TAGS = new TreeMap<>(); /** * accountId : nickname @@ -77,15 +82,19 @@ public class CacheObjects { private static final List ACCOUNT_ID_LIST = new ArrayList<>(); /** - * accountId : member name + * memberId : member base info */ - private static final LinkedHashMap MEMBER_MAP = new LinkedHashMap<>(); + private static LinkedHashMap MEMBER_MAP = new LinkedHashMap<>(); /** * member blacklist */ private static final Set MEMBER_BLACKLIST = new HashSet<>(); + static { + refreshMemberInfo(); + } + public static Set getMemberBlackList() { if (MEMBER_BLACKLIST.isEmpty()) { refreshMemberBlacklist(); @@ -94,44 +103,60 @@ public static Set getMemberBlackList() { } public synchronized static void refreshMemberBlacklist() { - BlacklistRepository repository = Launcher.CONTEXT.getBean(BlacklistRepository.class); + BlacklistRepository repository = Launcher.getBean(BlacklistRepository.class); MEMBER_BLACKLIST.clear(); repository.findAll().forEach(x -> MEMBER_BLACKLIST.add(x.getBlacklistMemberId())); } public static String getMemberId() { - if (MEMBER_ID == null) { - refreshMemberInfo(); - } return MEMBER_ID; } + /** + * 判断指定的 member_id 是属于当前本地成员 + */ + public static boolean isCurrentMember(String memberId) { + return getMemberId().equals(memberId); + } + public static String getRsaPrivateKey() { - if (RSA_PRIVATE_KEY == null) { - refreshMemberInfo(); - } return RSA_PRIVATE_KEY; } public static String getRsaPublicKey() { - if (RSA_PUBLIC_KEY == null) { - refreshMemberInfo(); - } return RSA_PUBLIC_KEY; } public static String getMemberName() { - if (MEMBER_NAME == null) { - refreshMemberInfo(); - } return MEMBER_NAME; } - public static TreeMap getDataSetTags() { - if (DATA_SET_TAGS.isEmpty()) { - refreshDataSetTags(); + public static TreeMap getDataResourceTags(DataResourceType dataResourceType) { + TreeMap map = null; + + if (dataResourceType == null) { + map = DATA_RESOURCE_TAGS; + } else { + switch (dataResourceType) { + case TableDataSet: + map = TABLE_DATA_SET_TAGS; + break; + case ImageDataSet: + map = IMAGE_DATA_SET_TAGS; + break; + case BloomFilter: + map = BLOOM_FILTER_TAGS; + break; + default: + + } + } + + if (map.isEmpty()) { + refreshDataResourceTags(dataResourceType, map); } - return DATA_SET_TAGS; + + return map; } public static List getAccountIdList() { @@ -151,7 +176,7 @@ public static LinkedHashMap getAccountMap() { /** * Get the account's nickname */ - public static synchronized String getNickname(String accountId) { + public static String getNickname(String accountId) { if (accountId == null) { return null; } @@ -161,11 +186,11 @@ public static synchronized String getNickname(String accountId) { /** * Determine whether accountId belongs to the current member */ - public static synchronized boolean isCurrentMember(String accountId) { + public static boolean isCurrentMemberAccount(String accountId) { return getAccountIdList().contains(accountId); } - private static LinkedHashMap getMemberMap() throws StatusCodeWithException { + private static LinkedHashMap getMemberMap() throws StatusCodeWithException { if (MEMBER_MAP.isEmpty()) { refreshMemberMap(); } @@ -176,21 +201,29 @@ private static LinkedHashMap getMemberMap() throws StatusCodeWit * Check if an id is member_id */ public static boolean isMemberId(String memberId) { - return getMemberName(memberId) != null; + try { + return getMemberMap().get(memberId) != null; + } catch (StatusCodeWithException e) { + return false; + } } - public static synchronized String getMemberName(String memberId) { + public static String getMemberName(String memberId) { if (StringUtil.isEmpty(memberId)) { return null; } try { - String memberName = getMemberMap().get(memberId); - if (memberName == null) { + MemberBaseInfo member = getMemberMap().get(memberId); + if (member == null) { CacheObjects.refreshMemberMap(); - memberName = getMemberMap().get(memberId); + member = getMemberMap().get(memberId); + } + + if (member == null) { + return null; } - return memberName; + return member.name; } catch (StatusCodeWithException e) { return null; @@ -201,8 +234,8 @@ public static synchronized String getMemberName(String memberId) { /** * Reload member information */ - public static synchronized void refreshMemberInfo() { - GlobalConfigService service = Launcher.CONTEXT.getBean(GlobalConfigService.class); + public static void refreshMemberInfo() { + GlobalConfigService service = Launcher.getBean(GlobalConfigService.class); MemberInfoModel model = service.getMemberInfo(); if (model == null) { @@ -213,87 +246,80 @@ public static synchronized void refreshMemberInfo() { RSA_PUBLIC_KEY = model.getRsaPublicKey(); RSA_PRIVATE_KEY = model.getRsaPrivateKey(); MEMBER_NAME = model.getMemberName(); + SECRET_KEY_TYPE = model.getSecretKeyType(); } - /** - * Reload the number of data sets corresponding to each tag - */ - public static synchronized void refreshDataSetTags() { + public static void refreshDataResourceTags(DataResourceType dataResourceType) { + TreeMap map = getDataResourceTags(dataResourceType); + refreshDataResourceTags(dataResourceType, map); + } + + public static synchronized void refreshDataResourceTags(DataResourceType dataResourceType, TreeMap map) { + // Query all tags from the database - DataSetRepository repo = Launcher.CONTEXT.getBean(DataSetRepository.class); - List rows = repo.listAllTags(); - DATA_SET_TAGS.clear(); + DataResourceRepository repo = Launcher.getBean(DataResourceRepository.class); + List rows = dataResourceType == null + ? repo.listAllTags() + : repo.listAllTags(dataResourceType.name()); + map.clear(); // Count the number of data sets corresponding to each tag for (Object[] row : rows) { List tags = StringUtil.splitWithoutEmptyItem(String.valueOf(row[0]), ","); - long count = Convert.toLong(row[1]); + int count = Convert.toInt(row[1]); for (String tag : tags) { - if (!DATA_SET_TAGS.containsKey(tag)) { - DATA_SET_TAGS.put(tag, 0L); + if (StringUtil.isEmpty(tag)) { + continue; } - - DATA_SET_TAGS.put(tag, DATA_SET_TAGS.get(tag) + count); - + if (!map.containsKey(tag)) { + map.put(tag, 0); + } + map.put(tag, map.get(tag) + count); } } } + /** * Reload account list + *

+ * 注意:经过测试,这个方法不能加 synchronized 关键字,否则会出现线程死锁。 */ - public static synchronized void refreshAccountMap() { - AccountRepository repo = Launcher.CONTEXT.getBean(AccountRepository.class); - List list = repo.findAll(Sort.by("nickname")); + public static void refreshAccountMap() { + if (System.currentTimeMillis() - LAST_REFRESH_ACCOUNT_MAP_TIME < 10_000) { + return; + } + LAST_REFRESH_ACCOUNT_MAP_TIME = System.currentTimeMillis(); + + AccountRepository repo = Launcher.getBean(AccountRepository.class); + List list = repo.findAll(Sort.by("nickname")); ACCOUNT_MAP.clear(); ACCOUNT_ID_LIST.clear(); - for (AccountMySqlModel item : list) { + for (AccountMysqlModel item : list) { ACCOUNT_MAP.put(item.getId(), item.getNickname()); ACCOUNT_ID_LIST.add(item.getId()); } } - /** * Reload the list of union members */ - public static synchronized void refreshMemberMap() throws StatusCodeWithException { + public static void refreshMemberMap() throws StatusCodeWithException { // Prohibit high frequency refresh if (System.currentTimeMillis() - LAST_REFRESH_MEMBER_MAP_TIME < 60_000) { return; } LAST_REFRESH_MEMBER_MAP_TIME = System.currentTimeMillis(); - UnionService service = Launcher.CONTEXT.getBean(UnionService.class); + UnionService service = Launcher.getBean(UnionService.class); MEMBER_MAP.clear(); - MemberListApi.Input input = new MemberListApi.Input(); - while (true) { - - JSONObject json = service.queryMembers(input); - - JSONArray list = json - .getJSONObject("data") - .getJSONArray("list"); - - if (list.isEmpty()) { - break; - } - - list - .stream() - .map(x -> (JSONObject) x) - .forEach(x -> MEMBER_MAP.put(x.getString("id"), x.getString("name"))); + MEMBER_MAP = service.getMemberMap(); - if (list.size() < input.getPageSize()) { - break; - } - - input.setPageIndex(input.getPageIndex() + 1); - - - } + } + public static SecretKeyType getSecretKeyType() { + return null == SECRET_KEY_TYPE ? SecretKeyType.rsa : SECRET_KEY_TYPE; } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ChatLastAccountService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ChatLastAccountService.java index 61ac085d9..47e31a63a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ChatLastAccountService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ChatLastAccountService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -24,7 +24,7 @@ import com.welab.wefe.board.service.database.repository.ChatUnreadMessageRepository; import com.welab.wefe.board.service.dto.entity.ChatLastAccountOutputModel; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.OrderBy; +import com.welab.wefe.common.data.mysql.enums.OrderBy; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ChatUnreadMessageService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ChatUnreadMessageService.java index 34beaeaf1..7755d0e4e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ChatUnreadMessageService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ChatUnreadMessageService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetColumnService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetColumnService.java index fac692eb1..7bc69a4b3 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetColumnService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetColumnService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -23,8 +23,7 @@ import com.welab.wefe.board.service.dto.entity.data_set.DataSetColumnInputModel; import com.welab.wefe.board.service.dto.entity.data_set.DataSetColumnOutputModel; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.OrderBy; -import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.data.mysql.enums.OrderBy; import org.modelmapper.ModelMapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; @@ -57,7 +56,7 @@ public PagingOutput list(String dataSetId) { ); } - public void update(String dataSetId, List list, CurrentAccount.Info userInfo) { + public void update(String dataSetId, List list) { // clear data set columns dataSetColumnRepository.deleteByDataSetId(dataSetId); @@ -66,7 +65,6 @@ public void update(String dataSetId, List list, Current DataSetColumnInputModel item = list.get(i); DataSetColumnMysqlModel column = new ModelMapper().map(item, DataSetColumnMysqlModel.class); - column.setCreatedBy(userInfo.id); column.setDataSetId(dataSetId); column.setIndex(i); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetService.java deleted file mode 100644 index 924e40494..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetService.java +++ /dev/null @@ -1,404 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.service; - -import java.io.File; -import java.sql.Connection; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Consumer; -import java.util.stream.Collectors; - -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.data.jpa.domain.Specification; -import org.springframework.stereotype.Service; - -import com.welab.wefe.board.service.api.dataset.DeleteApi; -import com.welab.wefe.board.service.api.dataset.QueryApi; -import com.welab.wefe.board.service.api.dataset.UpdateApi; -import com.welab.wefe.board.service.constant.Config; -import com.welab.wefe.board.service.constant.DataSetAddMethod; -import com.welab.wefe.board.service.database.entity.DataSourceMySqlModel; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; -import com.welab.wefe.board.service.database.entity.job.ProjectDataSetMySqlModel; -import com.welab.wefe.board.service.database.entity.job.ProjectMySqlModel; -import com.welab.wefe.board.service.database.repository.DataSetRepository; -import com.welab.wefe.board.service.database.repository.DataSourceRepository; -import com.welab.wefe.board.service.database.repository.JobMemberRepository; -import com.welab.wefe.board.service.database.repository.JobRepository; -import com.welab.wefe.board.service.database.repository.ProjectDataSetRepository; -import com.welab.wefe.board.service.database.repository.ProjectRepository; -import com.welab.wefe.board.service.dto.base.PagingOutput; -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; -import com.welab.wefe.board.service.dto.entity.project.ProjectUsageDetailOutputModel; -import com.welab.wefe.board.service.onlinedemo.OnlineDemoBranchStrategy; -import com.welab.wefe.board.service.sdk.UnionService; -import com.welab.wefe.board.service.util.JdbcManager; -import com.welab.wefe.board.service.util.ModelMapper; -import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.DataSetPublicLevel; -import com.welab.wefe.common.enums.OrderBy; -import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.util.StringUtil; -import com.welab.wefe.common.web.CurrentAccount; - -/** - * @author Zane - */ -@Service -public class DataSetService extends AbstractService { - - @Autowired - protected DataSetRepository repo; - @Autowired - protected DataSetColumnService dataSetColumnService; - @Autowired - protected UnionService unionService; - @Autowired - protected DataSetStorageService dataSetStorageService; - @Autowired - protected JobRepository jobRepository; - @Autowired - protected JobMemberRepository jobMemberRepository; - @Autowired - protected JobRepository featureJobRepository; - @Autowired - DataSourceRepository dataSourceRepo; - @Autowired - private DataSetRepository dataSetRepository; - @Autowired - private Config config; - @Autowired - private ProjectDataSetRepository projectDataSetRepository; - @Autowired - private ProjectRepository projectRepository; - - /** - * Get uploaded file - */ - public File getDataSetFile(DataSetAddMethod method, String filename) throws StatusCodeWithException { - File file = null; - switch (method) { - case HttpUpload: - file = new File(config.getFileUploadDir(), filename); - break; - case LocalFile: - file = new File(filename); - break; - case Database: - break; - default: - } - - if (null == file || !file.exists()) { - throw new StatusCodeWithException("未找到文件:" + filename, StatusCode.PARAMETER_VALUE_INVALID); - } - - return file; - } - - /** - * Paging query data set - */ - public PagingOutput query(QueryApi.Input input) { - - Specification where = Where - .create() - .equal("id", input.getId()) - .contains("name", input.getName()) - .containsItem("tags", input.getTag()) - .equal("containsY", input.getContainsY()) - .equal("createdBy", input.getCreator()) - .equal("sourceType", null, false) - .orderBy("createdTime", OrderBy.desc) - .build(DataSetMysqlModel.class); - - return repo.paging(where, input, DataSetOutputModel.class); - } - - public DataSetMysqlModel query(String sourceJobId, ComponentType sourceType) { - - Specification where = Where.create().equal("sourceJobId", sourceJobId) - .equal("sourceType", sourceType).build(DataSetMysqlModel.class); - - return repo.findOne(where).orElse(null); - } - - public DataSetMysqlModel save(DataSetMysqlModel model) { - return repo.save(model); - } - - /** - * delete data set - */ - public void delete(DeleteApi.Input input) throws StatusCodeWithException { - DataSetMysqlModel model = repo.findById(input.getId()).orElse(null); - if (model == null) { - return; - } - - OnlineDemoBranchStrategy.hackOnDelete(input, model, "只能删除自己添加的数据集。"); - - delete(model); - } - - /** - * delete data set - */ - public void delete(String dataSetId) throws StatusCodeWithException { - DataSetMysqlModel model = repo.findById(dataSetId).orElse(null); - if (model == null) { - return; - } - - delete(model); - } - - /** - * delete data set - */ - public void delete(DataSetMysqlModel model) throws StatusCodeWithException { - - // delete data set from database - repo.deleteById(model.getId()); - - // delete data set from storage - dataSetStorageService.deleteDataSet(model.getId()); - - // is raw data set - if (model.getSourceType() == null) { - // Notify the union to do not public the data set - unionService.dontPublicDataSet(model.getId()); - - // Refresh the data set tag list - CacheObjects.refreshDataSetTags(); - } - - } - - /** - * update data set info - */ - public void update(UpdateApi.Input input) throws StatusCodeWithException { - - if (repo.countByName(input.getName(), input.getId()) > 0) { - throw new StatusCodeWithException("此数据集名称已存在,请换一个数据集名称", StatusCode.PARAMETER_VALUE_INVALID); - } - - DataSetMysqlModel model = repo.findById(input.getId()).orElse(null); - if (model == null) { - return; - } - - model.setUpdatedBy(CurrentAccount.id()); - model.setName(input.getName()); - model.setTags(StringUtil.join(input.getTags(), ",")); - model.setDescription(input.getDescription()); - model.setPublicMemberList(input.getPublicMemberList()); - model.setPublicLevel(input.getPublicLevel()); - model.setTags(standardizeTags(input.getTags())); - - handlePublicMemberList(model); - - repo.save(model); - - // save data set column info to database - dataSetColumnService.update(input.getId(), input.getMetadataList(), CurrentAccount.get()); - - unionService.uploadDataSet(model); - - CacheObjects.refreshDataSetTags(); - } - - /** - * Process the list of visible members - *

- * When the scene is visible to the specified members, automatically add itself is also visible. - */ - public void handlePublicMemberList(DataSetMysqlModel model) { - - // When the PublicLevel is PublicWithMemberList, if list contains yourself, - // you will be removed, and union will handle the data that you must be visible. - if (model.getPublicLevel() == DataSetPublicLevel.PublicWithMemberList) { - String memberId = CacheObjects.getMemberId(); - - - if (model.getPublicMemberList().contains(memberId)) { - String list = model.getPublicMemberList() - .replace(memberId, "") - .replace(",,", ","); - - model.setPublicMemberList(list); - } - } - - } - - - /** - * Standardize the tag list - */ - public String standardizeTags(List tags) { - if (tags == null) { - return ""; - } - - tags = tags.stream() - // Remove comma(,,) - .map(x -> x.replace(",", "").replace(",", "")) - // Remove empty elements - .filter(x -> !StringUtil.isEmpty(x)) - .distinct() - .sorted() - .collect(Collectors.toList()); - - // Concatenate into a string, add a comma before and after it to facilitate like query. - return "," + StringUtil.join(tags, ',') + ","; - - } - - /** - * get data source by id - */ - public DataSourceMySqlModel getDataSourceById(String dataSourceId) { - return dataSourceRepo.findById(dataSourceId).orElse(null); - } - - /** - * get data sets info from local or union - */ - public DataSetOutputModel findDataSetFromLocalOrUnion(String memberId, String dataSetId) throws StatusCodeWithException { - - if (memberId.equals(CacheObjects.getMemberId())) { - DataSetMysqlModel dataSet = repo.findById(dataSetId).orElse(null); - if (dataSet == null) { - return null; - } - return ModelMapper.map(dataSet, DataSetOutputModel.class); - } else { - return unionService.queryDataSetDetail(dataSetId); - } - } - - public DataSetMysqlModel findOne(String dataSetId) { - return repo.findById(dataSetId).orElse(null); - - } - - /** - * Test whether SQL can be queried normally - */ - public boolean testSqlQuery(String dataSourceId, String sql) throws StatusCodeWithException { - DataSourceMySqlModel model = getDataSourceById(dataSourceId); - if (model == null) { - throw new StatusCodeWithException("dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); - } - - if (StringUtils.isEmpty(sql)) { - throw new StatusCodeWithException("请填入sql查询语句", StatusCode.PARAMETER_CAN_NOT_BE_EMPTY); - } - - Connection conn = JdbcManager.getConnection( - model.getDatabaseType(), - model.getHost(), - model.getPort(), - model.getUserName(), - model.getPassword(), - model.getDatabaseName() - ); - - return JdbcManager.testQuery(conn, sql, true); - } - - /** - * Update the number of data sets used in the project - */ - public void updateUsageCountInProject(String dataSetId) { - dataSetRepository.updateUsageCountInProject(dataSetId); - - DataSetMysqlModel model = repo.findById(dataSetId).orElse(null); - if (model == null) { - return; - } - - try { - unionService.uploadDataSet(model); - } catch (StatusCodeWithException e) { - super.log(e); - } - } - - /** - * The number of data sets used in the flow ++ - */ - public void usageCountInFlowIncrement(String dataSetId) throws StatusCodeWithException { - updateUsageCount(dataSetId, x -> x.setUsageCountInProject(x.getUsageCountInProject() + 1)); - } - - /** - * The number of data sets used in the flow -- - */ - public void usageCountInFlowDecrement(String dataSetId) throws StatusCodeWithException { - updateUsageCount(dataSetId, x -> x.setUsageCountInFlow(x.getUsageCountInFlow() - 1)); - } - - /** - * The number of data sets used in the job ++ - */ - public void usageCountInJobIncrement(String dataSetId) throws StatusCodeWithException { - updateUsageCount(dataSetId, x -> x.setUsageCountInJob(x.getUsageCountInJob() + 1)); - } - - /** - * Update the various usage count of the data set - */ - private void updateUsageCount(String dataSetId, Consumer func) throws StatusCodeWithException { - DataSetMysqlModel model = repo.findById(dataSetId).orElse(null); - if (model == null) { - return; - } - - func.accept(model); - repo.save(model); - - unionService.uploadDataSet(model); - } - - /** - * Query the project information used by the dataset in the project - */ - public List queryUsageInProject(String dataSetId) { - List ProjectUsageDetailOutputModelList = new ArrayList<>(); - List usageInProjectList = projectDataSetRepository.queryUsageInProject(dataSetId); - if (usageInProjectList == null || usageInProjectList.isEmpty()) { - return ProjectUsageDetailOutputModelList; - } - - for (ProjectDataSetMySqlModel usageInProject : usageInProjectList) { - ProjectMySqlModel projectMySqlModel = projectRepository.findOneById(usageInProject.getProjectId()); - ProjectUsageDetailOutputModel projectUsageDetailOutputModel = new ProjectUsageDetailOutputModel(); - projectUsageDetailOutputModel.setName(projectMySqlModel.getName()); - projectUsageDetailOutputModel.setProjectId(projectMySqlModel.getProjectId()); - ProjectUsageDetailOutputModelList.add(projectUsageDetailOutputModel); - } - - return ProjectUsageDetailOutputModelList; - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetStorageService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetStorageService.java index e36dae53a..fc6e60ca2 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetStorageService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSetStorageService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -217,4 +217,8 @@ public int count(String databaseName, String tableName) { public int getAddBatchSize(int columns) { return storageService.getAddBatchSize(columns); } + + public DataItemModel getByKey(String databaseName, String tableName, String key) { + return storageService.getByKey(databaseName, tableName, key); + } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSourceService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSourceService.java index 45a813f95..b0d2e30dd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSourceService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/DataSourceService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,16 +20,15 @@ import com.welab.wefe.board.service.api.datasource.DeleteApi; import com.welab.wefe.board.service.api.datasource.QueryApi; import com.welab.wefe.board.service.api.datasource.TestDBConnectApi; -import com.welab.wefe.board.service.database.entity.DataSourceMySqlModel; +import com.welab.wefe.board.service.database.entity.DataSourceMysqlModel; import com.welab.wefe.board.service.database.repository.DataSourceRepository; import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.util.JdbcManager; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.util.Md5; import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.web.util.ModelMapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; import org.springframework.stereotype.Service; @@ -53,7 +52,7 @@ public AddApi.DataSourceAddOutput add(AddApi.DataSourceAddInput input) throws St // Test if the connection is available testdbconnect(input); - DataSourceMySqlModel model = ModelMapper.map(input, DataSourceMySqlModel.class); + DataSourceMysqlModel model = ModelMapper.map(input, DataSourceMysqlModel.class); model.setCreatedBy(CurrentAccount.id()); // model.setPassword(Md5.of(model.getPassword())); dataSourceRepo.save(model); @@ -67,7 +66,7 @@ public AddApi.DataSourceAddOutput add(AddApi.DataSourceAddInput input) throws St * Delete data sources */ public void delete(DeleteApi.Input input) { - DataSourceMySqlModel model = dataSourceRepo.findById(input.getId()).orElse(null); + DataSourceMysqlModel model = dataSourceRepo.findById(input.getId()).orElse(null); if (model == null) { return; } @@ -79,9 +78,9 @@ public void delete(DeleteApi.Input input) { * Query data source by pagination */ public PagingOutput query(QueryApi.Input input) { - Specification where = Where.create() + Specification where = Where.create() .equal("name", input.getName()) - .build(DataSourceMySqlModel.class); + .build(DataSourceMysqlModel.class); return dataSourceRepo.paging(where, input, QueryApi.Output.class); } @@ -95,7 +94,7 @@ public TestDBConnectApi.Output testdbconnect(AddApi.DataSourceAddInput input) th if (conn != null) { boolean success = JdbcManager.testQuery(conn); if (!success) { - throw new StatusCodeWithException(StatusCode.DATABASE_LOST, "数据库连接失败"); + throw new StatusCodeWithException(StatusCode.DATABASE_LOST, "测试连接数据库失败,请检查数据库是否正常或者账号密码是否填写错误"); } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/EmailService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/EmailService.java index 0114d071b..e5264e09f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/EmailService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/EmailService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,14 +16,14 @@ package com.welab.wefe.board.service.service; -import com.welab.wefe.board.service.database.entity.AccountMySqlModel; +import com.welab.wefe.board.service.database.entity.AccountMysqlModel; import com.welab.wefe.board.service.database.entity.MessageMysqlModel; import com.welab.wefe.board.service.dto.globalconfig.MailServerModel; import com.welab.wefe.board.service.service.account.AccountService; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; -import com.welab.wefe.common.enums.MessageLevel; -import com.welab.wefe.common.enums.ProducerType; import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.enums.MessageLevel; +import com.welab.wefe.common.wefe.enums.ProducerType; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.mail.MailSendException; @@ -48,9 +48,9 @@ public class EmailService extends AbstractService { private static final String MAIL_DEFAULT_ENCODING = "UTF-8"; private static final String MAIL_SMTP_AUTH = "true"; - private static final String MAIL_SMTP_WRITE_TIMEOUT = "5000"; - private static final String MAIL_SMTP_TIMEOUT = "5000"; - private static final String MAIL_SMTP_CONNECTION_TIMEOUT = "5000"; + private static final String MAIL_SMTP_WRITE_TIMEOUT = "30000"; + private static final String MAIL_SMTP_TIMEOUT = "30000"; + private static final String MAIL_SMTP_CONNECTION_TIMEOUT = "30000"; @Autowired private MessageService messageService; @@ -136,6 +136,33 @@ public Set sendMail(String from, Set to, String subject, String return new HashSet<>(16); } + /** + * Send multiple emails + * + * @param from Sender + * @param to Recipient + * @param subject subject + * @param content content + */ + public void sendMail(String from, String to, String subject, String content) throws Exception { + JavaMailSenderImpl javaMailSender = getMailSender(); + + MimeMessage mimeMessage = javaMailSender.createMimeMessage(); + MimeMessageHelper mineHelper = new MimeMessageHelper(mimeMessage, true); + mineHelper.setFrom(from); + mineHelper.setTo(to); + mineHelper.setSubject(subject); + mineHelper.setText(content, true); + + try { + javaMailSender.send(mimeMessage); + } catch (Exception e) { + LOG.error("Sending mail exception:", e); + throw e; + } + } + + /** * Get message sender @@ -171,6 +198,8 @@ private JavaMailSenderImpl getMailSender() throws Exception { mailProperties.setProperty("mail.smtp.writetimeout", MAIL_SMTP_WRITE_TIMEOUT); mailProperties.setProperty("mail.smtp.timeout", MAIL_SMTP_TIMEOUT); mailProperties.setProperty("mail.smtp.connectiontimeout", MAIL_SMTP_CONNECTION_TIMEOUT); + mailProperties.setProperty("mail.smtp.ssl.enable", "true"); + mailProperties.setProperty("mail.debug", "true"); javaMailSender.setJavaMailProperties(mailProperties); return javaMailSender; @@ -182,9 +211,9 @@ private JavaMailSenderImpl getMailSender() throws Exception { */ private Set getTotalEmails() { Set totalEmails = new HashSet<>(16); - List accountMySqlModelList = accountService.queryAll(); - if (CollectionUtils.isNotEmpty(accountMySqlModelList)) { - for (AccountMySqlModel model : accountMySqlModelList) { + List accountMysqlModelList = accountService.queryAll(); + if (CollectionUtils.isNotEmpty(accountMysqlModelList)) { + for (AccountMysqlModel model : accountMysqlModelList) { if (StringUtil.isNotEmpty(model.getEmail())) { totalEmails.add(model.getEmail()); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/EncryptPhoneNumberService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/EncryptPhoneNumberService.java new file mode 100644 index 000000000..e4937af26 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/EncryptPhoneNumberService.java @@ -0,0 +1,75 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service; + + +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.board.service.database.entity.AccountMysqlModel; +import com.welab.wefe.board.service.database.entity.GlobalConfigMysqlModel; +import com.welab.wefe.board.service.database.entity.VerificationCodeMysqlModel; +import com.welab.wefe.board.service.database.repository.AccountRepository; +import com.welab.wefe.board.service.database.repository.GlobalConfigRepository; +import com.welab.wefe.board.service.database.repository.VerificationCodeRepository; +import com.welab.wefe.board.service.service.globalconfig.BaseGlobalConfigService; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.util.CollectionUtils; + +import java.util.Date; +import java.util.List; + +@Service +public class EncryptPhoneNumberService { + @Autowired + private Config config; + + @Autowired + private AccountRepository accountRepository; + + @Autowired + private VerificationCodeRepository verificationCodeRepository; + + @Autowired + protected GlobalConfigRepository globalConfigRepository; + + @Transactional(rollbackFor = Exception.class) + public void encrypt() { + List accountMysqlModelList = accountRepository.findAll(); + if (!CollectionUtils.isEmpty(accountMysqlModelList)) { + for (AccountMysqlModel model : accountMysqlModelList) { + model.setUpdatedTime(new Date()); + accountRepository.save(model); + } + } + List verificationCodeMysqlModelList = verificationCodeRepository.findAll(); + if (!CollectionUtils.isEmpty(verificationCodeMysqlModelList)) { + for (VerificationCodeMysqlModel model : verificationCodeMysqlModelList) { + model.setUpdatedTime(new Date()); + verificationCodeRepository.save(model); + } + } + List globalConfigMysqlModelList = globalConfigRepository.findByGroup(BaseGlobalConfigService.Group.MEMBER_INFO); + if (!CollectionUtils.isEmpty(globalConfigMysqlModelList)) { + for (GlobalConfigMysqlModel model : globalConfigMysqlModelList) { + model.setUpdatedTime(new Date()); + globalConfigRepository.save(model); + } + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/FeatureDataOutputInfoService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/FeatureDataOutputInfoService.java index 27b57a876..e90088ff4 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/FeatureDataOutputInfoService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/FeatureDataOutputInfoService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,8 +21,8 @@ import com.welab.wefe.board.service.database.repository.JobRepository; import com.welab.wefe.board.service.database.repository.OutputModelRepository; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; import org.springframework.stereotype.Service; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowActionQueueService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowActionQueueService.java index 9524da12b..cf664f1eb 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowActionQueueService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowActionQueueService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,10 +19,11 @@ import com.welab.wefe.board.service.database.entity.flow.FlowActionQueueMySqlModel; import com.welab.wefe.board.service.database.entity.job.JobMySqlModel; import com.welab.wefe.board.service.database.repository.FlowActionQueueRepository; -import com.welab.wefe.common.enums.FlowActionType; -import com.welab.wefe.common.enums.ProducerType; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.dto.AbstractApiInput; +import com.welab.wefe.common.wefe.enums.FlowActionType; +import com.welab.wefe.common.wefe.enums.ProducerType; +import com.welab.wefe.common.wefe.enums.ProjectType; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -37,10 +38,47 @@ public class FlowActionQueueService extends AbstractService { @Autowired private JobService jobService; + public void runJob(AbstractApiInput input, String jobId, ProjectType projectType) { + JObject params = + projectType == ProjectType.DeepLearning + ? JObject.create("type", "visualfl") + : null; + + notifyFlow( + input, + jobId, + FlowActionType.run_job, + params + ); + } + + public void stopJob(AbstractApiInput input, String jobId, ProjectType projectType) { + JObject params = + projectType == ProjectType.DeepLearning + ? JObject.create("type", "visualfl") + : null; + + notifyFlow( + input, + jobId, + FlowActionType.stop_job, + params + ); + } + + public void notifyFlow(AbstractApiInput input, String jobId, FlowActionType actionType) { + + notifyFlow(input, jobId, actionType, null); + } + /** * send a action message to flow service */ - public void notifyFlow(AbstractApiInput input, String jobId, FlowActionType actionType) { + public void notifyFlow(AbstractApiInput input, String jobId, FlowActionType actionType, JObject params) { + + if (params == null) { + params = new JObject(); + } for (JobMySqlModel job : jobService.listByJobId(jobId)) { @@ -49,8 +87,7 @@ public void notifyFlow(AbstractApiInput input, String jobId, FlowActionType acti action.setProducer(input.fromGateway() ? ProducerType.gateway : ProducerType.board); action.setPriority(0); action.setParams( - JObject - .create() + params .put("jobId", job.getJobId()) .put("dstRole", job.getMyRole().name()) .toStringWithNull() diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowJobService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowJobService.java index 15baf7c2b..029459b41 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowJobService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowJobService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,12 +26,12 @@ import com.welab.wefe.board.service.dto.entity.job.JobListOutputModel; import com.welab.wefe.board.service.dto.entity.job.JobOutputModel; import com.welab.wefe.board.service.dto.vo.JobProgressOutput; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.JobStatus; -import com.welab.wefe.common.enums.TaskStatus; import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.JobStatus; +import com.welab.wefe.common.wefe.enums.TaskStatus; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; import org.springframework.stereotype.Service; @@ -63,6 +63,7 @@ public PagingOutput query(QueryApi.Input input) { .equal("flowId", input.getFlowId()) .equal("jobId", input.getJobId()) .equal("status", input.getStatus()) + .notEqual("myRole", JobMemberRole.arbiter) .contains("name", input.getName()) .build(JobMySqlModel.class); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowTemplateService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowTemplateService.java index f74454231..9acf5485c 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowTemplateService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/FlowTemplateService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/GatewayService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/GatewayService.java index 07f69f642..c99f0a418 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/GatewayService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/GatewayService.java @@ -1,11 +1,11 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -17,32 +17,41 @@ package com.welab.wefe.board.service.service; import com.alibaba.fastjson.JSONObject; +import com.welab.wefe.board.service.api.project.flow.AddFlowApi; +import com.welab.wefe.board.service.api.project.flow.CopyFlowApi; +import com.welab.wefe.board.service.api.project.flow.DeleteApi; +import com.welab.wefe.board.service.api.project.flow.UpdateFlowBaseInfoApi; +import com.welab.wefe.board.service.api.project.flow.UpdateFlowGraphApi; +import com.welab.wefe.board.service.api.project.node.UpdateApi; import com.welab.wefe.board.service.api.project.project.AddApi; import com.welab.wefe.board.service.database.entity.job.JobMemberMySqlModel; +import com.welab.wefe.board.service.database.entity.job.ProjectFlowMySqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectMemberMySqlModel; import com.welab.wefe.board.service.database.repository.JobMemberRepository; import com.welab.wefe.board.service.exception.MemberGatewayException; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.GatewayActionType; -import com.welab.wefe.common.enums.GatewayProcessorType; -import com.welab.wefe.common.enums.JobMemberRole; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.web.api.base.Api; import com.welab.wefe.common.web.dto.AbstractApiInput; import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.checkpoint.dto.ServiceAvailableCheckOutput; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.GatewayActionType; +import com.welab.wefe.common.wefe.enums.GatewayProcessorType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import org.apache.commons.lang.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; -import ru.yandex.clickhouse.util.apache.StringUtils; import java.util.List; import java.util.stream.Collectors; /** - * @author seven.zeng + * @author zane.luo */ @Service public class GatewayService extends BaseGatewayService { @@ -59,6 +68,8 @@ public class GatewayService extends BaseGatewayService { private JobMemberService jobMemberService; @Autowired private GlobalConfigService globalConfigService; + @Autowired + private ProjectFlowService projectFlowService; /** * Synchronize messages to all job participants @@ -88,8 +99,7 @@ public void syncToOtherJobMembers(String jobId, AbstractApiInput input, Class continue; } - sendToBoardRedirectApi(member.getMemberId(), me.getJobRole(), input, api); - + callOtherMemberBoard(member.getMemberId(), me.getJobRole(), api, input); } } @@ -138,6 +148,30 @@ private void syncToOtherProjectMembers(String projectId, AbstractApiInput input, return; } + boolean needSkipOtherPromoters = false; + String flowId = ""; + // 对流程相关操作特殊处理 + if (input instanceof UpdateApi.Input) { + flowId = ((UpdateApi.Input) input).getFlowId(); + } else if (input instanceof UpdateFlowGraphApi.Input) { + flowId = ((UpdateFlowGraphApi.Input) input).getFlowId(); + } else if (input instanceof UpdateFlowBaseInfoApi.Input) { + flowId = ((UpdateFlowBaseInfoApi.Input) input).getFlowId(); + } else if (input instanceof AddFlowApi.Input) { + flowId = ((AddFlowApi.Input) input).getFlowId(); + } else if (input instanceof CopyFlowApi.Input) { + flowId = ((CopyFlowApi.Input) input).getSourceFlowId(); + } else if (input instanceof DeleteApi.Input) { + flowId = ((DeleteApi.Input) input).getFlowId(); + } + if (StringUtils.isNotBlank(flowId)) { + ProjectFlowMySqlModel flow = projectFlowService.findOne(flowId); + if (flow.getFederatedLearningType() == FederatedLearningType.horizontal + || flow.getFederatedLearningType() == FederatedLearningType.vertical) { + needSkipOtherPromoters = true; + } + } + checkProjectMemberList(members); for (ProjectMemberMySqlModel member : members) { // Skip self @@ -159,8 +193,11 @@ private void syncToOtherProjectMembers(String projectId, AbstractApiInput input, if (input instanceof AddApi.Input) { ((AddApi.Input) input).setRole(member.getMemberRole()); } - sendToBoardRedirectApi(member.getMemberId(), me.getMemberRole(), input, api); + if (needSkipOtherPromoters && member.getMemberRole() == JobMemberRole.promoter) { + continue; + } + callOtherMemberBoard(member.getMemberId(), me.getMemberRole(), api, input); } } @@ -198,7 +235,7 @@ private List findMembersByProjectId(String projectId) { List members = projectMemberService.findListByProjectId(projectId); ProjectMemberMySqlModel promoter = members.stream() - .filter(x -> x.getMemberRole() == JobMemberRole.promoter && StringUtils.isBlank(x.getInviterId())) + .filter(x -> x.getMemberRole() == JobMemberRole.promoter && StringUtil.isBlank(x.getInviterId())) .findFirst().orElse(null); // Since the initiator models with itself, the records of the initiator as a provider should be eliminated to @@ -213,10 +250,10 @@ private List findMembersByProjectId(String projectId) { /** * Notify the gateway to update the system configuration cache */ - public void refreshSystemConfigCache() { + public void refreshSystemConfigCache() throws StatusCodeWithException { sendToMyselfGateway( - GatewayActionType.refresh_system_config_cache, - "refresh_system_config_cache", + GatewayActionType.none, + "", GatewayProcessorType.refreshSystemConfigCacheProcessor ); } @@ -224,10 +261,10 @@ public void refreshSystemConfigCache() { /** * Notify the gateway to update the member blacklist cache */ - public void refreshMemberBlacklistCache() { + public void refreshMemberBlacklistCache() throws StatusCodeWithException { sendToMyselfGateway( - GatewayActionType.refresh_system_config_cache, - "refresh_member_blacklist_cache", + GatewayActionType.none, + "", GatewayProcessorType.refreshMemberBlacklistCacheProcessor ); } @@ -235,61 +272,101 @@ public void refreshMemberBlacklistCache() { /** * Notify the gateway to update the IP whitelist cache */ - public void refreshIpWhiteListCache() { + public void refreshIpWhiteListCache() throws StatusCodeWithException { sendToMyselfGateway( - GatewayActionType.refresh_system_config_cache, - "refresh_ip_white_list_cache", + GatewayActionType.none, + "", GatewayProcessorType.refreshSystemConfigCacheProcessor ); } - /** - * Call the board of other members - */ - public ApiResult callOtherMemberBoard(String dstMemberId, Class api, Object data) throws MemberGatewayException { - Api annotation = api.getAnnotation(Api.class); - return callOtherMemberBoard(dstMemberId, annotation.path(), - data instanceof JSONObject - ? (JSONObject) data - : JObject.create(data) - ); + public ServiceAvailableCheckOutput getLocalGatewayAvailable() throws StatusCodeWithException { + return sendToMyselfGateway( + GatewayActionType.none, + "", + GatewayProcessorType.gatewayAvailableProcessor + ).toJavaObject(ServiceAvailableCheckOutput.class); } - /** - * Call the board of other members - */ - public ApiResult callOtherMemberBoard(String dstMemberId, String api, JSONObject data) throws MemberGatewayException { + public T callOtherMemberBoard(String dstMemberId, Class api, Class resultClass) throws StatusCodeWithException { + return callOtherMemberBoard(dstMemberId, null, api, null, resultClass); + } - String request = JObject.create() - .append("url", api) - .append("method", "POST") - .append("body", data) - .toStringWithNull(); + public void callOtherMemberBoard(String dstMemberId, Class api, Object params) throws StatusCodeWithException { + callOtherMemberBoard(dstMemberId, null, api, params, Object.class); + } - ApiResult result = sendToOtherGateway(dstMemberId, GatewayActionType.http_job, request, GatewayProcessorType.boardHttpProcessor); - if (!result.success()) { - throw new MemberGatewayException(dstMemberId, result.getMessage()); - } + public void callOtherMemberBoard(String dstMemberId, JobMemberRole senderRole, Class api, Object params) throws StatusCodeWithException { + callOtherMemberBoard(dstMemberId, senderRole, api, params, Object.class); + } - return result; + public T callOtherMemberBoard(String dstMemberId, Class api, Object params, Class resultClass) throws StatusCodeWithException { + return callOtherMemberBoard(dstMemberId, null, api, params, resultClass); } /** * Send the request to the gateway/redirect interface in the board + * + * @param dstMemberId 接收请求的成员Id + * @param senderRole 发送请求的成员角色,可以为 null。 + * @param api 被调用的接口名 + * @param params 接口请求参数 + * @param resultClass 响应结果的实体类型 */ - public ApiResult sendToBoardRedirectApi(String receiverMemberId, JobMemberRole senderRole, Object data, Class api) throws MemberGatewayException { + public T callOtherMemberBoard(String dstMemberId, JobMemberRole senderRole, Class api, Object params, Class resultClass) throws StatusCodeWithException { Api annotation = api.getAnnotation(Api.class); - return callOtherMemberBoard(receiverMemberId, "gateway/redirect", + JSONObject result = callOtherMemberBoard( + dstMemberId, + "gateway/redirect", JObject .create() .put("api", annotation.path()) - .put("data", data) + .put("data", params) .put("caller_member_id", CacheObjects.getMemberId()) .put("caller_member_name", CacheObjects.getMemberName()) - .put("caller_member_role", senderRole.name()) + .put("caller_member_role", senderRole == null ? "" : senderRole.name()) ); + ApiResult apiResult = result.toJavaObject(ApiResult.class); + if (!apiResult.success()) { + throw new MemberGatewayException(dstMemberId, apiResult.message); + } + + JSONObject data = result.getJSONObject("data"); + + if (data == null) { + return null; + } + + if (resultClass == JObject.class) { + return (T) JObject.create(data); + } + + return data.toJavaObject(resultClass); + } + + + /** + * Call the board of other members + */ + private JSONObject callOtherMemberBoard(String dstMemberId, String api, JSONObject data) throws StatusCodeWithException { + + String request = JObject.create() + .append("url", api) + .append("method", "POST") + .append("body", data) + .toStringWithNull(); + + JSONObject result = sendToOtherGateway( + dstMemberId, + GatewayActionType.none, + request, + GatewayProcessorType.boardHttpProcessor + ); + + + return result; } @@ -318,11 +395,7 @@ public void checkMemberRouteConnect(String gatewayUri) throws StatusCodeWithExce ) .toStringWithNull(); - ApiResult result = sendToMyselfGateway(gatewayUri, GatewayActionType.http_job, data, GatewayProcessorType.boardHttpProcessor); - if (!result.success()) { - throw new MemberGatewayException(CacheObjects.getMemberId(), result.getMessage()); - } - + sendToMyselfGateway(gatewayUri, GatewayActionType.http_job, data, GatewayProcessorType.boardHttpProcessor).toJavaObject(ApiResult.class); } /** @@ -336,10 +409,7 @@ public void pingGatewayAlive(String gatewayUri) throws StatusCodeWithException { gatewayUri = globalConfigService.getGatewayConfig().intranetBaseUri; } - ApiResult result = sendToMyselfGateway(gatewayUri, GatewayActionType.not_null, JObject.create().toString(), GatewayProcessorType.gatewayAliveProcessor); - if (!result.success()) { - throw new MemberGatewayException(CacheObjects.getMemberId(), result.getMessage()); - } + sendToMyselfGateway(gatewayUri, GatewayActionType.none, JObject.create().toString(), GatewayProcessorType.gatewayAliveProcessor); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/JobMemberService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/JobMemberService.java index 8190f023d..a8fdcf766 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/JobMemberService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/JobMemberService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,9 +20,9 @@ import com.welab.wefe.board.service.database.entity.job.JobMemberMySqlModel; import com.welab.wefe.board.service.database.repository.JobMemberRepository; import com.welab.wefe.board.service.dto.entity.job.JobMemberOutputModel; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.domain.Sort; import org.springframework.data.jpa.domain.Specification; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/JobService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/JobService.java index ada70839f..ce3b47ac0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/JobService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/JobService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,19 +21,17 @@ import com.welab.wefe.board.service.database.entity.job.ProjectFlowMySqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; -import com.welab.wefe.board.service.database.repository.DataSetRepository; import com.welab.wefe.board.service.database.repository.JobMemberRepository; import com.welab.wefe.board.service.database.repository.JobRepository; import com.welab.wefe.board.service.database.repository.TaskRepository; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.board.service.sdk.UnionService; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.JobStatus; -import com.welab.wefe.common.enums.TaskStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.JobStatus; +import com.welab.wefe.common.wefe.enums.TaskStatus; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; import org.springframework.stereotype.Service; @@ -54,10 +52,6 @@ public class JobService extends AbstractService { @Autowired JobMemberRepository jobMemberRepo; @Autowired - DataSetRepository dataSetRepository; - @Autowired - UnionService unionService; - @Autowired JobMemberService jobMemberService; @Autowired private ProjectFlowService projectFlowService; @@ -199,7 +193,7 @@ private void checkCacheEnableStatus(FlowGraph graph, JobMySqlModel lastJob) thro .stream() .filter(x -> x.getParamsVersion() >= lastJobCreateTime) .forEach(x -> x.setHasCacheResult(false)); - + List nodes = graph.getAllJobSteps(); Collections.sort(nodes, new Comparator() { @Override diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/MemberChatService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/MemberChatService.java index b689f1809..45dbcf32f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/MemberChatService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/MemberChatService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,21 +22,24 @@ import com.welab.wefe.board.service.database.entity.chat.ChatLastAccountMysqlModel; import com.welab.wefe.board.service.database.entity.chat.MemberChatMySqlModel; import com.welab.wefe.board.service.database.entity.chat.MessageQueueMySqlModel; -import com.welab.wefe.board.service.database.repository.*; +import com.welab.wefe.board.service.database.repository.ChatUnreadMessageRepository; +import com.welab.wefe.board.service.database.repository.MemberChatRepository; +import com.welab.wefe.board.service.database.repository.MessageQueueRepository; +import com.welab.wefe.board.service.database.repository.MessageRepository; import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.dto.entity.MemberChatOutputModel; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.GatewayActionType; -import com.welab.wefe.common.enums.GatewayProcessorType; -import com.welab.wefe.common.enums.OrderBy; -import com.welab.wefe.common.enums.ProducerType; +import com.welab.wefe.common.data.mysql.enums.OrderBy; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.web.CurrentAccount; -import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.service.account.AccountInfo; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.GatewayActionType; +import com.welab.wefe.common.wefe.enums.GatewayProcessorType; +import com.welab.wefe.common.wefe.enums.ProducerType; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.domain.Page; @@ -130,9 +133,6 @@ public JObject sendMessage(String fromAccountId, String fromAccountName, String .toString(); - // Push the message to the destination member through the gateway - ApiResult result = gatewayService.sendToOtherGateway(toMemberId, GatewayActionType.create_chat_msg, data, GatewayProcessorType.dbChatTableProcessor); - Date createdTime = new Date(); // Message detail object MemberChatMySqlModel memberChatModel = new MemberChatMySqlModel(); @@ -151,12 +151,18 @@ public JObject sendMessage(String fromAccountId, String fromAccountName, String memberChatModel.setUpdatedTime(createdTime); memberChatModel.setMessageId(messageId); - // Message sending failed - if (!result.success()) { + + // Push the message to the destination member through the gateway + try { + gatewayService.sendToOtherGateway(toMemberId, GatewayActionType.create_chat_msg, data, GatewayProcessorType.dbChatTableProcessor); + } catch (Exception e) { + // Message sending failed memberChatModel.setStatus(ChatConstant.MESSAGE_STATUS_SEND_FAIL); ret.append(ChatConstant.KEY_CODE, StatusCode.SYSTEM_ERROR.getCode()) - .append(ChatConstant.KEY_MESSAGE, result.getMessage()); + .append(ChatConstant.KEY_MESSAGE, e.getMessage()); } + + ret.append(ChatConstant.KEY_MEMBER_CHAT_ID, memberChatModel.getId()); // Save message details @@ -173,7 +179,7 @@ public JObject sendMessage(String fromAccountId, String fromAccountName, String public JObject sendMessage(String toMemberId, String toMemberName, String toAccountId, String toAccountName, String content) throws StatusCodeWithException { - CurrentAccount.Info info = CurrentAccount.get(); + AccountInfo info = CurrentAccount.get(); if (null == info) { throw new StatusCodeWithException("请登录后访问", StatusCode.LOGIN_REQUIRED); } @@ -210,11 +216,8 @@ public void resendMessage(String memberChatId) throws StatusCodeWithException { .toString(); // Push the message to the destination member through the gateway - ApiResult result = gatewayService.sendToOtherGateway(model.getToMemberId(), GatewayActionType.create_chat_msg, data, GatewayProcessorType.dbChatTableProcessor); - // Message sending failed - if (!result.success()) { - throw new StatusCodeWithException(result.getMessage(), StatusCode.RPC_ERROR); - } + gatewayService.sendToOtherGateway(model.getToMemberId(), GatewayActionType.create_chat_msg, data, GatewayProcessorType.dbChatTableProcessor); + // Update message status is successful memberChatRepository.updateById(model.getId(), "status", ChatConstant.MESSAGE_STATUS_SEND_SUCCESS, MemberChatMySqlModel.class, false); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/MessageService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/MessageService.java index 4b6e2dd05..c21f9850e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/MessageService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/MessageService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ModelOotRecordService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ModelOotRecordService.java index 17a25a93b..4bac2de19 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ModelOotRecordService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ModelOotRecordService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/OperationLogService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/OperationLogService.java index 230f62310..6b6caf9a5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/OperationLogService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/OperationLogService.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,14 @@ package com.welab.wefe.board.service.service; -import com.welab.wefe.board.service.api.operation.QueryApi; +import com.welab.wefe.board.service.api.operation.LogQueryApi; import com.welab.wefe.board.service.database.entity.OperationLogMysqlModel; import com.welab.wefe.board.service.database.repository.OperationLogRepository; import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.dto.entity.OperationLogOutputModel; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.OrderBy; +import com.welab.wefe.common.data.mysql.enums.OrderBy; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.web.CurrentAccount; import org.springframework.beans.factory.annotation.Autowired; @@ -39,15 +39,15 @@ public class OperationLogService extends AbstractService { @Autowired OperationLogRepository mOperationLogRepository; - public PagingOutput query(QueryApi.Input input) throws StatusCodeWithException { + public PagingOutput query(LogQueryApi.Input input) throws StatusCodeWithException { if (!CurrentAccount.isAdmin()) { StatusCode.PERMISSION_DENIED.throwException("普通用户无法进行此操作。"); } Specification where = Where .create() - .equal("operatorPhone", input.getOperatorPhone()) - .equal("logAction", input.getAction()) + .equal("logInterface", input.logInterface) + .equal("operatorId", input.operatorId) .betweenAndDate("createdTime", input.getStartTime(), input.getEndTime()) .orderBy("createdTime", OrderBy.desc) .build(OperationLogMysqlModel.class); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectDataSetAuditService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectDataSetAuditService.java index c26dc31cb..74aa3f820 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectDataSetAuditService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectDataSetAuditService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,19 +16,19 @@ package com.welab.wefe.board.service.service; -import java.util.List; - -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import org.springframework.transaction.annotation.Transactional; - import com.welab.wefe.board.service.api.project.dataset.AuditDataSetApi; import com.welab.wefe.board.service.api.project.dataset.AuditDataSetApi.Input; import com.welab.wefe.board.service.database.entity.job.ProjectDataSetMySqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectMySqlModel; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.AuditStatus; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.util.List; /** * @author zane.luo @@ -70,7 +70,7 @@ public synchronized void auditDataSet(Input input) throws StatusCodeWithExceptio } ProjectDataSetMySqlModel dataSet = dataSets.stream().filter(d -> d.getAuditStatus() == AuditStatus.auditing) .findFirst().orElse(null); - + if (dataSet == null || dataSet.getAuditStatus() != AuditStatus.auditing) { throw new StatusCodeWithException("请勿重复审核!", StatusCode.ILLEGAL_REQUEST); } @@ -84,7 +84,7 @@ public synchronized void auditDataSet(Input input) throws StatusCodeWithExceptio projectDataSetService.update(dataSet, (x) -> x.setAuditStatus(input.getAuditStatus())); // Update the number of data sets used in the project - dataSetService.updateUsageCountInProject(dataSet.getDataSetId()); + tableDataSetService.updateUsageCountInProject(dataSet.getDataSetId()); gatewayService.syncToNotExistedMembers(input.getProjectId(), input, AuditDataSetApi.class); @@ -92,6 +92,6 @@ public synchronized void auditDataSet(Input input) throws StatusCodeWithExceptio } @Autowired - private DataSetService dataSetService; + private TableDataSetService tableDataSetService; } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectDataSetService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectDataSetService.java index e41e94beb..cf8aae39d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectDataSetService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectDataSetService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,29 +16,34 @@ package com.welab.wefe.board.service.service; -import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.api.gateway.GetDerivedDataSetDetailApi; import com.welab.wefe.board.service.api.project.dataset.QueryDerivedDataSetApi; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectDataSetMySqlModel; import com.welab.wefe.board.service.database.repository.ProjectDataSetRepository; import com.welab.wefe.board.service.dto.base.PagingOutput; -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.DataResourceOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.ImageDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.TableDataSetOutputModel; import com.welab.wefe.board.service.dto.entity.job.JobMemberOutputModel; -import com.welab.wefe.board.service.dto.entity.project.DerivedProjectDataSetOutputModel; -import com.welab.wefe.board.service.dto.entity.project.ProjectDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.project.data_set.DerivedProjectDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.project.data_set.ProjectDataResourceOutputModel; import com.welab.wefe.board.service.dto.vo.JobMemberWithDataSetOutputModel; -import com.welab.wefe.board.service.exception.MemberGatewayException; -import com.welab.wefe.board.service.util.ModelMapper; +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterService; +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetService; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.OrderBy; +import com.welab.wefe.common.data.mysql.enums.OrderBy; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.CurrentAccount; -import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; import org.springframework.stereotype.Service; @@ -54,7 +59,11 @@ public class ProjectDataSetService extends AbstractService { @Autowired - private DataSetService dataSetService; + private TableDataSetService tableDataSetService; + @Autowired + private ImageDataSetService imageDataSetService; + @Autowired + private BloomFilterService bloomFilterService; @Autowired private ProjectDataSetRepository projectDataSetRepo; @@ -63,7 +72,8 @@ public class ProjectDataSetService extends AbstractService { * Get the details of the derived data set */ public DerivedProjectDataSetOutputModel getDerivedDataSetDetail(GetDerivedDataSetDetailApi.Input input) throws StatusCodeWithException { - DataSetMysqlModel dataSet = dataSetService.findOne(input.getDataSetId()); + // 衍生数据集目前只有 TableDataSet + TableDataSetMysqlModel dataSet = tableDataSetService.findOneById(input.getDataSetId()); ProjectDataSetMySqlModel projectDataSet = findOne(input.getProjectId(), input.getDataSetId(), input.getMemberRole()); if (dataSet == null || projectDataSet == null) { @@ -74,17 +84,12 @@ public DerivedProjectDataSetOutputModel getDerivedDataSetDetail(GetDerivedDataSe throw new StatusCodeWithException("拒绝查询原始数据集信息", StatusCode.PARAMETER_VALUE_INVALID); } - JObject json = JObject.create(); - json.putAll(JObject.create(dataSet)); - json.putAll(JObject.create(projectDataSet)); + List members = ModelMapper.maps(jobMemberService.list(dataSet.getDerivedFromJobId(), false), JobMemberWithDataSetOutputModel.class); - DerivedProjectDataSetOutputModel output = json.toJavaObject(DerivedProjectDataSetOutputModel.class); - List members = ModelMapper.maps(jobMemberService.list(dataSet.getSourceJobId(), false), JobMemberWithDataSetOutputModel.class); + DerivedProjectDataSetOutputModel output = ModelMapper.map(projectDataSet, DerivedProjectDataSetOutputModel.class); + output.setDataResource(ModelMapper.map(dataSet, TableDataSetOutputModel.class)); output.setMembers(members); - LOG.info("members:" + JSON.toJSONString(members, true)); - LOG.info("input.callerMemberInfo:" + JSON.toJSONString(input.callerMemberInfo, true)); - return output; } @@ -95,6 +100,7 @@ public PagingOutput queryDerivedDataSet(QueryD Where where = Where .create() .equal("projectId", input.getProjectId()) + .equal("dataResourceType", input.getDataResourceType()) .equal("dataSetId", input.getDataSetId()) .equal("sourceFlowId", input.getSourceFlowId()) .equal("sourceJobId", input.getSourceJobId()); @@ -119,51 +125,60 @@ public PagingOutput queryDerivedDataSet(QueryD } + /** + * 组装衍生数据集 output 对象 + *

+ * tips:衍生数据集目前只有 TableDataSet 类型 + */ private DerivedProjectDataSetOutputModel buildDerivedProjectDataSetOutputModel(ProjectDataSetMySqlModel projectDataSet) { - DataSetMysqlModel dataSet = dataSetService.findOne(projectDataSet.getDataSetId()); - - JObject json = JObject.create(); - if (dataSet != null) { - json.putAll(JObject.create(dataSet)); - } - json.putAll(JObject.create(projectDataSet)); + DerivedProjectDataSetOutputModel derivedDataSet = new DerivedProjectDataSetOutputModel(); + BeanUtils.copyProperties(projectDataSet, derivedDataSet); // Create a derived dataset object - DerivedProjectDataSetOutputModel derivedDataSet = json.toJavaObject(DerivedProjectDataSetOutputModel.class); - derivedDataSet.setSourceTypeCn( - derivedDataSet.getSourceType() != null ? derivedDataSet.getSourceType().getLabel() : ""); + TableDataSetMysqlModel dataSet = tableDataSetService.findOneById(projectDataSet.getDataSetId()); if (dataSet != null) { + derivedDataSet.setDataResource(JObject.create(dataSet).toJavaObject(TableDataSetOutputModel.class)); // Query the feature list from each member - List jobMembers = jobMemberService.list(dataSet.getSourceJobId(), false); + List jobMembers = jobMemberService.list(dataSet.getDerivedFromJobId(), false); List output = jobMembers .stream() .map(m -> { JobMemberWithDataSetOutputModel member = ModelMapper.map(m, JobMemberWithDataSetOutputModel.class); + // Take your own feature list directly + TableDataSetOutputModel tableDataSet = null; if (member.getMemberId().equals(derivedDataSet.getMemberId())) { - member.setFeatureNameList(derivedDataSet.getFeatureNameList()); - member.setFeatureCount(derivedDataSet.getFeatureCount()); + tableDataSet = (TableDataSetOutputModel) derivedDataSet.getDataResource(); } // Others’ feature list should be checked remotely else { try { - ApiResult apiResult = gatewayService.callOtherMemberBoard( + + JSONObject derivedProjectDataSet = gatewayService.callOtherMemberBoard( member.getMemberId(), GetDerivedDataSetDetailApi.class, - new GetDerivedDataSetDetailApi.Input(member.getProjectId(), projectDataSet.getDataSetId(), member.getJobRole()) + new GetDerivedDataSetDetailApi.Input(member.getProjectId(), projectDataSet.getDataSetId(), member.getJobRole()), + /** + * 这里不能直接指定为 DerivedProjectDataSetOutputModel.class, + * 因为 dataResource 字段类型为 DataResourceOutputModel, + * 这是个父类,反射成这个对象会缺字段。 + * + * 要取 json 节点手动反射为 TableDataSetOutputModel + */ + JSONObject.class ); - if (apiResult.data != null) { - DerivedProjectDataSetOutputModel derivedProjectDataSet = ((JSONObject) apiResult.data).toJavaObject(DerivedProjectDataSetOutputModel.class); - member.setFeatureNameList(derivedProjectDataSet.getFeatureNameList()); - member.setFeatureCount(derivedProjectDataSet.getFeatureCount()); - } - - } catch (MemberGatewayException e) { + tableDataSet = derivedProjectDataSet.getJSONObject("data_resource").toJavaObject(TableDataSetOutputModel.class); + } catch (Exception e) { super.log(e); } } + if (tableDataSet != null) { + member.setFeatureNameList(tableDataSet.getFeatureNameList()); + member.setFeatureCount(tableDataSet.getFeatureCount()); + } + return member; }) .collect(Collectors.toList()); @@ -177,16 +192,21 @@ private DerivedProjectDataSetOutputModel buildDerivedProjectDataSetOutputModel(P @Autowired private JobMemberService jobMemberService; + public List listRawDataSet(String projectId, DataResourceType dataResourceType, String memberId, JobMemberRole memberRole, Boolean containsY) { + return listRawDataSet(projectId, dataResourceType, memberId, memberRole, containsY, null); + } + /** * Display the list of data sets of the specified members in the project *

* When memberId is empty, check the data sets of all members. */ - public List listRawDataSet(String projectId, String memberId, JobMemberRole memberRole, Boolean containsY) { + public List listRawDataSet(String projectId, DataResourceType dataResourceType, String memberId, JobMemberRole memberRole, Boolean containsY, DeepLearningJobType forJobType) { Specification where = Where .create() .equal("projectId", projectId) + .equal("dataResourceType", dataResourceType) .equal("memberId", memberId) .equal("memberRole", memberRole) .equal("sourceType", null, false) @@ -195,38 +215,46 @@ public List listRawDataSet(String projectId, String m List list = projectDataSetRepo.findAll(where); - List output = list + List output = list .parallelStream() .map(x -> { - DataSetOutputModel dataSet = null; + try { - dataSet = dataSetService.findDataSetFromLocalOrUnion(x.getMemberId(), x.getDataSetId()); - // The data set does not exist and is marked as deleted. - if (dataSet == null) { - ProjectDataSetOutputModel foo = JObject - .create(x) - .toJavaObject(ProjectDataSetOutputModel.class); - foo.setName("此数据集已被删除或不可见"); - foo.setRowCount(0L); - foo.setDeleted(true); - return foo; + ProjectDataResourceOutputModel projectDataResource = ModelMapper.map(x, ProjectDataResourceOutputModel.class); + DataResourceOutputModel dataResource = null; + if (x.getDataResourceType() == DataResourceType.TableDataSet) { + dataResource = tableDataSetService.findDataSetFromLocalOrUnion(x.getMemberId(), x.getDataSetId()); + } else if (x.getDataResourceType() == DataResourceType.ImageDataSet) { + dataResource = imageDataSetService.findDataSetFromLocalOrUnion(x.getMemberId(), x.getDataSetId()); + } else if (x.getDataResourceType() == DataResourceType.BloomFilter) { + dataResource = bloomFilterService.findDataSetFromLocalOrUnion(x.getMemberId(), x.getDataSetId()); + } + // 如果这里没有拿到数据集信息,说明数据集已经被删除或者不可见。 + if (dataResource == null) { + dataResource = new DataResourceOutputModel(); + String name = CacheObjects.getMemberId().equals(projectDataResource.getMemberId()) + ? "资源已被删除" + : "资源已被删除或不可见"; + dataResource.setName(name); + dataResource.setId(projectDataResource.getDataSetId()); + dataResource.setDeleted(true); } + projectDataResource.setDataResource(dataResource); + return projectDataResource; } catch (StatusCodeWithException e) { - e.printStackTrace(); - + super.log(e); return null; } - JObject item = JObject.create(); - item.putAll(JObject.create(dataSet)); - item.putAll(JObject.create(x)); - - return item.toJavaObject(ProjectDataSetOutputModel.class); - }) .filter(x -> { - if (containsY != null) { - return containsY.equals(x.getContainsY()); + if (containsY != null && (x.getDataResource() instanceof TableDataSetOutputModel)) { + return containsY.equals(((TableDataSetOutputModel) x.getDataResource()).isContainsY()); + } + return true; + }).filter(x -> { + if (forJobType != null && (x.getDataResource() instanceof ImageDataSetOutputModel)) { + return forJobType.equals(((ImageDataSetOutputModel) x.getDataResource()).getForJobType()); } return true; }) @@ -256,43 +284,28 @@ public List listAllRawDataSet(String projectId, String *

* When memberId is empty, check the data sets of all members. */ - public List list(String projectId, String memberId) { + public List list(String projectId, DataResourceType dataResourceType, String memberId) { Specification where = Where .create() .equal("projectId", projectId) + .equal("dataResourceType", dataResourceType) .equal("memberId", memberId) .build(ProjectDataSetMySqlModel.class); List list = projectDataSetRepo.findAll(where); - List output = list + List output = list .parallelStream() .map(x -> { - DataSetOutputModel dataSet = null; + ProjectDataResourceOutputModel projectDataSet = ModelMapper.map(x, ProjectDataResourceOutputModel.class); try { - dataSet = dataSetService.findDataSetFromLocalOrUnion(x.getMemberId(), x.getDataSetId()); - // The data set does not exist and is marked as deleted. - if (dataSet == null) { - ProjectDataSetOutputModel foo = JObject - .create(x) - .toJavaObject(ProjectDataSetOutputModel.class); - foo.setName("此数据集已被删除或不可见"); - foo.setRowCount(0L); - foo.setDeleted(true); - return foo; - } + TableDataSetOutputModel dataSet = tableDataSetService.findDataSetFromLocalOrUnion(x.getMemberId(), x.getDataSetId()); + projectDataSet.setDataResource(dataSet); } catch (StatusCodeWithException e) { - e.printStackTrace(); - - return null; + super.log(e); } - - JObject item = JObject.create(); - item.putAll(JObject.create(dataSet)); - item.putAll(JObject.create(x)); - - return item.toJavaObject(ProjectDataSetOutputModel.class); + return projectDataSet; }) .collect(Collectors.toList()); @@ -341,7 +354,7 @@ public ProjectDataSetMySqlModel findOne(String projectId, String dataSetId) { .build(ProjectDataSetMySqlModel.class) ).orElse(null); } - + public List findAll(String projectId, String dataSetId) { return projectDataSetRepo.findAll(Where.create().equal("projectId", projectId).equal("dataSetId", dataSetId) .build(ProjectDataSetMySqlModel.class)); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowJobService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowJobService.java index a4eb36707..d477f22ef 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowJobService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowJobService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,23 +16,7 @@ package com.welab.wefe.board.service.service; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.Date; -import java.util.HashSet; -import java.util.List; -import java.util.UUID; -import java.util.stream.Collectors; - -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.BeanUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import org.springframework.transaction.annotation.Transactional; - -import com.welab.wefe.board.service.api.member.ServiceStatusCheckApi; +import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.api.project.flow.StartFlowApi; import com.welab.wefe.board.service.api.project.job.ResumeJobApi; import com.welab.wefe.board.service.api.project.job.StopJobApi; @@ -40,43 +24,40 @@ import com.welab.wefe.board.service.component.DataIOComponent; import com.welab.wefe.board.service.component.OotComponent; import com.welab.wefe.board.service.component.base.AbstractComponent; -import com.welab.wefe.board.service.constant.Config; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; -import com.welab.wefe.board.service.database.entity.job.JobMemberMySqlModel; -import com.welab.wefe.board.service.database.entity.job.JobMySqlModel; -import com.welab.wefe.board.service.database.entity.job.ProjectDataSetMySqlModel; -import com.welab.wefe.board.service.database.entity.job.ProjectFlowMySqlModel; -import com.welab.wefe.board.service.database.entity.job.ProjectFlowNodeMySqlModel; -import com.welab.wefe.board.service.database.entity.job.ProjectMySqlModel; -import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; -import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; -import com.welab.wefe.board.service.database.repository.JobMemberRepository; -import com.welab.wefe.board.service.database.repository.JobRepository; -import com.welab.wefe.board.service.database.repository.ProjectFlowRepository; -import com.welab.wefe.board.service.database.repository.TaskRepository; -import com.welab.wefe.board.service.database.repository.TaskResultRepository; -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; -import com.welab.wefe.board.service.dto.kernel.Env; -import com.welab.wefe.board.service.dto.kernel.JobDataSet; -import com.welab.wefe.board.service.dto.kernel.KernelJob; +import com.welab.wefe.board.service.component.base.dto.AbstractDataIOParam; +import com.welab.wefe.board.service.component.base.dto.AbstractDataSetItem; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.job.*; +import com.welab.wefe.board.service.database.repository.*; +import com.welab.wefe.board.service.dto.entity.data_resource.output.DataResourceOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.TableDataSetOutputModel; import com.welab.wefe.board.service.dto.kernel.Member; -import com.welab.wefe.board.service.dto.kernel.Project; +import com.welab.wefe.board.service.dto.kernel.machine_learning.Env; +import com.welab.wefe.board.service.dto.kernel.machine_learning.JobDataSet; +import com.welab.wefe.board.service.dto.kernel.machine_learning.KernelJob; +import com.welab.wefe.board.service.dto.kernel.machine_learning.Project; import com.welab.wefe.board.service.dto.vo.JobArbiterInfo; -import com.welab.wefe.board.service.dto.vo.MemberServiceStatusOutput; import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; +import com.welab.wefe.board.service.service.data_resource.DataResourceService; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.FlowActionType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.JobStatus; -import com.welab.wefe.common.enums.ProjectFlowStatus; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.DateUtil; import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.wefe.checkpoint.dto.MemberAvailableCheckOutput; +import com.welab.wefe.common.wefe.enums.*; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.util.*; +import java.util.stream.Collectors; /** * @author winter.zou @@ -102,13 +83,11 @@ public class ProjectFlowJobService extends AbstractService { @Autowired private ProjectFlowRepository projectFlowRepo; @Autowired - private Config config; - @Autowired private ProjectFlowNodeService projectFlowNodeService; @Autowired private ProjectService projectService; @Autowired - private DataSetService dataSetService; + private DataResourceService dataResourceService; @Autowired private ProjectFlowService projectFlowService; @Autowired @@ -117,13 +96,14 @@ public class ProjectFlowJobService extends AbstractService { private ProjectDataSetService projectDataSetService; @Autowired private ServiceCheckService serviceCheckService; + @Autowired + private TableDataSetService tableDataSetService; + public static final int MIX_FLOW_PROMOTER_NUM = 2; /** * start flow - * - * @return jobId */ @Transactional(rollbackFor = Exception.class) public synchronized String startFlow(StartFlowApi.Input input) throws StatusCodeWithException { @@ -132,7 +112,6 @@ public synchronized String startFlow(StartFlowApi.Input input) throws StatusCode if (flow == null) { throw new StatusCodeWithException("未找到相应的流程!", StatusCode.ILLEGAL_REQUEST); } - ProjectMySqlModel project = projectService.findByProjectId(flow.getProjectId()); if (!input.fromGateway() && (!isCreator(flow, project))) { @@ -160,10 +139,11 @@ public synchronized String startFlow(StartFlowApi.Input input) throws StatusCode throw new StatusCodeWithException("当前任务不包含我方数据集,无法启动。", StatusCode.PARAMETER_VALUE_INVALID); } } - long memberCount = jobMembers.stream().filter(x -> x.getJobRole() != JobMemberRole.arbiter).count(); - if (memberCount < MIX_FLOW_PROMOTER_NUM && !isOotMode) { - throw new StatusCodeWithException("需要在【" + ComponentType.DataIO.getLabel() + "】中选择两个或两个以上的数据集", StatusCode.PARAMETER_VALUE_INVALID); - } + long memberCount = jobMembers.stream().filter(x -> x.getJobRole() != JobMemberRole.arbiter).count(); + if (memberCount < 2 && !isOotMode) { + throw new StatusCodeWithException("需要在【" + ComponentType.DataIO.getLabel() + "】中选择两个或两个以上的数据集", + StatusCode.PARAMETER_VALUE_INVALID); + } long promoterMemberCount = jobMembers.stream().filter(x -> x.getJobRole() == JobMemberRole.promoter).count(); if (promoterMemberCount >= MIX_FLOW_PROMOTER_NUM && !flow.getFederatedLearningType().equals(FederatedLearningType.mix)) { throw new StatusCodeWithException("【选择数据集】组件参数错误,请先移除再重新添加", StatusCode.PARAMETER_VALUE_INVALID); @@ -179,20 +159,20 @@ public synchronized String startFlow(StartFlowApi.Input input) throws StatusCode JobMySqlModel job = createJob(flow, input.getJobId(), jobMember.getJobRole()); // create Graph - FlowGraph graph = new FlowGraph(job, lastJob, jobMembers, projectFlowNodeService.findNodesByFlowId(job.getFlowId())); + FlowGraph graph = new FlowGraph(job, lastJob, jobMembers, projectFlowNodeService.findNodesByFlowId(job.getFlowId()), flow.getCreatorMemberId()); // check if (jobMember.getJobRole() == JobMemberRole.promoter) { checkBeforeStartFlow(graph, project, isOotMode); } // create task - createJobTasks(graph, input.isUseCache(), input.getEndNodeId(), flow.getFederatedLearningType()); + createJobTasks(project, graph, input.isUseCache(), input.getEndNodeId(), flow.getFederatedLearningType()); } gatewayService.syncToOtherJobMembers(input.getJobId(), input, StartFlowApi.class); - flowActionQueueService.notifyFlow(input, input.getJobId(), FlowActionType.run_job); + flowActionQueueService.runJob(input, input.getJobId(), project.getProjectType()); //update flow projectFlowService.updateFlowStatus(flow.getFlowId(), ProjectFlowStatus.running); @@ -203,13 +183,19 @@ public synchronized String startFlow(StartFlowApi.Input input) throws StatusCode public boolean isCreator(ProjectFlowMySqlModel flow, ProjectMySqlModel project) { return JobMemberRole.promoter == project.getMyRole() - && CacheObjects.isCurrentMember(flow.getCreatedBy()); + && CacheObjects.isCurrentMemberAccount(flow.getCreatedBy()); } public JobArbiterInfo calcArbiterInfo(ProjectFlowMySqlModel flow, StartFlowApi.Input input, ProjectMySqlModel project) { JobArbiterInfo info = new JobArbiterInfo(); info.setHasArbiter(false); + + // 深度学习没有 arbiter 角色 + if (project.getProjectType() == ProjectType.DeepLearning) { + return info; + } + if (flow.getFederatedLearningType() == FederatedLearningType.horizontal) { if (project.getMyRole() == JobMemberRole.promoter) { info.setHasArbiter(true); @@ -227,13 +213,20 @@ public JobArbiterInfo calcArbiterInfo(ProjectFlowMySqlModel flow, StartFlowApi.I * Check the effectiveness of the task before starting the task. */ private void checkBeforeStartFlow(FlowGraph graph, ProjectMySqlModel project, boolean isOotMode) throws StatusCodeWithException { + if (CollectionUtils.isEmpty(graph.getStartNodes())) { throw new StatusCodeWithException("流程中没有起始节点,无法执行该流程。", StatusCode.PARAMETER_VALUE_INVALID); } - if (graph.getStartNodes().stream().noneMatch(x -> (x.getComponentType() == ComponentType.DataIO - || x.getComponentType() == ComponentType.Oot))) { - throw new StatusCodeWithException("起始节点必须包含 " + ComponentType.DataIO.getLabel() + ",否则无法执行流程。", StatusCode.PARAMETER_VALUE_INVALID); + boolean hasDataSet = graph.getStartNodes() + .stream() + .anyMatch(x -> + x.getComponentType() == ComponentType.DataIO + || x.getComponentType() == ComponentType.Oot + || x.getComponentType() == ComponentType.ImageDataIO + ); + if (!hasDataSet) { + throw new StatusCodeWithException("流程起点必须包含数据集加载,否则无法执行流程。", StatusCode.PARAMETER_VALUE_INVALID); } if (isOotMode) { @@ -244,15 +237,13 @@ private void checkBeforeStartFlow(FlowGraph graph, ProjectMySqlModel project, bo // Check whether the services of each member are available for (JobMemberMySqlModel member : graph.getMembers()) { - ServiceStatusCheckApi.Output status = serviceCheckService.checkMemberServiceStatus(new ServiceStatusCheckApi.Input(member.getMemberId())); - MemberServiceStatusOutput errorService = status.getStatus().values().stream().filter(x -> !x.isSuccess()).findFirst().orElse(null); - if (errorService != null) { - throw new StatusCodeWithException("成员 " - + CacheObjects.getMemberName(member.getMemberId()) - + " 的 " + errorService.getService().name() + " 服务不可用:" - + errorService.getMessage(), - - StatusCode.REMOTE_SERVICE_ERROR + MemberAvailableCheckOutput status = serviceCheckService.getMemberAvailableInfo(member.getMemberId()); + if (!status.available) { + StatusCode.REMOTE_SERVICE_ERROR.throwException( + "成员 " + + CacheObjects.getMemberName(member.getMemberId()) + + " 的 " + status.errorServiceType.name() + " 服务不可用:" + + status.message ); } } @@ -273,14 +264,8 @@ private void checkBeforeStartFlow(FlowGraph graph, ProjectMySqlModel project, bo if (CacheObjects.getMemberId().equals(member.getMemberId())) { ProjectDataSetMySqlModel projectDataSet = projectDataSetService.findOne(project.getProjectId(), member.getDataSetId(), member.getJobRole()); - if (projectDataSet == null) { - throw new StatusCodeWithException("成员【" + memberName + " - " + member.getJobRole().name() + "】的数据集 " + member.getDataSetId() + " 不存在,可能已删除。", StatusCode.PARAMETER_VALUE_INVALID); - } - - DataSetOutputModel dataSet = dataSetService.findDataSetFromLocalOrUnion(member.getMemberId(), member.getDataSetId()); - if (dataSet == null) { - throw new StatusCodeWithException("成员【" + memberName + " - " + member.getJobRole().name() + "】的数据集 " + member.getDataSetId() + " 不存在,可能已被删除。", StatusCode.PARAMETER_VALUE_INVALID); + throw new StatusCodeWithException("成员【" + memberName + " - " + member.getJobRole().name() + "】的数据集 " + member.getDataSetId() + " 不存在,可能已删除或移除了授权。", StatusCode.PARAMETER_VALUE_INVALID); } } @@ -295,8 +280,8 @@ private void checkBeforeStartFlow(FlowGraph graph, ProjectMySqlModel project, bo if (projectDataSet.getSourceType() != null) { continue; } else { - DataSetOutputModel dataSet = dataSetService.findDataSetFromLocalOrUnion(member.getMemberId(), member.getDataSetId()); - if (dataSet == null) { + DataResourceOutputModel resource = dataResourceService.findDataResourceFromLocalOrUnion(projectDataSet); + if (resource == null) { throw new StatusCodeWithException("成员【" + memberName + "】的数据集 " + member.getDataSetId() + " 不存在,可能已被删除或不可见。", StatusCode.PARAMETER_VALUE_INVALID); } } @@ -305,11 +290,8 @@ private void checkBeforeStartFlow(FlowGraph graph, ProjectMySqlModel project, bo throw new StatusCodeWithException("成员【" + memberName + "】的数据集 " + member.getDataSetId() + " 尚未授权,不可使用。", StatusCode.PARAMETER_VALUE_INVALID); } } - - } - } } @@ -334,15 +316,35 @@ public synchronized void resumeJob(ResumeJobApi.Input input) throws StatusCodeWi throw new StatusCodeWithException("当前状态不允许进行继续任务操作!", StatusCode.ILLEGAL_REQUEST); } + // 如果是深度学习的任务,把之前任务的配置改为继续,就可以实现续跑。 + if (project.getProjectType() == ProjectType.DeepLearning) { + List tasks = taskService.listByJobId(job.getJobId(), job.getMyRole()); + tasks + .stream() + .filter(x -> x.getTaskType() == ComponentType.PaddleClassify || x.getTaskType() == ComponentType.PaddleDetection) + .filter(x -> x.getStatus() != TaskStatus.success) + .forEach(x -> { + JSONObject taskConfig = JSONObject.parseObject(x.getTaskConf()); + taskConfig.getJSONObject("env").put("resume", true); + x.setTaskConf(taskConfig.toJSONString()); + x.setMessage("resume task(" + DateUtil.getCurrentDate() + ")"); + x.setStatus(TaskStatus.wait_run); + taskRepository.save(x); + }); + + } + + jobs.forEach(y -> jobService.updateJob(y, (x) -> { x.setUpdatedBy(input); x.setStatus(JobStatus.wait_run); + x.setMessage("resume job(" + DateUtil.getCurrentDate() + ")"); return x; }) ); - - flowActionQueueService.notifyFlow(input, input.getJobId(), FlowActionType.run_job); + projectFlowService.updateFlowStatus(job.getFlowId(), ProjectFlowStatus.wait_run); + flowActionQueueService.runJob(input, input.getJobId(), project.getProjectType()); gatewayService.syncToOtherJobMembers(job.getJobId(), input, ResumeJobApi.class); @@ -383,7 +385,7 @@ public synchronized void stopFlowJob(StopJobApi.Input input) throws StatusCodeWi projectFlowService.updateFlowStatus(job.getFlowId(), ProjectFlowStatus.stop_on_running); - flowActionQueueService.notifyFlow(input, input.getJobId(), FlowActionType.stop_job); + flowActionQueueService.stopJob(input, input.getJobId(), project.getProjectType()); gatewayService.syncToOtherJobMembers(job.getJobId(), input, StopJobApi.class); @@ -410,7 +412,7 @@ private JobMySqlModel createJob(ProjectFlowMySqlModel flow, String jobId, JobMem } - private List createJobTasks(FlowGraph graph, boolean useCache, String endNodeId, + private List createJobTasks(ProjectMySqlModel project, FlowGraph graph, boolean useCache, String endNodeId, FederatedLearningType federatedLearningType) throws StatusCodeWithException { List startNodes = graph.getStartNodes(); @@ -490,7 +492,7 @@ private List createJobTasks(FlowGraph graph, boolean useCache, S tasks.addAll(subTasks); } } else { - TaskMySqlModel task = component.buildTask(graph, tasks, kernelJob, node); + TaskMySqlModel task = component.buildTask(project, graph, tasks, kernelJob, node); if (task != null) { tasks.add(task); } @@ -498,17 +500,17 @@ private List createJobTasks(FlowGraph graph, boolean useCache, S } catch (FlowNodeException e) { throw e; } catch (Exception e) { - throw new FlowNodeException(node, e.getMessage()); + super.log(e); + throw new FlowNodeException(node, e.getClass() + " " + e.getMessage()); } } /** - * If the first node to run is a modeling algorithm node and there is an available cache, - * you need to copy the previously failed task result to the current task. - * 1. Parameter specifies the use of caching(useCache == true) - * 2. The first task is the modeling node - * 3. last job is not empty, which indicates that this flow has been run before. - * 4. The modeling node has not been edited since the last job was created + * 如果第一个运行的节点是建模算法节点并且有可用的缓存,则需要将之前中断的任务结果复制到当前任务中。 + * 1. 参数指定使用缓存(useCache == true) + * 2. 第一个 task 是建模节点 + * 3. last job 不为空,表示该流程之前已经运行过。 + * 4. 自上次创建 Job 后,建模节点未编辑。 */ FlowGraphNode firstNode = noCacheNodes.get(0); if (useCache && firstNode.getComponentType().isModeling() && graph.getLastJob() != null) { @@ -537,7 +539,7 @@ private void updateDataSetUsageCountInJob(KernelJob kernelJob) throws StatusCode } for (String dataSetId : dataSetIds) { - dataSetService.usageCountInJobIncrement(dataSetId); + dataResourceService.usageCountInJobIncrement(dataSetId); } } @@ -563,17 +565,11 @@ private KernelJob createKernelJob(JobMySqlModel job, List m Project project = new Project(); project.setProjectId(job.getProjectId()); - Env env = new Env(); - env.setBackend(config.getBackend()); - env.setDbType(config.getDbType()); - env.setWorkMode(config.getWorkMode()); - env.setName(config.getEnvName()); - List dataSets = listJobDataSets(job, nodes); jobInfo.setFederatedLearningType(job.getFederatedLearningType()); jobInfo.setProject(project); - jobInfo.setMembers(memberList.stream().map(Member::new).collect(Collectors.toList())); + jobInfo.setMembers(Member.forMachineLearning(memberList)); Member arbiter = jobInfo .getMembers() @@ -593,16 +589,13 @@ private KernelJob createKernelJob(JobMySqlModel job, List m .orElse(null); if (promoter != null) { - arbiter = new Member(); - arbiter.setMemberId(promoter.getMemberId()); - arbiter.setMemberRole(JobMemberRole.arbiter); - arbiter.setMemberName(promoter.getMemberName()); + arbiter = Member.forMachineLearning(promoter.getMemberId(), JobMemberRole.arbiter); jobInfo.getMembers().add(arbiter); } } } - jobInfo.setEnv(env); + jobInfo.setEnv(Env.get()); jobInfo.setDataSets(dataSets); return jobInfo; @@ -628,20 +621,16 @@ private void addPreTasks(FlowGraphNode node, List tasks, List copyMixTaskInfoFromLastJob(JobMySqlModel oldJob, JobMySqlModel newJob, FlowGraphNode node, boolean copyTask) { - if (newJob == null) { return null; } - List oldTasks = taskService.findAll(oldJob.getJobId(), node.getNodeId(), oldJob.getMyRole()); - List newTasks = new ArrayList<>(); for (TaskMySqlModel oldTask : oldTasks) { - TaskMySqlModel newTask = null; int count = Integer.parseInt(oldTask.getTaskId().split("_")[oldTask.getTaskId().split("_").length - 1]); // copy task if (copyTask) { - newTask = new TaskMySqlModel(); + TaskMySqlModel newTask = new TaskMySqlModel(); BeanUtils.copyProperties(oldTask, newTask); newTask.setId(new TaskMySqlModel().getId()); newTask.setRole(newJob.getMyRole()); @@ -651,36 +640,39 @@ private List copyMixTaskInfoFromLastJob(JobMySqlModel oldJob, Jo newTask.setTaskId(node.createTaskId(newJob, count)); newTask.setParentTaskIdList(node.createParentTaskIds(newJob, count)); taskRepository.save(newTask); - - List oldResults = taskResultService.listAllResult(oldTask.getTaskId()); - // copy task_result - for (TaskResultMySqlModel oldResult : oldResults) { - - TaskResultMySqlModel newResult = new TaskResultMySqlModel(); - BeanUtils.copyProperties(oldResult, newResult); - - newResult.setId(new TaskResultMySqlModel().getId()); - newResult.setRole(newJob.getMyRole()); - newResult.setJobId(newJob.getJobId()); - newResult.setTaskId(node.createTaskId(newJob, count)); - taskResultRepository.save(newResult); - } - - DataSetMysqlModel dataSetModel = dataSetService.query(oldJob.getJobId(), node.getComponentType()); - if (dataSetModel != null) { - DataSetMysqlModel newDataSetModel = new DataSetMysqlModel(); - BeanUtils.copyProperties(dataSetModel, newDataSetModel); - newDataSetModel.setId(new DataSetMysqlModel().getId()); - newDataSetModel.setSourceJobId(newJob.getJobId()); - newDataSetModel.setSourceType(node.getComponentType()); - dataSetService.save(newDataSetModel); - } + if (newTask != null) { + newTasks.add(newTask); + } + } + + List oldResults = taskResultService.listAllResult(oldTask.getTaskId()); + // copy task_result + for (TaskResultMySqlModel oldResult : oldResults) { + + TaskResultMySqlModel newResult = new TaskResultMySqlModel(); + BeanUtils.copyProperties(oldResult, newResult); + + newResult.setId(new TaskResultMySqlModel().getId()); + newResult.setRole(newJob.getMyRole()); + newResult.setJobId(newJob.getJobId()); + newResult.setTaskId(node.createTaskId(newJob, count)); + taskResultRepository.save(newResult); } - if (newTask != null) { - newTasks.add(newTask); - } - } + List dataSetModels = tableDataSetService.queryAll(oldJob.getJobId(), + node.getComponentType()); + + if (CollectionUtils.isNotEmpty(dataSetModels)) { + for (TableDataSetMysqlModel dataSetModel : dataSetModels) { + TableDataSetMysqlModel newDataSetModel = new TableDataSetMysqlModel(); + BeanUtils.copyProperties(dataSetModel, newDataSetModel); + newDataSetModel.setId(new TableDataSetMysqlModel().getId()); + newDataSetModel.setDerivedFromJobId(newJob.getJobId()); + newDataSetModel.setDerivedFrom(node.getComponentType()); + tableDataSetService.save(newDataSetModel); + } + } + } return newTasks; } @@ -692,21 +684,15 @@ private List copyMixTaskInfoFromLastJob(JobMySqlModel oldJob, Jo */ private TaskMySqlModel copyNodeInfoFromLastJob(JobMySqlModel oldJob, JobMySqlModel newJob, FlowGraphNode node, boolean copyTask) { - if (newJob == null) { return null; } - TaskMySqlModel oldTask = taskService.findOne(oldJob.getJobId(), node.getNodeId(), oldJob.getMyRole()); - TaskMySqlModel newTask = null; - // copy task if (copyTask) { - newTask = new TaskMySqlModel(); BeanUtils.copyProperties(oldTask, newTask); - newTask.setId(new TaskMySqlModel().getId()); newTask.setRole(newJob.getMyRole()); newTask.setJobId(newJob.getJobId()); @@ -714,34 +700,37 @@ private TaskMySqlModel copyNodeInfoFromLastJob(JobMySqlModel oldJob, JobMySqlMod newTask.setPosition(node.getPosition()); newTask.setTaskId(node.createTaskId(newJob)); newTask.setParentTaskIdList(node.createParentTaskIds(newJob)); - taskRepository.save(newTask); - - List oldResults = taskResultService.listAllResult(oldTask.getTaskId()); - // copy task_result - for (TaskResultMySqlModel oldResult : oldResults) { + } - TaskResultMySqlModel newResult = new TaskResultMySqlModel(); - BeanUtils.copyProperties(oldResult, newResult); + List oldResults = taskResultService.listAllResult(oldTask.getTaskId()); + // copy task_result + for (TaskResultMySqlModel oldResult : oldResults) { - newResult.setId(new TaskResultMySqlModel().getId()); - newResult.setRole(newJob.getMyRole()); - newResult.setJobId(newJob.getJobId()); - newResult.setTaskId(node.createTaskId(newJob)); + TaskResultMySqlModel newResult = new TaskResultMySqlModel(); + BeanUtils.copyProperties(oldResult, newResult); - taskResultRepository.save(newResult); + newResult.setId(new TaskResultMySqlModel().getId()); + newResult.setRole(newJob.getMyRole()); + newResult.setJobId(newJob.getJobId()); + newResult.setTaskId(node.createTaskId(newJob)); + + taskResultRepository.save(newResult); + } + + List dataSetModels = tableDataSetService.queryAll(oldJob.getJobId(), + node.getComponentType()); + if (CollectionUtils.isNotEmpty(dataSetModels)) { + for (TableDataSetMysqlModel dataSetModel : dataSetModels) { + TableDataSetMysqlModel newDataSetModel = new TableDataSetMysqlModel(); + BeanUtils.copyProperties(dataSetModel, newDataSetModel); + newDataSetModel.setId(new TableDataSetMysqlModel().getId()); + newDataSetModel.setDerivedFromJobId(newJob.getJobId()); + newDataSetModel.setDerivedFrom(node.getComponentType()); + tableDataSetService.save(newDataSetModel); } - - DataSetMysqlModel dataSetModel = dataSetService.query(oldJob.getJobId(), node.getComponentType()); - if (dataSetModel != null) { - DataSetMysqlModel newDataSetModel = new DataSetMysqlModel(); - BeanUtils.copyProperties(dataSetModel, newDataSetModel); - newDataSetModel.setId(new DataSetMysqlModel().getId()); - newDataSetModel.setSourceJobId(newJob.getJobId()); - newDataSetModel.setSourceType(node.getComponentType()); - dataSetService.save(newDataSetModel); - } } + return newTask; } @@ -757,33 +746,39 @@ private List listJobMembers(String projectId, String flowId String promoterId = null; for (ProjectFlowNodeMySqlModel node : nodes) { - List dataSetItemList = null; - if (node.getComponentType().equals(ComponentType.Oot)) { - if (isOotMode) { - OotComponent.Params params = (OotComponent.Params) Components - .get(node.getComponentType()) - .deserializationParam(null, node.getParams()); - // oot model - dataSetItemList = StringUtil.isNotEmpty(params.getJobId()) ? params.getDataSetList() : dataSetItemList; - } - } else { - DataIOComponent.Params params = (DataIOComponent.Params) Components - .get(node.getComponentType()) - .deserializationParam(null, node.getParams()); - dataSetItemList = params.getDataSetList(); + List dataSetItemList = null; + + AbstractDataIOParam params = (AbstractDataIOParam) Components + .get(node.getComponentType()) + .deserializationParam(node.getParams()); + + switch (node.getComponentType()) { + case DataIO: + case ImageDataIO: + dataSetItemList = params.getDataSetList(); + break; + case Oot: + OotComponent.Params ootParams = (OotComponent.Params) params; + dataSetItemList = StringUtil.isNotEmpty(ootParams.getJobId()) + ? params.getDataSetList() + : dataSetItemList; + break; + default: + StatusCode.UNEXPECTED_ENUM_CASE.throwException(); } if (CollectionUtils.isEmpty(dataSetItemList)) { continue; } - for (DataIOComponent.DataSetItem item : dataSetItemList) { - boolean existMember = jobMembers.stream().anyMatch(x -> - x.getMemberId().equals(item.getMemberId()) - && x.getJobRole().equals(item.getMemberRole()) - ); + for (AbstractDataSetItem item : dataSetItemList) { + boolean memberExisted = jobMembers.stream() + .anyMatch(x -> + x.getMemberId().equals(item.getMemberId()) + && x.getJobRole().equals(item.getMemberRole()) + ); - if (existMember) { + if (memberExisted) { continue; } @@ -870,9 +865,9 @@ private List listJobDataSets(JobMySqlModel job, List member.memberRole = item.getMemberRole(); member.dataSetId = item.getDataSetId(); - DataSetOutputModel dataSetInfo = dataSetService.findDataSetFromLocalOrUnion(member.memberId, member.dataSetId); + TableDataSetOutputModel dataSetInfo = tableDataSetService.findDataSetFromLocalOrUnion(member.memberId, member.dataSetId); if (dataSetInfo != null) { - member.dataSetRows = dataSetInfo.getRowCount(); + member.dataSetRows = dataSetInfo.getTotalDataCount(); member.dataSetFeatures = dataSetInfo.getFeatureCount(); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowNodeService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowNodeService.java index 5c7d51745..689b03704 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowNodeService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowNodeService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,28 +26,29 @@ import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.database.repository.ProjectFlowNodeRepository; import com.welab.wefe.board.service.database.repository.ProjectFlowRepository; -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.TableDataSetOutputModel; import com.welab.wefe.board.service.dto.entity.job.ProjectFlowNodeOutputModel; -import com.welab.wefe.board.service.dto.kernel.JobDataSet; +import com.welab.wefe.board.service.dto.kernel.machine_learning.JobDataSet; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; -import com.welab.wefe.board.service.util.ModelMapper; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.ProjectFlowStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectFlowStatus; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Objects; import java.util.stream.Collectors; /** @@ -63,7 +64,7 @@ public class ProjectFlowNodeService { @Autowired private ProjectFlowService projectFlowService; @Autowired - private DataSetService dataSetService; + private TableDataSetService tableDataSetService; @Autowired private ProjectFlowRepository projectFlowRepo; @@ -112,10 +113,10 @@ public List findFlowDataSetInfo(String flowId) throws StatusCodeWith for (DataIOComponent.DataSetItem item : params.getDataSetList()) { JobDataSet.Member member = new JobDataSet.Member(); member.memberRole = item.getMemberRole(); - DataSetOutputModel dataSetInfo = dataSetService.findDataSetFromLocalOrUnion(item.getMemberId(), + TableDataSetOutputModel dataSetInfo = tableDataSetService.findDataSetFromLocalOrUnion(item.getMemberId(), item.getDataSetId()); if (dataSetInfo != null) { - member.dataSetRows = dataSetInfo.getRowCount(); + member.dataSetRows = dataSetInfo.getTotalDataCount(); member.dataSetFeatures = dataSetInfo.getFeatureCount(); } dataSet.members.add(member); @@ -137,6 +138,7 @@ public List listAboutLoadDataSetNodes(String flowId) .in("componentType", Arrays.asList( ComponentType.DataIO, + ComponentType.ImageDataIO, ComponentType.HorzXGBoostValidationDataSetLoader, ComponentType.VertXGBoostValidationDataSetLoader, ComponentType.HorzLRValidationDataSetLoader, @@ -151,12 +153,14 @@ public List listAboutLoadDataSetNodes(String flowId) /** * Nodes in the update flow */ + @Transactional(rollbackFor = Exception.class) public List updateFlowNode(UpdateApi.Input input) throws StatusCodeWithException { // Update flow status projectFlowService.updateFlowStatus(input.getFlowId(), ProjectFlowStatus.editing); ProjectFlowNodeMySqlModel node = findOne(input.getFlowId(), input.getNodeId()); + List list = new ArrayList<>(); // If the node does not exist, it will be created automatically. @@ -175,22 +179,6 @@ public List updateFlowNode(UpdateApi.Input input) th } // If the node already exists, update it. else { - // If the parameters have not changed, jump out. - if (input.getParams().equals(node.getParams())) { - - // Repeat the update, the DataIO data has not changed, - // but the previous operation may not be completed. In order to have a better experience, - // the feature selection component list with empty parameters is also returned. - if (node.getComponentType() == ComponentType.DataIO) { - List nodes = findNodesByFlowId(node.getFlowId()); - - list = nodes.stream().filter(x -> Objects.requireNonNull(Components.get(x.getComponentType())).canSelectFeatures() && x.getParams() == null) - .map(x -> ModelMapper.map(x, ProjectFlowNodeOutputModel.class)) - .collect(Collectors.toList()); - } - return list; - } - node.setParams(input.getParams()); node.setParamsVersion(System.currentTimeMillis()); node.setUpdatedBy(input); @@ -201,7 +189,7 @@ public List updateFlowNode(UpdateApi.Input input) th if (node.getComponentType() == ComponentType.DataIO) { List nodes = findNodesByFlowId(node.getFlowId()); for (ProjectFlowNodeMySqlModel flowNode : nodes) { - if (Objects.requireNonNull(Components.get(flowNode.getComponentType())).canSelectFeatures()) { + if (Components.get(flowNode.getComponentType()).canSelectFeatures()) { flowNode.setParams(null); projectFlowNodeRepository.save(flowNode); list.add(ModelMapper.map(flowNode, ProjectFlowNodeOutputModel.class)); diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowService.java index 0c1db69e7..8be6a305e 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectFlowService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -37,15 +37,15 @@ import com.welab.wefe.board.service.dto.entity.project.ProjectFlowListOutputModel; import com.welab.wefe.board.service.dto.entity.project.ProjectFlowProgressOutputModel; import com.welab.wefe.board.service.onlinedemo.OnlineDemoBranchStrategy; -import com.welab.wefe.board.service.util.ModelMapper; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.*; +import com.welab.wefe.common.data.mysql.enums.OrderBy; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.DateUtil; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.web.CurrentAccount; -import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.*; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; @@ -105,10 +105,10 @@ public synchronized void delete(DeleteApi.Input input) throws StatusCodeWithExce ProjectMySqlModel project = projectService.findByProjectId(flow.getProjectId()); - if (!input.fromGateway() && !flow.getCreatedBy().equals(CurrentAccount.id()) && !CurrentAccount.isAdmin()) { - throw new StatusCodeWithException("只能删除自己创建的流程。", StatusCode.UNSUPPORTED_HANDLE); - } - + if (!input.fromGateway() && !flow.getCreatedBy().equals(CurrentAccount.id()) && !CurrentAccount.isAdmin()) { + throw new StatusCodeWithException("只能删除自己创建的流程。", StatusCode.UNSUPPORTED_HANDLE); + } + flow.setDeleted(true); flow.setUpdatedBy(input); projectFlowRepo.save(flow); @@ -146,8 +146,13 @@ public synchronized String addFlow(AddFlowApi.Input input) throws StatusCodeWith input.setFlowId(UUID.randomUUID().toString().replaceAll("-", "")); } + if (project.getProjectType() == ProjectType.DeepLearning && input.getDeepLearningJobType() == null) { + StatusCode.PARAMETER_CAN_NOT_BE_EMPTY.throwException("深度学习项目请指定任务类型"); + } + ProjectFlowMySqlModel flow = new ProjectFlowMySqlModel(); flow.setFederatedLearningType(input.getFederatedLearningType()); + flow.setDeepLearningJobType(input.getDeepLearningJobType()); flow.setCreatedBy(input); flow.setProjectId(input.getProjectId()); flow.setFlowId(input.getFlowId()); @@ -192,10 +197,10 @@ public synchronized void updateFlowBaseInfo(UpdateFlowBaseInfoApi.Input input) t if (flow == null) { throw new StatusCodeWithException("未找到该流程", StatusCode.ILLEGAL_REQUEST); } - if (input.getFederatedLearningType() != null - && flow.getFederatedLearningType() != input.getFederatedLearningType()) { - throw new StatusCodeWithException("训练类型不允许更改", StatusCode.ILLEGAL_REQUEST); - } + if (input.getFederatedLearningType() != null + && flow.getFederatedLearningType() != input.getFederatedLearningType()) { + throw new StatusCodeWithException("训练类型不允许更改", StatusCode.ILLEGAL_REQUEST); + } // List nodes = projectFlowNodeService.findNodesByFlowId(flow.getFlowId()); // if (nodes != null && !nodes.isEmpty()) { // for (ProjectFlowNodeMySqlModel node : nodes) { @@ -352,12 +357,13 @@ public ProjectFlowMySqlModel findOne(String flowId) { return projectFlowRepo.findOne("flowId", flowId, ProjectFlowMySqlModel.class); } - public PagingOutput query(QueryFlowListApi.Input input) { + public PagingOutput query(FlowQueryApi.Input input) { Specification where = Where .create() .equal("projectId", input.getProjectId()) .equal("deleted", input.isDeleted()) + .in("flowId", input.getFlowIdList()) .build(ProjectFlowMySqlModel.class); PagingOutput page = projectFlowRepo.paging(where, input, ProjectFlowListOutputModel.class); @@ -369,7 +375,7 @@ public PagingOutput query(QueryFlowListApi.Input inp if (lastJob != null) { x.setJobProgress(lastJob.getProgress()); } - x.setIsCreator(CacheObjects.isCurrentMember(x.getCreatedBy())); + x.setIsCreator(CacheObjects.isCurrentMemberAccount(x.getCreatedBy())); }); return page; } @@ -409,17 +415,27 @@ public synchronized void copy(CopyFlowApi.Input input) throws StatusCodeWithExce ProjectFlowMySqlModel sourceProjectFlow = findOne(input.getSourceFlowId()); if (sourceProjectFlow == null) { // If the source replication flow cannot be found locally, obtain the source flow from the initiator - ApiResult flowDetail = gatewayService.sendToBoardRedirectApi(targetPromoterProjectMember.getMemberId(), JobMemberRole.provider, new DetailFlowApi.Input(input.getSourceFlowId()), DetailFlowApi.class); - sourceProjectFlow = JSONObject.toJavaObject(JObject.create(flowDetail.data), ProjectFlowMySqlModel.class); + + sourceProjectFlow = gatewayService.callOtherMemberBoard( + targetPromoterProjectMember.getMemberId(), + JobMemberRole.provider, + DetailFlowApi.class, + new DetailFlowApi.Input(input.getSourceFlowId()), + ProjectFlowMySqlModel.class + ); } if (sourceProjectFlow == null) { throw new StatusCodeWithException(StatusCode.DATA_NOT_FOUND, "找不到原流程信息:" + input.getSourceFlowId()); } // Get the node information of the original process - ApiResult sourceProjectFlowNodeListApiResult = gatewayService.sendToBoardRedirectApi(targetPromoterProjectMember.getMemberId(), JobMemberRole.provider, new QueryFlowNodeListApi.Input(input.getSourceFlowId()), QueryFlowNodeListApi.class); - JObject sourceProjectFlowNodeDataObj = JObject.create(sourceProjectFlowNodeListApiResult.data); - List sourceProjectFlowNodeList = JObject.parseArray(sourceProjectFlowNodeDataObj.getStringByPath("list")).toJavaList(ProjectFlowNodeOutputModel.class); + ListFlowNodeApi.Output output = gatewayService.callOtherMemberBoard( + targetPromoterProjectMember.getMemberId(), + ListFlowNodeApi.class, + new ListFlowNodeApi.Input(input.getSourceFlowId()), + ListFlowNodeApi.Output.class + ); + List sourceProjectFlowNodeList = output.getList(); if (CollectionUtils.isEmpty(sourceProjectFlowNodeList)) { throw new StatusCodeWithException(StatusCode.DATA_NOT_FOUND, "找不到原流程节点信息:" + input.getSourceFlowId()); } @@ -478,7 +494,7 @@ public void updateFlowStatus(String flowId, ProjectFlowStatus projectFlowStatus) ProjectFlowMySqlModel flow = findOne(flowId); if (flow == null) { - throw new StatusCodeWithException(StatusCode.DATA_NOT_FOUND, "找不到需要更新的流程!"); + throw new StatusCodeWithException("找不到需要更新的流程!", StatusCode.DATA_NOT_FOUND); } flow.setFlowStatus(projectFlowStatus); @@ -520,7 +536,7 @@ public PagingOutput queryModelingInfo(QueryApi.Input in /** * Query model details: including model evaluation results. */ - public TaskResultOutputModel findModelingResult(DetailApi.Input input) { + public TaskResultOutputModel findModelingResult(DetailApi.Input input) throws StatusCodeWithException { TaskResultOutputModel result = null; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectMemberAuditService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectMemberAuditService.java index 6229fab30..bce57574a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectMemberAuditService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectMemberAuditService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,11 +20,14 @@ import com.welab.wefe.board.service.database.entity.job.ProjectMemberAuditMySqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectMemberMySqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectMySqlModel; -import com.welab.wefe.board.service.database.repository.*; +import com.welab.wefe.board.service.database.repository.ProjectDataSetRepository; +import com.welab.wefe.board.service.database.repository.ProjectMemberAuditRepository; +import com.welab.wefe.board.service.database.repository.ProjectMemberRepository; +import com.welab.wefe.board.service.database.repository.ProjectRepository; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.AuditStatus; import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.wefe.enums.AuditStatus; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; import org.springframework.stereotype.Service; @@ -55,9 +58,6 @@ public class ProjectMemberAuditService { @Autowired ProjectDataSetService projectDataSetService; - @Autowired - DataSetRepository dataSetRepository; - @Autowired GatewayService gatewayService; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectMemberService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectMemberService.java index 32c1dbec4..8d3a06a61 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectMemberService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectMemberService.java @@ -1,4 +1,4 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,20 +17,20 @@ package com.welab.wefe.board.service.service; import com.welab.wefe.board.service.api.project.member.AddApi; -import com.welab.wefe.board.service.api.project.member.ListApi; +import com.welab.wefe.board.service.api.project.member.ListInProjectApi; import com.welab.wefe.board.service.database.entity.job.*; import com.welab.wefe.board.service.database.repository.ProjectMemberAuditRepository; import com.welab.wefe.board.service.database.repository.ProjectMemberRepository; import com.welab.wefe.board.service.dto.entity.ProjectMemberInput; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.OrderBy; +import com.welab.wefe.common.data.mysql.enums.OrderBy; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.apache.commons.collections4.CollectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -346,7 +346,7 @@ public ProjectMemberMySqlModel update(ProjectMemberMySqlModel projectMemberMySql } - public List findList(ListApi.Input input) throws StatusCodeWithException { + public List findList(ListInProjectApi.Input input) throws StatusCodeWithException { List projectMemberMySqlModelList = findListByProjectId(input.getProjectId()); if (StringUtil.isEmpty(input.getOotJobId())) { return projectMemberMySqlModelList; @@ -386,5 +386,19 @@ public List findList(ListApi.Input input) throws Status return resultList; } + /** + * Get the list of official providers in the project + */ + public List listFormalProjectProviders(String projectId) { + Specification where = Where + .create() + .equal("projectId", projectId) + .equal("auditStatus", AuditStatus.agree) + .equal("exited", false) + .equal("memberRole", JobMemberRole.provider) + .build(ProjectMemberMySqlModel.class); + + return projectMemberRepo.findAll(where); + } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectService.java index 39c70c8d5..7e836be2a 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ProjectService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,11 +16,10 @@ package com.welab.wefe.board.service.service; -import com.alibaba.fastjson.JSONObject; import com.welab.wefe.board.service.api.project.dataset.AddDataSetApi; import com.welab.wefe.board.service.api.project.dataset.RemoveDataSetApi; import com.welab.wefe.board.service.api.project.member.ExitProjectApi; -import com.welab.wefe.board.service.api.project.member.ListApi; +import com.welab.wefe.board.service.api.project.member.ListInProjectApi; import com.welab.wefe.board.service.api.project.member.RemoveApi; import com.welab.wefe.board.service.api.project.project.*; import com.welab.wefe.board.service.database.entity.job.*; @@ -28,25 +27,29 @@ import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.dto.entity.ProjectDataSetInput; import com.welab.wefe.board.service.dto.entity.ProjectMemberInput; -import com.welab.wefe.board.service.dto.entity.project.*; +import com.welab.wefe.board.service.dto.entity.project.ProjectDetailMemberOutputModel; +import com.welab.wefe.board.service.dto.entity.project.ProjectMemberOutputModel; +import com.welab.wefe.board.service.dto.entity.project.ProjectOutputModel; +import com.welab.wefe.board.service.dto.entity.project.ProjectQueryOutputModel; +import com.welab.wefe.board.service.dto.entity.project.data_set.ProjectDataResourceOutputModel; import com.welab.wefe.board.service.dto.vo.AuditStatusCounts; import com.welab.wefe.board.service.dto.vo.RoleCounts; import com.welab.wefe.board.service.onlinedemo.OnlineDemoBranchStrategy; -import com.welab.wefe.board.service.util.ModelMapper; +import com.welab.wefe.board.service.service.data_resource.DataResourceService; import com.welab.wefe.common.Convert; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.AuditStatus; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.ProjectFlowStatus; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.util.ThreadUtil; import com.welab.wefe.common.web.CurrentAccount; import com.welab.wefe.common.web.dto.AbstractApiInput; -import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.ProjectFlowStatus; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -89,10 +92,6 @@ public class ProjectService extends AbstractService { @Autowired private ProjectDataSetRepository projectDataSetRepo; - - @Autowired - private DataSetService dataSetService; - @Autowired private ProjectMemberAuditRepository projectMemberAuditRepository; @Autowired @@ -103,6 +102,8 @@ public class ProjectService extends AbstractService { @Autowired private ProjectFlowNodeRepository projectFlowNodeRepository; + @Autowired + private DataResourceService dataResourceService; /** * New Project @@ -157,6 +158,7 @@ public synchronized String addProject(AddApi.Input input) throws StatusCodeWithE .append(ProjectFlowStatus.editing.name(), 0) .append(ProjectFlowStatus.running.name(), 0) .append(ProjectFlowStatus.finished.name(), 0).toJSONString()); + project.setProjectType(input.getProjectType()); projectRepo.save(project); // create and save ProjectMember to database @@ -193,12 +195,13 @@ public synchronized String addProject(AddApi.Input input) throws StatusCodeWithE dataSet.setStatusUpdatedTime(new Date()); dataSet.setAuditStatus(auditStatus); dataSet.setSourceType(null); + dataSet.setDataResourceType(dataSetInput.getDataResourceType()); projectDataSetRepo.save(dataSet); // Update the usage count of the dataset in the project - if (auditStatus == AuditStatus.agree) { - dataSetService.updateUsageCountInProject(dataSet.getDataSetId()); + if (auditStatus == AuditStatus.agree && CacheObjects.isCurrentMember(dataSetInput.getMemberId())) { + dataResourceService.updateUsageCountInProject(dataSet.getDataSetId()); } } @@ -272,12 +275,12 @@ public ProjectOutputModel detail(String projectId) throws StatusCodeWithExceptio .map(x -> ModelMapper.map(x, ProjectDetailMemberOutputModel.class)) .collect(Collectors.toList()); - List allDataSetList = projectDataSetService.listRawDataSet(projectId, null, null, null); + List allDataSetList = projectDataSetService.listRawDataSet(projectId, null, null, null, null); // Populate the member's data set list allMemberList.forEach(member -> - member.setDataSetList( + member.setDataResourceList( allDataSetList .stream() .filter(dataSet -> @@ -307,9 +310,9 @@ public ProjectOutputModel detail(String projectId) throws StatusCodeWithExceptio .filter(x -> x.getMemberRole() == JobMemberRole.provider).collect(Collectors.toList()); ProjectOutputModel output = ModelMapper.map(project, ProjectOutputModel.class); - ProjectDetailMemberOutputModel newPromoter = JSONObject.parseObject(JSONObject.toJSONString(promoter), - ProjectDetailMemberOutputModel.class); - output.setPromoter(newPromoter); +// ProjectDetailMemberOutputModel newPromoter = JSONObject.parseObject(JSONObject.toJSONString(promoter), +// ProjectDetailMemberOutputModel.class); + output.setPromoter(promoter); output.setProviderList(providers); output.setPromoterList(promoters); output.setIsCreator( @@ -346,7 +349,7 @@ public synchronized void removeMember(RemoveApi.Input input) throws StatusCodeWi projectDataSetService .listAllRawDataSet(project.getProjectId(), member.getMemberId()) .stream() - .forEach(x -> dataSetService.updateUsageCountInProject(x.getDataSetId())); + .forEach(x -> dataResourceService.updateUsageCountInProject(x.getDataSetId())); checkAuditingRecord(input.getProjectId(), input.getMemberId()); @@ -453,11 +456,11 @@ public synchronized ProjectMySqlModel addProjectDataSet(AddDataSetApi.Input inpu } } - if (CollectionUtils.isEmpty(input.getDataSetList())) { + if (CollectionUtils.isEmpty(input.getDataResourceList())) { throw new StatusCodeWithException("数据集不能为空", StatusCode.ILLEGAL_REQUEST); } - for (ProjectDataSetInput item : input.getDataSetList()) { + for (ProjectDataSetInput item : input.getDataResourceList()) { // Determine whether the member exists ProjectMemberMySqlModel member = projectMemberService.findOneByMemberId(input.getProjectId(), item.getMemberId(), item.getMemberRole()); if (member == null) { @@ -490,13 +493,13 @@ public synchronized ProjectMySqlModel addProjectDataSet(AddDataSetApi.Input inpu projectDataSet.setMemberRole(item.getMemberRole()); projectDataSet.setStatusUpdatedTime(new Date()); projectDataSet.setSourceType(null); - + projectDataSet.setDataResourceType(item.getDataResourceType()); } projectDataSetRepo.save(projectDataSet); // Update the usage count of the dataset in the project if (projectDataSet.getAuditStatus() == AuditStatus.agree) { - dataSetService.updateUsageCountInProject(projectDataSet.getDataSetId()); + dataResourceService.updateUsageCountInProject(projectDataSet.getDataSetId()); } } @@ -546,8 +549,7 @@ public synchronized void removeDataSet(RemoveDataSetApi.Input input) throws Stat if (project.getMyRole() != JobMemberRole.promoter) { throw new StatusCodeWithException("只有 promoter 才能删除衍生数据集", StatusCode.ILLEGAL_REQUEST); } - - dataSetService.delete(projectDataSet.getDataSetId()); + dataResourceService.delete(projectDataSet.getDataSetId(), projectDataSet.getDataResourceType()); } } @@ -556,7 +558,7 @@ public synchronized void removeDataSet(RemoveDataSetApi.Input input) throws Stat projectDataSetRepo.deleteById(projectDataSet.getId()); // Update the usage count of the dataset in the project - dataSetService.updateUsageCountInProject(projectDataSet.getDataSetId()); + dataResourceService.updateUsageCountInProject(projectDataSet.getDataSetId()); gatewayService.syncToNotExistedMembers(input.getProjectId(), input, RemoveDataSetApi.class); @@ -623,7 +625,7 @@ public CountStatisticsApi.Output statistics(QueryApi.Input input) { public PagingOutput query(QueryApi.Input input) { StringBuffer sql = new StringBuffer( - "select distinct(p.id),p.flow_status_statistics,p.deleted,p.name,p.project_desc,p.audit_status,p.status_updated_time" + "select distinct(p.id),p.project_type,p.flow_status_statistics,p.deleted,p.name,p.project_desc,p.audit_status,p.status_updated_time" + ",p.audit_status_from_myself,p.audit_status_from_others,p.audit_comment,p.exited,p.closed" + ",p.closed_by,p.closed_time,p.exited_by,p.exited_time" + ",p.project_id,p.member_id,p.my_role" @@ -658,6 +660,11 @@ private String buildQueryWhere(QueryApi.Input input) { where.append(" and p.deleted != true "); + + if (input.getProjectType() != null) { + where.append(" and p.project_type = '" + input.getProjectType() + "'"); + } + if (StringUtil.isNotBlank(input.getName())) { where.append(" and p.name like '%" + input.getName() + "%'"); } @@ -918,9 +925,13 @@ public void syncAuditProjectInfo(String projectId, AuditApi.Input input) throws throw new StatusCodeWithException(StatusCode.DATA_NOT_FOUND, "找不到promoter成员信息"); } - ApiResult detailResult = gatewayService.sendToBoardRedirectApi(promoterProjectMember.getMemberId(), JobMemberRole.provider, new DataInfoApi.Input(projectId), DataInfoApi.class); - - DataInfoApi.Output dataInfoOutput = JSONObject.toJavaObject(JObject.create(detailResult.data), DataInfoApi.Output.class); + DataInfoApi.Output dataInfoOutput = gatewayService.callOtherMemberBoard( + promoterProjectMember.getMemberId(), + JobMemberRole.provider, + DataInfoApi.class, + new DataInfoApi.Input(projectId), + DataInfoApi.Output.class + ); for (ProjectMemberMySqlModel projectMemberMySqlModel : dataInfoOutput.getProjectMembers()) { @@ -940,12 +951,12 @@ public void syncAuditProjectInfo(String projectId, AuditApi.Input input) throws } - for (ProjectDataSetMySqlModel dataSetMySqlModel : dataInfoOutput.getProjectDataSets()) { + for (ProjectDataSetMySqlModel item : dataInfoOutput.getProjectDataSets()) { // Filter derived data sets - if (dataSetMySqlModel.getSourceType() != null) { + if (item.getSourceType() != null) { continue; } - ProjectDataSetMySqlModel projectDataSet = projectDataSetService.findOne(dataSetMySqlModel.getProjectId(), dataSetMySqlModel.getDataSetId(), dataSetMySqlModel.getMemberRole()); + ProjectDataSetMySqlModel projectDataSet = projectDataSetService.findOne(item.getProjectId(), item.getDataSetId(), item.getMemberRole()); if (projectDataSet != null) { projectDataSetService.update(projectDataSet, dataSet -> { dataSet.setAuditStatus(projectDataSet.getAuditStatus()); @@ -953,8 +964,7 @@ public void syncAuditProjectInfo(String projectId, AuditApi.Input input) throws }); } else { - - projectDataSetRepo.save(dataSetMySqlModel); + projectDataSetRepo.save(item); } } List excludeFlowIds = new ArrayList<>(); @@ -970,7 +980,6 @@ public void syncAuditProjectInfo(String projectId, AuditApi.Input input) throws if (projectFlowService.findOne(projectFlowMySqlModel.getFlowId()) == null) { projectFlowMySqlModel.setMyRole(project.getMyRole()); projectFlowMySqlModel.setFlowStatus(ProjectFlowStatus.editing); - // todo: put creator_member_id on next version projectFlowMySqlModel.setCreatedBy(""); projectFlowRepository.save(projectFlowMySqlModel); } @@ -1021,6 +1030,7 @@ public void pullNewestProjectInfo(AbstractApiInput input, String projectId, Stri .append(ProjectFlowStatus.editing.name(), 0) .append(ProjectFlowStatus.running.name(), 0) .append(ProjectFlowStatus.finished.name(), 0).toJSONString()); + project.setProjectType(projectMySqlModel.getProjectType()); projectRepo.save(project); // save ProjectMember to database @@ -1056,6 +1066,7 @@ public void pullNewestProjectInfo(AbstractApiInput input, String projectId, Stri dataSet.setStatusUpdatedTime(x.getStatusUpdatedTime()); dataSet.setAuditStatus(x.getMemberId().equals(CacheObjects.getMemberId()) ? AuditStatus.auditing : x.getAuditStatus()); dataSet.setAuditComment(x.getMemberId().equals(CacheObjects.getMemberId()) ? "" : x.getAuditComment()); + dataSet.setDataResourceType(x.getDataResourceType()); projectDataSetRepo.save(dataSet); }); @@ -1121,7 +1132,7 @@ public void pullNewestProjectInfo(AbstractApiInput input, String projectId, Stri params.put("auditStatus", model.getMemberId().equals(CacheObjects.getMemberId()) && model.getMemberRole() == myRole ? AuditStatus.auditing - : model.getAuditStatus()); + : model.getAuditStatus()); params.put("auditComment", model.getMemberId().equals(CacheObjects.getMemberId()) ? "" : model.getAuditComment()); projectDataSetRepo.updateById(model.getId(), params, ProjectDataSetMySqlModel.class); @@ -1137,13 +1148,17 @@ public void pullNewestProjectInfo(AbstractApiInput input, String projectId, Stri public DataInfoApi.Output getPromoterDataInfo(String projectId, String callerMemberId) throws StatusCodeWithException { // Get all project members from the sender - ApiResult membersResult = gatewayService.sendToBoardRedirectApi(callerMemberId, JobMemberRole.provider, new ListApi.Input(projectId), ListApi.class); + ListInProjectApi.Output output = gatewayService.callOtherMemberBoard( + callerMemberId, + JobMemberRole.provider, + ListInProjectApi.class, + new ListInProjectApi.Input(projectId), + ListInProjectApi.Output.class + ); // Find the promoter in the current project from all members of the sender - ProjectMemberOutputModel promoterMember = JObject.create(membersResult.data) - .getJSONList("list") + ProjectMemberOutputModel promoterMember = output.getList() .stream() - .map(x -> JSONObject.toJavaObject(x, ProjectMemberOutputModel.class)) .filter(x -> x.getMemberRole() == JobMemberRole.promoter) .findFirst() .orElse(null); @@ -1155,11 +1170,13 @@ public DataInfoApi.Output getPromoterDataInfo(String projectId, String callerMem String promoterMemberId = promoterMember.getMemberId(); // Get project details from the promoter - ApiResult detailResult = gatewayService.sendToBoardRedirectApi(promoterMemberId, JobMemberRole.provider, new DataInfoApi.Input(projectId), DataInfoApi.class); - - DataInfoApi.Output projectOutputModel = JSONObject.toJavaObject(JObject.create(detailResult.data), DataInfoApi.Output.class); - - return projectOutputModel; + return gatewayService.callOtherMemberBoard( + promoterMemberId, + JobMemberRole.provider, + DataInfoApi.class, + new DataInfoApi.Input(projectId), + DataInfoApi.Output.class + ); } @@ -1206,7 +1223,7 @@ public void exitProject(ExitProjectApi.Input input) throws StatusCodeWithExcepti projectDataSetService .listAllRawDataSet(project.getProjectId(), memberId) .stream() - .forEach(x -> dataSetService.updateUsageCountInProject(x.getDataSetId())); + .forEach(x -> dataResourceService.updateUsageCountInProject(x.getDataSetId())); } @@ -1237,7 +1254,7 @@ public void closeProject(CloseProjectApi.Input input) throws StatusCodeWithExcep projectDataSetService .listAllRawDataSet(project.getProjectId(), null) .stream() - .forEach(x -> dataSetService.updateUsageCountInProject(x.getDataSetId())); + .forEach(x -> dataResourceService.updateUsageCountInProject(x.getDataSetId())); // Notify other members that the project is closed try { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ServiceCheckService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ServiceCheckService.java index d8a3cab05..936c0e61d 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ServiceCheckService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ServiceCheckService.java @@ -1,12 +1,12 @@ /* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,152 +16,88 @@ package com.welab.wefe.board.service.service; -import com.welab.wefe.board.service.api.member.ServiceStatusCheckApi; -import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.board.service.api.member.MemberAvailableCheckApi; import com.welab.wefe.board.service.database.entity.job.ProjectMemberMySqlModel; -import com.welab.wefe.board.service.dto.vo.MemberServiceStatusOutput; -import com.welab.wefe.board.service.exception.MemberGatewayException; import com.welab.wefe.board.service.sdk.FlowService; -import com.welab.wefe.board.service.sdk.UnionService; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; -import com.welab.wefe.common.data.storage.config.JdbcParamConfig; -import com.welab.wefe.common.data.storage.model.DataItemModel; -import com.welab.wefe.common.data.storage.repo.Storage; -import com.welab.wefe.common.data.storage.service.StorageService; -import com.welab.wefe.common.enums.MemberService; +import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; -import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.checkpoint.CheckpointManager; +import com.welab.wefe.common.wefe.checkpoint.dto.MemberAvailableCheckOutput; +import com.welab.wefe.common.wefe.checkpoint.dto.ServiceAvailableCheckOutput; +import com.welab.wefe.common.wefe.enums.ServiceType; import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.lang3.RandomStringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.ArrayList; -import java.util.LinkedHashMap; +import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; -import static com.welab.wefe.board.service.service.DataSetStorageService.DATABASE_NAME; - /** * @author lonnie */ @Service -public class ServiceCheckService { - - protected final Logger LOG = LoggerFactory.getLogger(this.getClass()); - - @Autowired - private FlowService flowService; - @Autowired - private UnionService unionService; - - @Autowired - private GatewayService gatewayService; - +public class ServiceCheckService extends AbstractService { @Autowired private ProjectMemberService projectMemberService; - @Autowired - private StorageService storageService; - - @Autowired - private JdbcParamConfig jdbcParamConfig; - - @Autowired - private Config config; - + private CheckpointManager checkpointManager; @Autowired private GlobalConfigService globalConfigService; + @Autowired + private FlowService flowService; /** - * check if each service is available + * 检查指定成员的服务是否可用 */ - public ServiceStatusCheckApi.Output checkMemberServiceStatus(ServiceStatusCheckApi.Input input) throws MemberGatewayException { + public MemberAvailableCheckOutput getMemberAvailableInfo(String memberId) throws StatusCodeWithException { - String memberId = input.getMemberId(); // If you are not checking your own status, go to the gateway. if (!CacheObjects.getMemberId().equals(memberId)) { - ApiResult result = gatewayService.callOtherMemberBoard(memberId, ServiceStatusCheckApi.class, JObject.create(input)); - return JObject.create(result.data).toJavaObject(ServiceStatusCheckApi.Output.class); - - } - - LinkedHashMap result = new LinkedHashMap<>(); - - if (input.getService() == null) { - result.put(MemberService.gateway, checkLocalGatewayStatus()); - result.put(MemberService.flow, checkFlowServiceStatus(!input.fromGateway())); - result.put(MemberService.union, checkUnionServiceStatus()); - result.put(MemberService.storage, checkStorageServiceStatus(!input.fromGateway())); - return new ServiceStatusCheckApi.Output(result); + return gatewayService.callOtherMemberBoard( + memberId, + MemberAvailableCheckApi.class, + new MemberAvailableCheckApi.Input(memberId), + MemberAvailableCheckOutput.class + ); } - switch (input.getService()) { - case union: - result.put(MemberService.union, checkUnionServiceStatus()); - break; - - case gateway: - result.put(MemberService.gateway, checkLocalGatewayStatus()); - break; - - case flow: - result.put(MemberService.flow, checkFlowServiceStatus(!input.fromGateway())); - break; - - case storage: - result.put(MemberService.storage, checkStorageServiceStatus(!input.fromGateway())); - break; + MemberAvailableCheckOutput result = new MemberAvailableCheckOutput(); + List serviceTypes = Arrays.asList( + ServiceType.BoardService, + ServiceType.UnionService, + ServiceType.GatewayService, + ServiceType.FlowService + ); - default: - break; + for (ServiceType type : serviceTypes) { + result.put(type, getServiceAvailableInfo(type)); } - return new ServiceStatusCheckApi.Output(result); - } - - /** - * check if the union service is available - */ - public MemberServiceStatusOutput checkUnionServiceStatus() { - MemberServiceStatusOutput output = new MemberServiceStatusOutput(MemberService.union); - output.setValue(config.getUNION_BASE_URL()); - try { - unionService.queryMember(0, 1); - output.setSuccess(true); - } catch (StatusCodeWithException e) { - output.setSuccess(false); - output.setMessage(e.getMessage()); - } - - return output; + return result; } - /** - * check if the gateway service is available - */ - public MemberServiceStatusOutput checkLocalGatewayStatus() { - MemberServiceStatusOutput output = new MemberServiceStatusOutput(MemberService.gateway); - output.setValue(globalConfigService.getGatewayConfig().intranetBaseUri); - + public ServiceAvailableCheckOutput getServiceAvailableInfo(ServiceType serviceType) { try { - GatewayOnlineCheckResult result = checkGatewayConnect(globalConfigService.getGatewayConfig().intranetBaseUri); - - output.setSuccess(result.online); - output.setMessage(result.error); - + switch (serviceType) { + case BoardService: + return checkpointManager.checkAll(); + case GatewayService: + return gatewayService.getLocalGatewayAvailable(); + case UnionService: + return unionService.getAvailable(); + case FlowService: + return flowService.getAvailable(); + default: + StatusCode.UNEXPECTED_ENUM_CASE.throwException(); + } } catch (Exception e) { - output.setSuccess(false); - output.setMessage(e.getMessage()); + return new ServiceAvailableCheckOutput("获取 " + serviceType + " 服务可用性状态失败:" + e.getMessage()); } - - - return output; + return null; } /** @@ -195,85 +131,6 @@ public List gatewayOnlineCheck(boolean local, String p return checkResultList; } - /** - * check if the storage service is available - */ - public MemberServiceStatusOutput checkStorageServiceStatus(boolean checkerIsMyself) { - MemberServiceStatusOutput output = new MemberServiceStatusOutput(MemberService.storage); - - // The connection string is non-exposed information and cannot be output to other members. - if (checkerIsMyself) { - output.setValue(jdbcParamConfig.getUrl()); - } - - - Storage storage = storageService.getStorage(); - String name = RandomStringUtils.randomAlphabetic(6); - try { - storage.put(DATABASE_NAME, name, new DataItemModel<>(name, "test")); - output.setSuccess(true); - } catch (Exception e) { - LOG.error(e.getMessage()); - e.printStackTrace(); - output.setSuccess(false); - output.setMessage(config.getDbType().name() + " put异常,请检查相关配置是否正确。"); - return output; - } - - try { - storage.dropTB(DATABASE_NAME, name); - output.setSuccess(true); - } catch (Exception e) { - output.setSuccess(false); - output.setMessage(config.getDbType().name() + " drop异常,请检查相关配置是否正确。"); - } - - return output; - } - - /** - * check if the flow service is available - */ - public MemberServiceStatusOutput checkFlowServiceStatus(boolean checkerIsMyself) { - MemberServiceStatusOutput output = new MemberServiceStatusOutput(MemberService.flow); - - if (checkerIsMyself) { - output.setValue(globalConfigService.getFlowConfig().intranetBaseUri); - } - - try { - JObject result = flowService.dashboard(); - - JObject board = result.getJObject(MemberService.board.name()); - JObject gateway = result.getJObject(MemberService.gateway.name()); - - output = buildResult(MemberService.board.name(), board, output); - output = buildResult(MemberService.gateway.name(), gateway, output); - - } catch (Exception e) { - output.setSuccess(false); - output.setMessage(e.getMessage()); - } - - return output; - } - - private MemberServiceStatusOutput buildResult(String checkpoint, JObject obj, MemberServiceStatusOutput output) { - - if (obj == null) { - output.setSuccess(false); - output.setMessage("flow 服务不可用,检查点 " + checkpoint + " 检查失败。"); - } else if (obj.getInteger("code") != null && obj.getInteger("code") == 0) { - output.setSuccess(true); - output.setMessage(obj.getString("message")); - } else { - output.setSuccess(false); - output.setMessage(obj.getString("message")); - } - - return output; - } - /** * check if the flow gateway is available diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ServingService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ServingService.java index 9f459bcb8..116dbf8ee 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/ServingService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/ServingService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,19 +26,19 @@ import com.welab.wefe.board.service.dto.entity.job.JobMemberOutputModel; import com.welab.wefe.board.service.dto.globalconfig.MemberInfoModel; import com.welab.wefe.board.service.dto.globalconfig.ServingConfigModel; -import com.welab.wefe.board.service.sdk.UnionService; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; import com.welab.wefe.common.CommonThreadPool; import com.welab.wefe.common.StatusCode; -import com.welab.wefe.common.enums.Algorithm; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.http.HttpRequest; import com.welab.wefe.common.http.HttpResponse; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.RSAUtil; import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.enums.Algorithm; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskResultType; import org.apache.commons.collections4.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -57,9 +57,6 @@ public class ServingService extends AbstractService { private static final String SEPARATOR = "_"; - @Autowired - UnionService unionService; - @Autowired JobRepository jobRepository; @@ -84,7 +81,6 @@ public void asynRefreshMemberInfo(MemberInfoModel model) throws StatusCodeWithEx try { refreshMemberInfo(model); } catch (StatusCodeWithException e) { - e.printStackTrace(); LOG.error("serving 响应失败:" + e.getMessage(), StatusCode.REMOTE_SERVICE_ERROR); } }); @@ -126,7 +122,6 @@ private JSONObject request(String api, try { sign = RSAUtil.sign(data, CacheObjects.getRsaPrivateKey()); } catch (Exception e) { - e.printStackTrace(); throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); } @@ -170,12 +165,19 @@ private JSONObject request(String api, * Modeling synchronization to serving */ public void syncModelToServing(SyncModelToServingApi.Input input) throws StatusCodeWithException { + TreeMap jobj = setBody(input.getTaskId(), input.getRole()); + + request("model_save", jobj, true); + } + - TaskResultMySqlModel taskResult = taskResultService.findByTaskIdAndTypeAndRole(input.getTaskId(), TaskResultType.model_train.name(), input.getRole()); + public TreeMap setBody(String taskId, JobMemberRole role) throws StatusCodeWithException { + + TaskResultMySqlModel taskResult = taskResultService.findByTaskIdAndTypeAndRole(taskId, TaskResultType.model_train.name(), role); if (taskResult == null) { LOG.error("查询task任务异常"); - throw new StatusCodeWithException(StatusCode.PARAMETER_VALUE_INVALID); + throw new StatusCodeWithException("task 不存在!", StatusCode.PARAMETER_VALUE_INVALID); } @@ -183,33 +185,28 @@ public void syncModelToServing(SyncModelToServingApi.Input input) throws StatusC if (CollectionUtils.isEmpty(memberList)) { LOG.error("查询job_member异常"); - throw new StatusCodeWithException(StatusCode.PARAMETER_VALUE_INVALID); + throw new StatusCodeWithException("查询job_member异常!", StatusCode.PARAMETER_VALUE_INVALID); } - JobMySqlModel job = jobRepository.findByJobId(taskResult.getJobId(), input.getRole().name()); + JobMySqlModel job = jobRepository.findByJobId(taskResult.getJobId(), role.name()); if (job == null) { LOG.error("查询job异常"); - throw new StatusCodeWithException(StatusCode.PARAMETER_VALUE_INVALID); + throw new StatusCodeWithException("查询job异常!", StatusCode.PARAMETER_VALUE_INVALID); } // Feature engineering - List featureEngineerResults = taskResultService.findByTaskIdAndRoleNotEqualType(input.getTaskId(), TaskResultType.model_train.name(), input.getRole()); - Map featrueEngineerMap = new TreeMap<>(); + List featureEngineerResults = taskResultService.findByTaskIdAndRoleNotEqualType(taskId, TaskResultType.model_train.name(), role); + Map featureEngineerMap = new TreeMap<>(); for (TaskResultMySqlModel fe : featureEngineerResults) { TaskMySqlModel taskMySqlModel = taskService.findOne(fe.getTaskId()); if (taskMySqlModel == null) { LOG.error("查询task任务异常"); throw new StatusCodeWithException(StatusCode.PARAMETER_VALUE_INVALID); } - featrueEngineerMap.put(taskMySqlModel.getPosition(), getModelParam(fe.getResult())); + featureEngineerMap.put(taskMySqlModel.getPosition(), getModelParam(fe.getResult())); } - request(memberList, taskResult, featrueEngineerMap, job); - } - - private void request(List memberList, TaskResultMySqlModel taskResult, Map featrueEngineerMap, JobMySqlModel job) throws StatusCodeWithException { - List members = new ArrayList<>(); memberList.forEach(mem -> { @@ -229,7 +226,7 @@ private void request(List memberList, TaskResultMySqlModel getJSONObject(0). getString("public_key")); } catch (StatusCodeWithException e) { - e.printStackTrace(); + super.log(e); } members.add(member); @@ -244,24 +241,27 @@ private void request(List memberList, TaskResultMySqlModel params.put("flType", job.getFederatedLearningType().name()); params.put("modelParam", taskResult.getResult()); params.put("memberParams", members); - params.put("featrueEngineerMap", featrueEngineerMap); + params.put("featureEngineerMap", featureEngineerMap); - request("model_save", params, true); + return params; } private Algorithm getAlgorithm(ComponentType componentType) { switch (componentType) { case HorzLR: case VertLR: + case MixLR: return Algorithm.LogisticRegression; case HorzSecureBoost: case VertSecureBoost: + case MixSecureBoost: return Algorithm.XGBoost; default: throw new RuntimeException("预算之外的组件类型"); } } + private String getModelParam(String taskResult) { return JObject.create(taskResult).getString("model_param"); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/SystemInitializeService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/SystemInitializeService.java index ffcf61f03..95401a213 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/SystemInitializeService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/SystemInitializeService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,18 +19,22 @@ import com.welab.wefe.board.service.api.member.InitializeApi; import com.welab.wefe.board.service.api.member.UpdateMemberInfoApi; import com.welab.wefe.board.service.api.member.UpdateMemberLogoApi; -import com.welab.wefe.board.service.database.entity.AccountMySqlModel; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.AccountMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.BloomFilterMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; import com.welab.wefe.board.service.database.repository.AccountRepository; -import com.welab.wefe.board.service.database.repository.DataSetRepository; +import com.welab.wefe.board.service.database.repository.data_resource.BloomFilterRepository; +import com.welab.wefe.board.service.database.repository.data_resource.ImageDataSetRepository; +import com.welab.wefe.board.service.database.repository.data_resource.TableDataSetRepository; import com.welab.wefe.board.service.dto.globalconfig.MemberInfoModel; -import com.welab.wefe.board.service.sdk.UnionService; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.board.service.util.BoardSM4Util; import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.constant.SecretKeyType; import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.util.FileUtil; import com.welab.wefe.common.util.RSAUtil; -import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.util.SignUtil; import com.welab.wefe.common.web.CurrentAccount; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -50,33 +54,36 @@ public class SystemInitializeService extends AbstractService { @Autowired private AccountRepository accountRepository; - @Autowired - private UnionService unionService; + private ServingService servingService; @Autowired - private GatewayService gatewayService; - + private TableDataSetRepository tableDataSetRepository; @Autowired - private ServingService servingService; - + private ImageDataSetRepository imageDataSetRepository; @Autowired - private DataSetRepository dataSetRepository; + private BloomFilterRepository bloomFilterRepository; /** * Synchronize member information to union for the recovery of membership after union data is lost. */ public synchronized void syncMemberToUnion() throws StatusCodeWithException { - AccountMySqlModel account = accountRepository.findByPhoneNumber(CurrentAccount.phoneNumber()); + AccountMysqlModel account = accountRepository.findByPhoneNumber(BoardSM4Util.encryptPhoneNumber(CurrentAccount.phoneNumber())); if (!account.getSuperAdminRole()) { throw new StatusCodeWithException("您没有初始化系统的权限,请联系超级管理员(第一个注册的人)进行操作。", StatusCode.INVALID_USER); } unionService.initializeSystem(globalConfigService.getMemberInfo()); - for (DataSetMysqlModel model : dataSetRepository.findAll()) { - unionService.uploadDataSet(model); + for (TableDataSetMysqlModel model : tableDataSetRepository.findAll()) { + unionService.upsertDataResource(model); + } + for (ImageDataSetMysqlModel model : imageDataSetRepository.findAll()) { + unionService.upsertDataResource(model); + } + for (BloomFilterMysqlModel model : bloomFilterRepository.findAll()) { + unionService.upsertDataResource(model); } } @@ -97,7 +104,7 @@ public synchronized void initialize(InitializeApi.Input input) throws StatusCode throw new StatusCodeWithException(StatusCode.UNSUPPORTED_HANDLE, "系统已初始化,不能重复操作。"); } - AccountMySqlModel account = accountRepository.findByPhoneNumber(CurrentAccount.phoneNumber()); + AccountMysqlModel account = accountRepository.findByPhoneNumber(BoardSM4Util.encryptPhoneNumber(CurrentAccount.phoneNumber())); if (!account.getSuperAdminRole()) { throw new StatusCodeWithException("您没有初始化系统的权限,请联系超级管理员(第一个注册的人)进行操作。", StatusCode.INVALID_USER); } @@ -111,11 +118,12 @@ public synchronized void initialize(InitializeApi.Input input) throws StatusCode model.setMemberHidden(false); try { - RSAUtil.RsaKeyPair pair = RSAUtil.generateKeyPair(); - model.setRsaPrivateKey(pair.privateKey); - model.setRsaPublicKey(pair.publicKey); + input.setSecretKeyType(null == input.getSecretKeyType() ? SecretKeyType.rsa : input.getSecretKeyType()); + SignUtil.KeyPair keyPair = SignUtil.generateKeyPair(input.getSecretKeyType()); + model.setRsaPrivateKey(keyPair.privateKey); + model.setRsaPublicKey(keyPair.publicKey); + model.setSecretKeyType(input.getSecretKeyType()); } catch (NoSuchAlgorithmException e) { - e.printStackTrace(); throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); } @@ -130,7 +138,7 @@ public synchronized void initialize(InitializeApi.Input input) throws StatusCode @Transactional(rollbackFor = Exception.class) public void updateMemberInfo(UpdateMemberInfoApi.Input input) throws StatusCodeWithException { - AccountMySqlModel account = accountRepository.findByPhoneNumber(CurrentAccount.phoneNumber()); + AccountMysqlModel account = accountRepository.findByPhoneNumber(BoardSM4Util.encryptPhoneNumber(CurrentAccount.phoneNumber())); if (!account.getSuperAdminRole()) { throw new StatusCodeWithException("您没有编辑权限,请联系超级管理员(第一个注册的人)进行操作。", StatusCode.INVALID_USER); } @@ -141,12 +149,6 @@ public void updateMemberInfo(UpdateMemberInfoApi.Input input) throws StatusCodeW model.setMemberMobile(input.getMemberMobile()); model.setMemberAllowPublicDataSet(input.getMemberAllowPublicDataSet()); model.setMemberGatewayUri(input.getMemberGatewayUri()); - if (StringUtil.isNotEmpty(input.getMemberLogo()) && input.getMemberLogo().contains(",")) { - LOG.info("压缩前:" + input.getMemberLogo().length()); - String[] strs = input.getMemberLogo().split(","); - model.setMemberLogo(strs[0] + "," + FileUtil.compressPicForScale(strs[1], 200, 0.7)); - LOG.info("压缩后:" + model.getMemberLogo().length()); - } model.setMemberHidden(input.getMemberHidden()); globalConfigService.setMemberInfo(model); @@ -162,7 +164,7 @@ public void updateMemberInfo(UpdateMemberInfoApi.Input input) throws StatusCodeW @Transactional(rollbackFor = Exception.class) public void updateMemberRsaKey() throws StatusCodeWithException { - AccountMySqlModel account = accountRepository.findByPhoneNumber(CurrentAccount.phoneNumber()); + AccountMysqlModel account = accountRepository.findByPhoneNumber(BoardSM4Util.encryptPhoneNumber(CurrentAccount.phoneNumber())); if (!account.getSuperAdminRole()) { throw new StatusCodeWithException("您没有编辑权限,请联系超级管理员(第一个注册的人)进行操作。", StatusCode.INVALID_USER); } @@ -170,11 +172,10 @@ public void updateMemberRsaKey() throws StatusCodeWithException { MemberInfoModel model = globalConfigService.getMemberInfo(); try { - RSAUtil.RsaKeyPair pair = RSAUtil.generateKeyPair(); - model.setRsaPrivateKey(pair.privateKey); - model.setRsaPublicKey(pair.publicKey); + SignUtil.KeyPair keyPair = SignUtil.generateKeyPair(model.getSecretKeyType()); + model.setRsaPrivateKey(keyPair.privateKey); + model.setRsaPublicKey(keyPair.publicKey); } catch (NoSuchAlgorithmException e) { - e.printStackTrace(); throw new StatusCodeWithException(e.getMessage(), StatusCode.SYSTEM_ERROR); } @@ -197,7 +198,7 @@ public void updateMemberRsaKey() throws StatusCodeWithException { */ @Transactional(rollbackFor = Exception.class) public void updateMemberLogo(UpdateMemberLogoApi.Input input) throws StatusCodeWithException { - AccountMySqlModel account = accountRepository.findByPhoneNumber(CurrentAccount.phoneNumber()); + AccountMysqlModel account = accountRepository.findByPhoneNumber(BoardSM4Util.encryptPhoneNumber(CurrentAccount.phoneNumber())); if (!account.getSuperAdminRole()) { throw new StatusCodeWithException("您没有编辑权限,请联系超级管理员(第一个注册的人)进行操作。", StatusCode.INVALID_USER); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskProgressService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskProgressService.java index 714a33ccc..bbb162518 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskProgressService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskProgressService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,7 +19,7 @@ import com.welab.wefe.board.service.database.entity.job.TaskProgressMysqlModel; import com.welab.wefe.board.service.database.repository.TaskProgressRepository; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.JobMemberRole; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; import org.springframework.stereotype.Service; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskResultService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskResultService.java index ac6d96215..31b66faf0 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskResultService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskResultService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -32,7 +32,7 @@ import org.springframework.stereotype.Service; import com.alibaba.fastjson.JSONArray; -import com.welab.wefe.board.service.api.dataset.DetailApi; +import com.welab.wefe.board.service.api.data_resource.table_data_set.DetailApi; import com.welab.wefe.board.service.api.project.job.task.GetFeatureApi; import com.welab.wefe.board.service.api.project.job.task.SelectFeatureApi; import com.welab.wefe.board.service.api.project.job.task.SelectFeatureApi.Input.MemberModel; @@ -42,27 +42,27 @@ import com.welab.wefe.board.service.component.feature.FeatureSelectionComponent; import com.welab.wefe.board.service.component.feature.VertOneHotComponent; import com.welab.wefe.board.service.component.feature.VertOneHotComponent.Params.MemberInfoModel; -import com.welab.wefe.board.service.database.entity.data_set.DataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; import com.welab.wefe.board.service.database.entity.job.ProjectMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskMySqlModel; import com.welab.wefe.board.service.database.entity.job.TaskResultMySqlModel; import com.welab.wefe.board.service.database.repository.TaskRepository; import com.welab.wefe.board.service.database.repository.TaskResultRepository; import com.welab.wefe.board.service.dto.entity.MemberFeatureInfoModel; -import com.welab.wefe.board.service.dto.entity.data_set.DataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.TableDataSetOutputModel; import com.welab.wefe.board.service.exception.FlowNodeException; import com.welab.wefe.board.service.exception.MemberGatewayException; import com.welab.wefe.board.service.model.FlowGraph; import com.welab.wefe.board.service.model.FlowGraphNode; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.FederatedLearningType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.TaskResultType; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; -import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.FederatedLearningType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskResultType; /** * @author zane.luo @@ -87,10 +87,10 @@ public class TaskResultService extends AbstractService { @Autowired private ProjectService projectService; - + @Autowired - private DataSetService datasetService; - + private TableDataSetService tableDataSetService; + @Autowired private GatewayService gatewayService; @@ -175,7 +175,7 @@ public JObject selectFeature(SelectFeatureApi.Input input) throws StatusCodeWith * filter the features by cv/iv */ private JObject selectByCvIv(FlowGraph flowGraph, FlowGraphNode node, SelectFeatureApi.Input input) throws FlowNodeException { - + JObject result = JObject.create(); List selectMembers = new ArrayList<>(); // mix flow @@ -189,14 +189,14 @@ private JObject selectByCvIv(FlowGraph flowGraph, FlowGraphNode node, SelectFeat if (featureBinningNode == null) { throw new FlowNodeException(node, "请添加特征分箱组件。"); } - + // Find the task corresponding to the FeatureStatistic node ProjectMySqlModel project = projectService.findProjectByJobId(input.getJobId()); TaskMySqlModel featureStatisticTask = taskRepository.findOne(input.getJobId(), featureStatisticNode.getNodeId(), project.getMyRole().name()); if (featureStatisticTask == null) { throw new FlowNodeException(node, "找不到对应的特征统计任务。"); } - + // Find the task result of FeatureStatistic TaskResultMySqlModel featureStatisticTaskResult = findByTaskIdAndType(featureStatisticTask.getTaskId(), TaskResultType.data_feature_statistic.name()); @@ -205,7 +205,7 @@ private JObject selectByCvIv(FlowGraph flowGraph, FlowGraphNode node, SelectFeat } JObject statisticResult = JObject.create(featureStatisticTaskResult.getResult()); - + TaskMySqlModel featureBinningTask = taskRepository.findOne(input.getJobId(), featureBinningNode.getNodeId(), project.getMyRole().name()); if (featureBinningTask == null) { throw new FlowNodeException(node, "找不到对应的特征分箱任务。"); @@ -216,16 +216,16 @@ private JObject selectByCvIv(FlowGraph flowGraph, FlowGraphNode node, SelectFeat if (featureBinningTaskResult == null) { return JObject.create(); } - + List featureBinningResults = parseBinningResult(featureBinningTaskResult); List statisticResultMembers = statisticResult.getJSONList("members"); for (JObject memberObj : statisticResultMembers) { Map cvMap = new HashMap<>(); Map ivMap = new HashMap<>(); - + String memberId = memberObj.getString("member_id"); String role = memberObj.getString("role"); - + JObject featureBinningResult = featureBinningResults.stream() .filter(s -> role.equalsIgnoreCase(s.getString("role")) && memberId.equalsIgnoreCase(s.getString("memberId"))) @@ -253,7 +253,7 @@ private JObject selectByCvIv(FlowGraph flowGraph, FlowGraphNode node, SelectFeat // Get the feature column of the current member List currentMembers = input.getMembers().stream().filter( - x -> x.getMemberId().equals(memberId) && x.getMemberRole() == JobMemberRole.valueOf(role)) + x -> x.getMemberId().equals(memberId) && x.getMemberRole() == JobMemberRole.valueOf(role)) .collect(Collectors.toList()); if (JobMemberRole.promoter.name().equalsIgnoreCase(role)) { currentMembers = input.getMembers().stream() @@ -276,9 +276,8 @@ private JObject selectByCvIv(FlowGraph flowGraph, FlowGraphNode node, SelectFeat } } } - - } - else { + + } else { // Find the FeatureCalculation node in the parent node FlowGraphNode featureCalculationNode = flowGraph.findOneNodeFromParent(node, ComponentType.FeatureCalculation); @@ -298,7 +297,7 @@ private JObject selectByCvIv(FlowGraph flowGraph, FlowGraphNode node, SelectFeat if (featureCalculationTaskResult == null) { return JObject.create(); } - + result = JObject.create(featureCalculationTaskResult.getResult()); List calculateResults = result.getJSONList("model_param.calculateResults"); @@ -375,7 +374,7 @@ private List parseBinningResult(TaskResultMySqlModel featureBinningTask binningResults.addAll(providerBinningResults); return binningResults; } - + /** * filter the features by missing rate */ @@ -383,7 +382,8 @@ private JObject selectByMissRate(FlowGraph flowGraph, FlowGraphNode node, Select // Find the FeatureStatistic node in the parent node FlowGraphNode featureStatisticNode = flowGraph.findOneNodeFromParent(node, x -> x.getComponentType() == ComponentType.FeatureStatistic - || x.getComponentType() == ComponentType.MixStatistic); + || x.getComponentType() == ComponentType.MixStatistic + || x.getComponentType() == ComponentType.HorzStatistic); if (featureStatisticNode == null) { throw new FlowNodeException(node, "请添加特征统计组件。"); @@ -429,10 +429,21 @@ private JObject selectByMissRate(FlowGraph flowGraph, FlowGraphNode node, Select missingValueMap.put(feature, bg.setScale(4, RoundingMode.HALF_UP).doubleValue()); } - - // Get the feature column of the current member - List currentMembers = input.getMembers().stream().filter(x -> x.getMemberId().equals(memberId) && x.getMemberRole() == JobMemberRole.valueOf(role)) - .collect(Collectors.toList()); + List currentMembers = new ArrayList<>(); + if(featureStatisticNode.getComponentType() == ComponentType.HorzStatistic) { + // Get the feature column of the members + currentMembers = input.getMembers().stream().collect(Collectors.toList()); + } + else if(featureStatisticNode.getComponentType() == ComponentType.MixStatistic && JobMemberRole.promoter.name().equalsIgnoreCase(memberObj.getString("role"))) { + // Get the feature column of the current member + currentMembers = input.getMembers().stream().filter(x -> x.getMemberRole() == JobMemberRole.promoter) + .collect(Collectors.toList()); + } + else { + // Get the feature column of the current member + currentMembers = input.getMembers().stream().filter(x -> x.getMemberId().equals(memberId) && x.getMemberRole() == JobMemberRole.valueOf(role)) + .collect(Collectors.toList()); + } // Assign values to features with missing values for (MemberModel model : currentMembers) { @@ -455,9 +466,18 @@ private JObject selectByMissRate(FlowGraph flowGraph, FlowGraphNode node, Select .append("featureNum", selectMembers.size()); } - /** - * Get feature list - */ + /** + * Get feature list + * + * has_feature_calculation: true 表示支持CV/IV过滤 从计算特征价值 组件获取CV值/IV值 + * has_feature_statistic: true 表示支持缺失率 特征统计组件获取缺失率 + * + * 1.做了特征统计(不管横向还是纵向还是混合),那就有 缺失率和cv + * + * 2.做了计算特征价值(只有纵向流程有),就有cv和iv。 + * + * 3.做了分箱(不管横向还是纵向还是混合),那就有iv + */ public GetFeatureApi.Output getResultFeature(GetFeatureApi.Input input) throws StatusCodeWithException { GetFeatureApi.Output out = new GetFeatureApi.Output(); FlowGraph graph = jobService.createFlowGraph(input.getFlowId()); @@ -469,7 +489,8 @@ public GetFeatureApi.Output getResultFeature(GetFeatureApi.Input input) throws S if (node.getComponentType() == ComponentType.FeatureSelection) { FlowGraphNode featureStatisticNode = graph.findOneNodeFromParent(node, x -> x.getComponentType() == ComponentType.MixStatistic - || x.getComponentType() == ComponentType.FeatureStatistic); + || x.getComponentType() == ComponentType.FeatureStatistic + || x.getComponentType() == ComponentType.HorzStatistic); out.setHasFeatureStatistic(false); out.setHasFeatureCalculation(false); if (featureStatisticNode != null && StringUtil.isNotEmpty(input.getJobId())) { @@ -479,7 +500,9 @@ public GetFeatureApi.Output getResultFeature(GetFeatureApi.Input input) throws S TaskResultMySqlModel featureStatisticResult = findByTaskIdAndTypeAndRole(featureStatisticTask.getTaskId(), TaskResultType.data_feature_statistic.name(), project.getMyRole()); if (featureStatisticResult != null) { - out.setHasFeatureStatistic(true); + out.setHasFeatureStatistic(true); // 缺失率 cv + out.setHasLossRate(true); + out.setHasCV(true); } } } @@ -489,29 +512,34 @@ public GetFeatureApi.Output getResultFeature(GetFeatureApi.Input input) throws S ProjectMySqlModel project = projectService.findProjectByJobId(input.getJobId()); TaskMySqlModel featureCalculationTask = taskRepository.findOne(input.getJobId(), featureCalculationNode.getNodeId(), project.getMyRole().name()); if (featureCalculationTask != null) { - TaskResultMySqlModel featureCalculationResult = findByTaskIdAndTypeAndRole(featureCalculationTask.getTaskId(), TaskResultType.model_result.name(), project.getMyRole()); if (featureCalculationResult != null) { - out.setHasFeatureCalculation(true); - } - } - } - - FlowGraphNode featureBinningNode = graph.findOneNodeFromParent(node, - x -> x.getComponentType() == ComponentType.MixBinning - || x.getComponentType() == ComponentType.Binning); - if (featureBinningNode != null && StringUtil.isNotEmpty(input.getJobId())) { - ProjectMySqlModel project = projectService.findProjectByJobId(input.getJobId()); - TaskMySqlModel featureBinningTask = taskRepository.findOne(input.getJobId(), - featureBinningNode.getNodeId(), project.getMyRole().name()); - if (featureBinningTask != null) { - TaskResultMySqlModel featureBinningResult = findByTaskIdAndTypeAndRole( - featureBinningTask.getTaskId(), TaskResultType.model_binning.name(), project.getMyRole()); - if (featureBinningResult != null) { - out.setHasFeatureCalculation(true && out.isHasFeatureStatistic()); + out.setHasFeatureCalculation(true); // cv_iv + out.setHasCV(true); + out.setHasIV(true); } } } + + FlowGraphNode featureBinningNode = graph.findOneNodeFromParent(node, + x -> x.getComponentType() == ComponentType.MixBinning + || x.getComponentType() == ComponentType.Binning + || x.getComponentType() == ComponentType.HorzFeatureBinning); + if (featureBinningNode != null && StringUtil.isNotEmpty(input.getJobId())) { + ProjectMySqlModel project = projectService.findProjectByJobId(input.getJobId()); + TaskMySqlModel featureBinningTask = taskRepository.findOne(input.getJobId(), + featureBinningNode.getNodeId(), project.getMyRole().name()); + if (featureBinningTask != null) { + TaskResultMySqlModel featureBinningResult = findByTaskIdAndTypeAndRole( + featureBinningTask.getTaskId(), TaskResultType.model_binning.name(), project.getMyRole()); + if (featureBinningResult != null) { + if (!out.isHasFeatureCalculation()) { + out.setHasFeatureCalculation(out.isHasFeatureStatistic()); + } + out.setHasIV(true); + } + } + } } List members = getMemberFeatures(graph, node); @@ -522,8 +550,9 @@ public GetFeatureApi.Output getResultFeature(GetFeatureApi.Input input) throws S /** * Find the feature column in the training data set: * take the feature column from (DataIO/binning/feature filtering) + * @throws StatusCodeWithException */ - public List getMemberFeatures(FlowGraph graph, FlowGraphNode node) throws FlowNodeException { + public List getMemberFeatures(FlowGraph graph, FlowGraphNode node) throws StatusCodeWithException { List nodeOutputItems = node.getComponent().findInputNodes(graph, node); // There is only one training data set by default, @@ -543,22 +572,22 @@ public List getMemberFeatures(FlowGraph graph, FlowGraph } else if (trainDataSetNodeOutputItem.getComponentType() == ComponentType.FeatureSelection) { return getFeatureSelectFeature(graph.getNode(trainDataSetNodeOutputItem.getNodeId()), graph); - } else if (trainDataSetNodeOutputItem.getComponentType() == ComponentType.HorzOneHot || trainDataSetNodeOutputItem.getComponentType() == ComponentType.VertOneHot){ - return getOneHotFeature(graph.getNode(trainDataSetNodeOutputItem.getNodeId()), graph); + } else if (trainDataSetNodeOutputItem.getComponentType() == ComponentType.HorzOneHot || trainDataSetNodeOutputItem.getComponentType() == ComponentType.VertOneHot) { + return getOneHotFeature(graph.getNode(trainDataSetNodeOutputItem.getNodeId()), graph); } else { return getMemberFeatures(graph, graph.getNode(trainDataSetNodeOutputItem.getNodeId())); } } - private List getOneHotFeature(FlowGraphNode node, FlowGraph flowGraph) - throws FlowNodeException { - List members = new ArrayList<>(); + private List getOneHotFeature(FlowGraphNode node, FlowGraph flowGraph) + throws StatusCodeWithException { + List members = new ArrayList<>(); - FlowGraphNode dataIONode = flowGraph.findOneNodeFromParent(node, ComponentType.DataIO); - DataIOComponent.Params dataIOParams = JObject.create(dataIONode.getParams()) - .toJavaObject(DataIOComponent.Params.class); + FlowGraphNode dataIONode = flowGraph.findOneNodeFromParent(node, ComponentType.DataIO); + DataIOComponent.Params dataIOParams = JObject.create(dataIONode.getParams()) + .toJavaObject(DataIOComponent.Params.class); - List dataSetItems = dataIOParams.getDataSetList(); + List dataSetItems = dataIOParams.getDataSetList(); // need filter VertOneHotComponent.Params params = JObject.create(node.getParams()) @@ -590,71 +619,71 @@ private List getOneHotFeature(FlowGraphNode node, FlowGr } } } - DataSetMysqlModel myTmpDataSet = datasetService.query(flowGraph.getLastJob().getJobId(), - node.getComponentType()); - if (myTmpDataSet != null) { - for (MemberFeatureInfoModel member : members) { - if (!member.getMemberId().equalsIgnoreCase(CacheObjects.getMemberId())) { - DetailApi.Input input = new DetailApi.Input(); - input.setId(myTmpDataSet.getId()); - try { - ApiResult apiResult = gatewayService.sendToBoardRedirectApi(member.getMemberId(), - JobMemberRole.promoter, input, DetailApi.class); - if (apiResult.data != null) { - DataSetOutputModel output = JObject.create(apiResult.data) - .toJavaObject(DataSetOutputModel.class); - LOG.info("getOneHotFeature request : " + JObject.toJSONString(input)); - List newColumnNameList = new ArrayList<>( - Arrays.asList(output.getFeatureNameList().split(","))); - List oldFeatures = member.getFeatures(); - - List newFeatures = new ArrayList<>(); - for (MemberFeatureInfoModel.Feature feature : oldFeatures) { - if (newColumnNameList.contains(feature.getName())) { - newFeatures.add(feature); - newColumnNameList.remove(feature.getName()); - } - } - if (newColumnNameList != null && !newColumnNameList.isEmpty()) { - for (String s : newColumnNameList) { - MemberFeatureInfoModel.Feature f = new MemberFeatureInfoModel.Feature(); - f.setName(s); - newFeatures.add(f); - } - } - member.setFeatures(newFeatures); - } - } catch (MemberGatewayException e) { - throw new FlowNodeException(node, member.getMemberId()); - } - } else { - List newColumnNameList = new ArrayList<>( - Arrays.asList(myTmpDataSet.getFeatureNameList().split(","))); - List oldFeatures = member.getFeatures(); - - List newFeatures = new ArrayList<>(); - for (MemberFeatureInfoModel.Feature feature : oldFeatures) { - if (newColumnNameList.contains(feature.getName())) { - newFeatures.add(feature); - newColumnNameList.remove(feature.getName()); - } - } - if (newColumnNameList != null && !newColumnNameList.isEmpty()) { - for (String s : newColumnNameList) { - MemberFeatureInfoModel.Feature f = new MemberFeatureInfoModel.Feature(); - f.setName(s); - newFeatures.add(f); - } - } - member.setFeatures(newFeatures); - } - } - } - return members; - } + if (flowGraph.getLastJob() != null) { + TableDataSetMysqlModel myTmpDataSet = tableDataSetService.query(flowGraph.getLastJob().getJobId(), + node.getComponentType()); + if (myTmpDataSet != null) { + for (MemberFeatureInfoModel member : members) { + if (!member.getMemberId().equalsIgnoreCase(CacheObjects.getMemberId())) { + DetailApi.Input input = new DetailApi.Input(); + input.setId(myTmpDataSet.getId()); + try { + TableDataSetOutputModel output = gatewayService.callOtherMemberBoard(member.getMemberId(), + JobMemberRole.promoter, DetailApi.class, input, TableDataSetOutputModel.class); + if (output != null) { + LOG.info("getOneHotFeature request : " + JObject.toJSONString(input)); + List newColumnNameList = new ArrayList<>( + Arrays.asList(output.getFeatureNameList().split(","))); + List oldFeatures = member.getFeatures(); + + List newFeatures = new ArrayList<>(); + for (MemberFeatureInfoModel.Feature feature : oldFeatures) { + if (newColumnNameList.contains(feature.getName())) { + newFeatures.add(feature); + newColumnNameList.remove(feature.getName()); + } + } + if (newColumnNameList != null && !newColumnNameList.isEmpty()) { + for (String s : newColumnNameList) { + MemberFeatureInfoModel.Feature f = new MemberFeatureInfoModel.Feature(); + f.setName(s); + newFeatures.add(f); + } + } + member.setFeatures(newFeatures); + } + } catch (MemberGatewayException e) { + throw new FlowNodeException(node, member.getMemberId()); + } + } else { + List newColumnNameList = new ArrayList<>( + Arrays.asList(myTmpDataSet.getFeatureNameList().split(","))); + List oldFeatures = member.getFeatures(); + + List newFeatures = new ArrayList<>(); + for (MemberFeatureInfoModel.Feature feature : oldFeatures) { + if (newColumnNameList.contains(feature.getName())) { + newFeatures.add(feature); + newColumnNameList.remove(feature.getName()); + } + } + if (newColumnNameList != null && !newColumnNameList.isEmpty()) { + for (String s : newColumnNameList) { + MemberFeatureInfoModel.Feature f = new MemberFeatureInfoModel.Feature(); + f.setName(s); + newFeatures.add(f); + } + } + member.setFeatures(newFeatures); + } + } + } + } + return members; + } - /** + /** * From the feature column in the DataIO node params */ public List getDataIOFeature(FlowGraphNode node) { diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskService.java index fd6e46c0f..7ab4c025f 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/TaskService.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,19 +21,19 @@ import com.welab.wefe.board.service.api.project.job.task.DetailApi; import com.welab.wefe.board.service.component.OotComponent; import com.welab.wefe.board.service.database.entity.job.*; +import com.welab.wefe.board.service.database.repository.JobRepository; import com.welab.wefe.board.service.database.repository.TaskRepository; import com.welab.wefe.board.service.dto.entity.DataIoTaskFeatureInfoOutputModel; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.ComponentType; -import com.welab.wefe.common.enums.JobMemberRole; -import com.welab.wefe.common.enums.OrderBy; -import com.welab.wefe.common.enums.TaskStatus; +import com.welab.wefe.common.data.mysql.enums.OrderBy; import com.welab.wefe.common.exception.StatusCodeWithException; import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.web.CurrentAccount; -import com.welab.wefe.common.web.dto.ApiResult; +import com.welab.wefe.common.wefe.enums.ComponentType; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.TaskStatus; import org.apache.commons.collections4.CollectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -76,6 +76,8 @@ public class TaskService { @Autowired private ProjectFlowNodeService projectFlowNodeService; + @Autowired + private JobRepository jobRepository; /** * Query all execution records of a node @@ -94,18 +96,38 @@ public List findTaskHistory(String flowId, String flowNodeId, Jo ); } - public TaskMySqlModel findOne(DetailApi.Input input) { + + public TaskMySqlModel findOne(DetailApi.Input input) throws StatusCodeWithException { if (StringUtil.isNotEmpty(input.getTaskId())) { return findOne(input.getTaskId()); } else { - ProjectMySqlModel project = projectService.findProjectByJobId(input.getJobId()); + String jobId = input.getJobId(); + ProjectMySqlModel project; + if (StringUtil.isEmpty(jobId)) { + if (StringUtil.isEmpty(input.getFlowNodeId())) { + StatusCode + .PARAMETER_VALUE_INVALID + .throwException("job_id 不传的时候 flow_id 必须要指定"); + } + + // 通过 flow_id 获取最后一个 job + ProjectFlowMySqlModel flow = projectFlowService.findOne(input.getFlowNodeId()); + project = projectService.findByProjectId(flow.getProjectId()); + JobMySqlModel job = jobRepository.findLastByFlowId(input.getFlowId(), project.getMyRole().name()); + if (job == null) { + return null; + } + jobId = job.getJobId(); + } else { + project = projectService.findProjectByJobId(input.getJobId()); + } if (project == null) { return null; } - return taskRepo.findOne(input.getJobId(), input.getFlowNodeId(), project.getMyRole().name()); + return taskRepo.findOne(jobId, input.getFlowNodeId(), project.getMyRole().name()); } } @@ -343,11 +365,15 @@ private List findDataIoTaskFeaturesWithOot(Que } } else if (JobMemberRole.provider.equals(jobMemberMySqlModel.getJobRole())) { // The provider needs to send a request to the other party to obtain - ApiResult apiResult = gatewayService.sendToBoardRedirectApi(jobMemberMySqlModel.getMemberId(), JobMemberRole.promoter, queryTaskConfigInput, QueryDataIoTaskConfigApi.class); - if (0 != apiResult.code) { - throw new StatusCodeWithException("获取成员[" + memberName + "]的入模特征失败,原因:" + apiResult.message, StatusCode.SYSTEM_ERROR); - } - JObject data = JObject.create(apiResult.data); + Object result = gatewayService.callOtherMemberBoard( + jobMemberMySqlModel.getMemberId(), + JobMemberRole.promoter, + QueryDataIoTaskConfigApi.class, + queryTaskConfigInput, + Object.class + ); + + JObject data = JObject.create(result); if (null == data || data.isEmpty()) { throw new StatusCodeWithException("获取成员[" + memberName + "]的入模特征为空。", StatusCode.DATA_NOT_FOUND); } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/WebSocketServer.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/WebSocketServer.java index f5ca53484..5fddb3dd5 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/WebSocketServer.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/WebSocketServer.java @@ -1,12 +1,12 @@ -/** +/* * Copyright 2021 Tianmian Tech. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -24,6 +24,7 @@ import com.welab.wefe.common.util.JObject; import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.web.service.account.AccountInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Component; @@ -80,7 +81,7 @@ public void onOpen(Session session, @PathParam("token") String token) { this.session = session; this.token = token; - CurrentAccount.Info info = CurrentAccount.get(token); + AccountInfo info = CurrentAccount.get(token); if (info == null) { log.error("Illegal user, the token does not exist: " + token); try { @@ -119,7 +120,7 @@ public void onClose() { ONLINE_COUNT.decrementAndGet(); } - CurrentAccount.Info info = CurrentAccount.get(token); + AccountInfo info = CurrentAccount.get(token); log.info("User exit: " + info.phoneNumber + ", token: " + token + ",the number of people currently online is:" + ONLINE_COUNT.get()); } @@ -132,7 +133,7 @@ public void onClose() { public void onMessage(String message, Session session) throws IOException { log.info("User message: {},content: {}", token, message); - CurrentAccount.Info info = CurrentAccount.get(token); + AccountInfo info = CurrentAccount.get(token); if (info == null) { sendMessage(responseNonchatMessage(StatusCode.LOGIN_REQUIRED.getCode(), "token无效,请重新登录再试", null)); return; diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/account/AccountService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/account/AccountService.java index 57016fa9c..70d8ca1dd 100644 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/account/AccountService.java +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/account/AccountService.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,45 +16,54 @@ package com.welab.wefe.board.service.service.account; -import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONArray; import com.welab.wefe.board.service.api.account.*; -import com.welab.wefe.board.service.database.entity.AccountMySqlModel; +import com.welab.wefe.board.service.database.entity.AccountMysqlModel; import com.welab.wefe.board.service.database.repository.AccountRepository; import com.welab.wefe.board.service.dto.base.PagingOutput; import com.welab.wefe.board.service.dto.entity.AccountOutputModel; import com.welab.wefe.board.service.dto.vo.AccountInputModel; import com.welab.wefe.board.service.dto.vo.OnlineAccountOutput; -import com.welab.wefe.board.service.sdk.UnionService; -import com.welab.wefe.board.service.service.AbstractService; import com.welab.wefe.board.service.service.CacheObjects; import com.welab.wefe.board.service.service.GatewayService; import com.welab.wefe.board.service.service.WebSocketServer; import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.board.service.service.verificationcode.VerificationCodeService; +import com.welab.wefe.board.service.util.BoardSM4Util; import com.welab.wefe.common.StatusCode; import com.welab.wefe.common.data.mysql.Where; -import com.welab.wefe.common.enums.*; +import com.welab.wefe.common.data.mysql.enums.OrderBy; import com.welab.wefe.common.exception.StatusCodeWithException; -import com.welab.wefe.common.util.*; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.util.Md5; +import com.welab.wefe.common.util.Sha1; +import com.welab.wefe.common.util.StringUtil; import com.welab.wefe.common.web.CurrentAccount; -import com.welab.wefe.common.web.LoginSecurityPolicy; -import com.welab.wefe.common.web.dto.ApiResult; -import com.welab.wefe.common.web.service.CaptchaService; +import com.welab.wefe.common.web.service.account.AbstractAccountService; +import com.welab.wefe.common.web.service.account.AccountInfo; +import com.welab.wefe.common.web.service.account.HistoryPasswordItem; +import com.welab.wefe.common.wefe.enums.AuditStatus; +import com.welab.wefe.common.wefe.enums.BoardUserSource; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import com.welab.wefe.common.wefe.enums.VerificationCodeBusinessType; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.RandomStringUtils; -import org.modelmapper.ModelMapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.jpa.domain.Specification; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -import java.security.SecureRandom; -import java.util.*; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Random; /** * @author Zane */ @Service -public class AccountService extends AbstractService { +public class AccountService extends AbstractAccountService { @Autowired private AccountRepository accountRepository; @@ -65,20 +74,20 @@ public class AccountService extends AbstractService { private GlobalConfigService globalConfigService; @Autowired - private UnionService unionService; + private VerificationCodeService verificationCodeService; /** * Paging query account */ public PagingOutput query(QueryApi.Input input) throws StatusCodeWithException { - Specification where = Where + Specification where = Where .create() - .contains("phoneNumber", input.getPhoneNumber()) + .contains("phoneNumber", BoardSM4Util.encryptPhoneNumber(input.getPhoneNumber())) .equal("auditStatus", input.getAuditStatus()) .contains("nickname", input.getNickname()) .orderBy("createdTime", OrderBy.desc) - .build(AccountMySqlModel.class); + .build(AccountMysqlModel.class); return accountRepository.paging(where, input, AccountOutputModel.class); } @@ -89,7 +98,7 @@ public PagingOutput query(QueryApi.Input input) throws Statu public void register(AccountInputModel input, BoardUserSource userSource) throws StatusCodeWithException { // Determine whether the account is registered - AccountMySqlModel one = accountRepository.findOne("phoneNumber", input.getPhoneNumber(), AccountMySqlModel.class); + AccountMysqlModel one = accountRepository.findOne("phoneNumber", BoardSM4Util.encryptPhoneNumber(input.getPhoneNumber()), AccountMysqlModel.class); if (one != null) { throw new StatusCodeWithException("该手机号已被注册!", StatusCode.DATA_EXISTED); } @@ -100,7 +109,7 @@ public void register(AccountInputModel input, BoardUserSource userSource) throws // sha hash String password = Sha1.of(input.getPassword() + salt); - AccountMySqlModel model = new AccountMySqlModel(); + AccountMysqlModel model = new AccountMysqlModel(); model.setCreatedBy(CurrentAccount.id()); model.setPhoneNumber(input.getPhoneNumber()); model.setNickname(input.getNickname()); @@ -110,6 +119,7 @@ public void register(AccountInputModel input, BoardUserSource userSource) throws model.setSuperAdminRole(accountRepository.count() < 1); model.setAdminRole(model.getSuperAdminRole()); model.setEnable(true); + model.setLastActionTime(new Date()); // Super administrator does not need to review if (model.getSuperAdminRole() || userSource == BoardUserSource.online_demo) { @@ -134,121 +144,32 @@ public void register(AccountInputModel input, BoardUserSource userSource) throws CacheObjects.refreshAccountMap(); } - /** - * login - */ - public LoginApi.Output login(String phoneNumber, String password, String key, String code) throws StatusCodeWithException { - - if (!config.getEnvName().isTestEnv()) { - // Verification code verification - if (!CaptchaService.verify(key, code)) { - throw new StatusCodeWithException("验证码错误!", StatusCode.PARAMETER_VALUE_INVALID); - } - } - - // Check if it's in the small black room - if (LoginSecurityPolicy.inDarkRoom(phoneNumber)) { - throw new StatusCodeWithException("账号已被禁止登陆,请一个小时后再试,或联系管理员。", StatusCode.PARAMETER_VALUE_INVALID); - } - - AccountMySqlModel model = accountRepository.findOne("phoneNumber", phoneNumber, AccountMySqlModel.class); - // phone number error - if (model == null) { - throw new StatusCodeWithException("手机号错误,该用户不存在。", StatusCode.PARAMETER_VALUE_INVALID); - } - - if (!model.getEnable()) { - throw new StatusCodeWithException("用户被禁用,请联系管理员。", StatusCode.PERMISSION_DENIED); - } - - // wrong password - if (!model.getPassword().equals(Sha1.of(password + model.getSalt()))) { - - // Log a login failure event - LoginSecurityPolicy.onLoginFail(phoneNumber); - throw new StatusCodeWithException("手机号或密码错误,连续错误 6 次会被禁止登陆,可以联系管理员重置密码找回账号。", StatusCode.PARAMETER_VALUE_INVALID); - } - - // Check audit status - if (model.getAuditStatus() != null) { - switch (model.getAuditStatus()) { - case auditing: - AccountMySqlModel superAdmin = findSuperAdmin(); - - throw new StatusCodeWithException("账号尚未审核,请联系管理员 " + superAdmin.getNickname() + " (或其他任意管理员)对您的账号进行审核后再尝试登录!", StatusCode.PARAMETER_VALUE_INVALID); - case disagree: - throw new StatusCodeWithException("账号审核不通过:" + model.getAuditComment(), StatusCode.PARAMETER_VALUE_INVALID); - default: - } - } - - String token = UUID.randomUUID().toString(); - CurrentAccount.logined(token, model.getId(), model.getPhoneNumber(), model.getAdminRole(), model.getSuperAdminRole()); - - LoginApi.Output output = new ModelMapper().map(model, LoginApi.Output.class); - output.setToken(token); - - // Record a successful login event - LoginSecurityPolicy.onLoginSuccess(phoneNumber); - - return output; - } - - /** - * update password - */ - public void updatePassword(String oldPassword, String newPassword) throws StatusCodeWithException { - - String phoneNumber = CurrentAccount.phoneNumber(); - if (phoneNumber == null) { - throw new StatusCodeWithException(StatusCode.LOGIN_REQUIRED); - } - - AccountMySqlModel model = accountRepository.findByPhoneNumber(phoneNumber); - - // Check old password - if (!StringUtil.equals(model.getPassword(), Sha1.of(oldPassword + model.getSalt()))) { - throw new StatusCodeWithException("您输入的旧密码不正确", StatusCode.PARAMETER_VALUE_INVALID); - } - - // Regenerate salt - String salt = createRandomSalt(); - - // sha hash - newPassword = Sha1.of(newPassword + salt); - + @Override + public void saveSelfPassword(String password, String salt, JSONArray historyPasswords) throws StatusCodeWithException { + AccountMysqlModel model = accountRepository.findById(CurrentAccount.id()).orElse(null); + model.setPassword(password); model.setSalt(salt); - model.setPassword(newPassword); - + model.setHistoryPasswordList(historyPasswords); accountRepository.save(model); } - /** * query all of account */ - public List queryAll() { + public List queryAll() { return accountRepository.findAll(); } - private String createRandomSalt() { - final Random r = new SecureRandom(); - byte[] salt = new byte[16]; - r.nextBytes(salt); - - return Base64Util.encode(salt); - } - /** * The administrator reviews the account */ public void audit(AuditApi.Input input) throws StatusCodeWithException { - AccountMySqlModel auditor = accountRepository.findById(CurrentAccount.id()).orElse(null); + AccountMysqlModel auditor = accountRepository.findById(CurrentAccount.id()).orElse(null); if (!auditor.getAdminRole()) { throw new StatusCodeWithException("您不是管理员,无权执行审核操作!", StatusCode.PARAMETER_VALUE_INVALID); } - AccountMySqlModel account = accountRepository.findById(input.getAccountId()).orElse(null); + AccountMysqlModel account = accountRepository.findById(input.getAccountId()).orElse(null); if (account.getAuditStatus() != AuditStatus.auditing) { throw new StatusCodeWithException("该用户已被审核,请勿重复操作!", StatusCode.PARAMETER_VALUE_INVALID); } @@ -260,21 +181,47 @@ public void audit(AuditApi.Input input) throws StatusCodeWithException { } - /** - * Query super administrator - */ - public AccountMySqlModel findSuperAdmin() { - List list = accountRepository.findAll(Where + @Override + public AccountInfo getAccountInfo(String phoneNumber) throws StatusCodeWithException { + AccountMysqlModel model = accountRepository.findByPhoneNumber(BoardSM4Util.encryptPhoneNumber(phoneNumber)); + return toAccountInfo(model); + } + + private AccountInfo toAccountInfo(AccountMysqlModel model) throws StatusCodeWithException { + if (model == null) { + return null; + } + + AccountInfo info = new AccountInfo(); + info.setId(model.getId()); + info.setPhoneNumber(model.getPhoneNumber()); + info.setNickname(model.getNickname()); + info.setPassword(model.getPassword()); + info.setSalt(model.getSalt()); + info.setAuditStatus(model.getAuditStatus()); + info.setAuditComment(model.getAuditComment()); + info.setAdminRole(model.getAdminRole()); + info.setSuperAdminRole(model.getSuperAdminRole()); + info.setEnable(model.getEnable()); + info.setCancelled(model.isCancelled()); + info.setHistoryPasswordList(model.getHistoryPasswordList()); + return info; + } + + + @Override + public AccountInfo getSuperAdmin() throws StatusCodeWithException { + List list = accountRepository.findAll(Where .create() .equal("superAdminRole", true) - .build(AccountMySqlModel.class) + .build(AccountMysqlModel.class) ); if (list.isEmpty()) { return null; } - return list.get(0); + return toAccountInfo(list.get(0)); } /** @@ -302,7 +249,7 @@ private void updateAdminRole(UpdateApi.Input input) throws StatusCodeWithExcepti return; } - AccountMySqlModel account = accountRepository.findById(input.id).orElse(null); + AccountMysqlModel account = accountRepository.findById(input.id).orElse(null); if (account == null) { throw new StatusCodeWithException("找不到更新的用户信息。", StatusCode.DATA_NOT_FOUND); @@ -316,7 +263,7 @@ private void updateAdminRole(UpdateApi.Input input) throws StatusCodeWithExcepti } private void updateBaseInfo(UpdateApi.Input input) throws StatusCodeWithException { - AccountMySqlModel account = accountRepository.findById(CurrentAccount.id()).orElse(null); + AccountMysqlModel account = accountRepository.findById(CurrentAccount.id()).orElse(null); if (StringUtil.isNotEmpty(input.getNickname())) { account.setNickname(input.getNickname()); @@ -345,7 +292,7 @@ public void enable(EnableApi.Input input) throws StatusCodeWithException { throw new StatusCodeWithException("无法对自己进行此操作。", StatusCode.PERMISSION_DENIED); } - AccountMySqlModel account = accountRepository.findById(input.getId()).orElse(null); + AccountMysqlModel account = accountRepository.findById(input.getId()).orElse(null); if (account == null) { throw new StatusCodeWithException("找不到更新的用户信息。", StatusCode.DATA_NOT_FOUND); } @@ -370,7 +317,14 @@ public void enable(EnableApi.Input input) throws StatusCodeWithException { * Reset user password (administrator rights) */ public String resetPassword(ResetPasswordApi.Input input) throws StatusCodeWithException { - AccountMySqlModel model = accountRepository.findById(input.getId()).orElse(null); + // 操作者 + AccountMysqlModel operator = accountRepository.findById(CurrentAccount.id()).orElse(null); + if (!super.verifyPassword(operator.getPassword(), input.getOperatorPassword(), operator.getSalt())) { + throw new StatusCodeWithException("密码错误,身份核实失败,已退出登录。", StatusCode.PERMISSION_DENIED); + } + + // 被重置密码的账号 + AccountMysqlModel model = accountRepository.findById(input.getId()).orElse(null); if (model == null) { throw new StatusCodeWithException("找不到更新的用户信息。", StatusCode.DATA_NOT_FOUND); @@ -380,6 +334,14 @@ public String resetPassword(ResetPasswordApi.Input input) throws StatusCodeWithE throw new StatusCodeWithException("非管理员无法重置密码。", StatusCode.PERMISSION_DENIED); } + if (model.getSuperAdminRole()) { + throw new StatusCodeWithException("不能重置超级管理员密码。", StatusCode.PERMISSION_DENIED); + } + + if (model.getAdminRole() && !CurrentAccount.isSuperAdmin()) { + throw new StatusCodeWithException("只有超级管理员才能重置管理员的密码", StatusCode.PERMISSION_DENIED); + } + String salt = createRandomSalt(); String newPassword = RandomStringUtils.randomAlphanumeric(2) + new Random().nextInt(999999); @@ -405,16 +367,13 @@ public PagingOutput queryMemberAccounts(QueryMemberAccountsA if (CacheObjects.getMemberId().equals(input.getMemberId())) { pagingOutput = query(input); } else { - ApiResult apiResult = gatewayService.sendToBoardRedirectApi(input.getMemberId(), JobMemberRole.promoter, input, QueryApi.class); - if (0 == apiResult.code) { - if (null == apiResult.data) { - return null; - } - JObject dataObj = JObject.create(apiResult.data); - pagingOutput = JObject.parseObject(dataObj.toJSONString(), pagingOutput.getClass()); - } else { - throw new StatusCodeWithException(apiResult.message, StatusCode.SYSTEM_ERROR); - } + pagingOutput = gatewayService.callOtherMemberBoard( + input.getMemberId(), + JobMemberRole.promoter, + QueryApi.class, + input, + pagingOutput.getClass() + ); } List accountOutputModelList = new ArrayList<>(); @@ -455,11 +414,15 @@ public List queryOnlineAccount(QueryOnlineApi.Input input) try { JObject data = JObject.create().append("memberId", input.getMemberId()) .append("accountId", input.getAccountId()); - ApiResult apiResult = gatewayService.sendToBoardRedirectApi(input.getMemberId(), JobMemberRole.promoter, data, QueryOnlineApi.class); - if (apiResult.code != 0) { - throw new StatusCodeWithException("系统异常: " + apiResult.message, StatusCode.SYSTEM_ERROR); - } - QueryOnlineApi.Output output = JSONObject.toJavaObject(JObject.create(apiResult.data), QueryOnlineApi.Output.class); + + QueryOnlineApi.Output output = gatewayService.callOtherMemberBoard( + input.getMemberId(), + JobMemberRole.promoter, + QueryOnlineApi.class, + data, + QueryOnlineApi.Output.class + ); + return output.getList(); } catch (Exception e) { throw new StatusCodeWithException("系统异常: " + e.getMessage(), StatusCode.SYSTEM_ERROR); @@ -469,8 +432,8 @@ public List queryOnlineAccount(QueryOnlineApi.Input input) /** * Check whether the user with the specified mobile phone number exists */ - public boolean exist(String phoneNumber) { - AccountMySqlModel model = accountRepository.findOne("phoneNumber", phoneNumber, AccountMySqlModel.class); + public boolean exist(String phoneNumber) throws StatusCodeWithException { + AccountMysqlModel model = accountRepository.findOne("phoneNumber", BoardSM4Util.encryptPhoneNumber(phoneNumber), AccountMysqlModel.class); return model != null; } @@ -479,7 +442,7 @@ public boolean exist(String phoneNumber) { * Transfer the super administrator status to another account */ @Transactional(rollbackFor = Exception.class) - public void changeSuperAdmin(AccountMySqlModel account) throws StatusCodeWithException { + public void changeSuperAdmin(AccountMysqlModel account) throws StatusCodeWithException { account.setAdminRole(true); account.setSuperAdminRole(true); account.setUpdatedBy(CurrentAccount.id()); @@ -489,6 +452,10 @@ public void changeSuperAdmin(AccountMySqlModel account) throws StatusCodeWithExc accountRepository.save(account); // Cancel the super administrator privileges of the current account accountRepository.cancelSuperAdmin(CurrentAccount.id()); + + CurrentAccount.logout(account.getId()); + CurrentAccount.logout(CurrentAccount.id()); + } public void forgetPassword(ForgetPasswordApi.Input input) throws StatusCodeWithException { @@ -499,10 +466,10 @@ public void forgetPassword(ForgetPasswordApi.Input input) throws StatusCodeWithE throw new StatusCodeWithException("密码不能为空。", StatusCode.PARAMETER_VALUE_INVALID); } if (StringUtil.isEmpty(input.getSmsVerificationCode())) { - throw new StatusCodeWithException("短信验证码不能为空。", StatusCode.PARAMETER_VALUE_INVALID); + throw new StatusCodeWithException("验证码不能为空。", StatusCode.PARAMETER_VALUE_INVALID); } - AccountMySqlModel model = accountRepository.findOne("phoneNumber", input.getPhoneNumber(), AccountMySqlModel.class); + AccountMysqlModel model = accountRepository.findOne("phoneNumber", BoardSM4Util.encryptPhoneNumber(input.getPhoneNumber()), AccountMysqlModel.class); // phone number error if (model == null) { throw new StatusCodeWithException("手机号错误,该用户不存在。", StatusCode.PARAMETER_VALUE_INVALID); @@ -511,12 +478,25 @@ public void forgetPassword(ForgetPasswordApi.Input input) throws StatusCodeWithE throw new StatusCodeWithException("用户被禁用,请联系管理员。", StatusCode.PERMISSION_DENIED); } - unionService.checkVerificationCode(input.getPhoneNumber(), input.getSmsVerificationCode(), SmsBusinessType.AccountForgetPasswordVerificationCode); + AccountInfo accountInfo = toAccountInfo(model); + int historyCount = 4; + if (inHistoryPassword(input.getPassword(), historyCount, accountInfo)) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("您输入的新密码必须与前四次设置的密码不一致"); + } + + // Check verification code is valid? + verificationCodeService.checkVerificationCode(input.getPhoneNumber(), input.getSmsVerificationCode(), VerificationCodeBusinessType.accountForgetPassword); + + // 当前密码成为历史 + accountInfo.getHistoryPasswordList().add(new HistoryPasswordItem(accountInfo.getPassword(), accountInfo.getSalt())); + // 历史密码 + String historyPasswordListString = JSON.toJSONString(accountInfo.getPasswordHistoryList(historyCount - 1)); // Regenerate salt String salt = createRandomSalt(); model.setSalt(salt); model.setPassword(Sha1.of(input.getPassword() + salt)); + model.setHistoryPasswordList(JSON.parseArray(historyPasswordListString)); accountRepository.save(model); } } diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/ServerAvailableCheckService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/ServerAvailableCheckService.java deleted file mode 100644 index fe6ba934f..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/ServerAvailableCheckService.java +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.service.available; - -import com.welab.wefe.board.service.service.AbstractService; -import org.springframework.stereotype.Service; - -/** - * @author zane - */ -@Service -public class ServerAvailableCheckService extends AbstractService { - -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/AbstractCheckpoint.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/AbstractCheckpoint.java deleted file mode 100644 index ede56cd06..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/AbstractCheckpoint.java +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.service.available.checkpoint; - -import com.welab.wefe.board.service.constant.Config; -import com.welab.wefe.board.service.dto.vo.ServerCheckPointOutput; -import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; -import com.welab.wefe.common.CommonThreadPool; -import com.welab.wefe.common.web.Launcher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; - -/** - * @author zane - */ -public abstract class AbstractCheckpoint { - protected final Logger LOG = LoggerFactory.getLogger(this.getClass()); - protected final Config config = Launcher.CONTEXT.getBean(Config.class); - protected final GlobalConfigService globalConfigService = Launcher.CONTEXT.getBean(GlobalConfigService.class); - - public abstract String desc(); - - public abstract String value(); - - protected abstract void doCheck() throws Exception; - - public ServerCheckPointOutput check() { - long start = System.currentTimeMillis(); - - Future future = CommonThreadPool.submit(() -> { - try { - if (value() == null) { - throw new Exception("相关配置为空,请进行设置后再进行检查。"); - } - doCheck(); - } catch (Exception e) { - return e; - } - return null; - }); - - Exception e; - try { - e = future.get(5, TimeUnit.SECONDS); - } catch (Exception ex) { - e = ex; - } - - ServerCheckPointOutput output = new ServerCheckPointOutput(); - output.setDesc(desc()); - output.setValue(value()); - output.setSpend(System.currentTimeMillis() - start); - - if (e == null) { - output.setSuccess(true); - output.setMessage("success"); - } else { - output.setSuccess(false); - output.setMessage(e.getMessage()); - } - return output; - } - - protected void log(Exception e) { - LOG.error(e.getClass() + " " + e.getMessage(), e); - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/GatewayInternetCheckpoint.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/GatewayInternetCheckpoint.java deleted file mode 100644 index 8e9b47132..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/GatewayInternetCheckpoint.java +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.service.available.checkpoint; - -import com.welab.wefe.board.service.dto.globalconfig.MemberInfoModel; - -/** - * @author zane - */ -public class GatewayInternetCheckpoint extends AbstractCheckpoint { - - @Override - public String desc() { - return "检查 board 与 gateway 服务在公网的的连通性"; - } - - @Override - public String value() { - MemberInfoModel memberInfo = globalConfigService.getMemberInfo(); - if (memberInfo == null) { - return null; - } - return memberInfo.getMemberGatewayUri(); - } - - @Override - protected void doCheck() throws Exception { - - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/GatewayIntranetCheckpoint.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/GatewayIntranetCheckpoint.java deleted file mode 100644 index 13eef59bb..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/GatewayIntranetCheckpoint.java +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.service.available.checkpoint; - -import com.welab.wefe.board.service.service.GatewayService; -import com.welab.wefe.common.web.Launcher; - -/** - * @author zane - */ -public class GatewayIntranetCheckpoint extends AbstractCheckpoint { - - @Override - public String desc() { - return "检查 board 与 gateway 服务在内网的的连通性"; - } - - @Override - public String value() { - return globalConfigService.getGatewayConfig().intranetBaseUri; - } - - @Override - protected void doCheck() throws Exception { - GatewayService gatewayService = Launcher.CONTEXT.getBean(GatewayService.class); - - // Since the gateway does not currently have an alive interface, - // temporarily adjust a method to test the connectivity between the board and the gateway. - gatewayService.refreshMemberBlacklistCache(); - - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/StorageCheckpoint.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/StorageCheckpoint.java deleted file mode 100644 index 543a5cb1b..000000000 --- a/board/board-service/src/main/java/com/welab/wefe/board/service/service/available/checkpoint/StorageCheckpoint.java +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2021 Tianmian Tech. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.welab.wefe.board.service.service.available.checkpoint; - -import com.welab.wefe.common.data.storage.config.JdbcParamConfig; -import com.welab.wefe.common.data.storage.model.DataItemModel; -import com.welab.wefe.common.data.storage.repo.Storage; -import com.welab.wefe.common.data.storage.service.StorageService; -import com.welab.wefe.common.web.Launcher; -import org.apache.commons.lang3.RandomStringUtils; - -import static com.welab.wefe.board.service.service.DataSetStorageService.DATABASE_NAME; - -/** - * @author zane - */ -public class StorageCheckpoint extends AbstractCheckpoint { - - @Override - public String desc() { - return "检查 board 对 storage 服务的访问是否正常"; - } - - @Override - public String value() { - JdbcParamConfig storageConfig = Launcher.CONTEXT.getBean(JdbcParamConfig.class); - return storageConfig.getUrl(); - } - - @Override - protected void doCheck() throws Exception { - - StorageService service = Launcher.CONTEXT.getBean(StorageService.class); - Storage storage = service.getStorage(); - String name = RandomStringUtils.randomAlphabetic(6); - try { - storage.put(DATABASE_NAME, name, new DataItemModel<>(name, "test")); - } catch (Exception e) { - super.log(e); - throw new Exception(config.getDbType().name() + " put异常,请检查相关配置是否正确。"); - } - - try { - storage.dropTB(DATABASE_NAME, name); - } catch (Exception e) { - super.log(e); - throw new Exception(config.getDbType().name() + " drop异常,请检查相关配置是否正确。"); - } - - } -} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/FlowCheckpoint.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/FlowCheckpoint.java new file mode 100644 index 000000000..ea468d0e9 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/FlowCheckpoint.java @@ -0,0 +1,65 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.service.checkpoint; + +import com.welab.wefe.board.service.dto.globalconfig.FlowConfigModel; +import com.welab.wefe.board.service.sdk.FlowService; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.wefe.checkpoint.AbstractCheckpoint; +import com.welab.wefe.common.wefe.enums.ServiceType; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +/** + * @author zane + * @date 2021/12/22 + */ +@Service +public class FlowCheckpoint extends AbstractCheckpoint { + @Autowired + private GlobalConfigService globalConfigService; + @Autowired + private FlowService flowService; + + @Override + protected ServiceType service() { + return ServiceType.FlowService; + } + + @Override + protected String desc() { + return "检查与 flow 服务的连通性"; + } + + @Override + protected String getConfigValue() { + FlowConfigModel flowConfig = globalConfigService.getFlowConfig(); + if (flowConfig == null) { + return null; + } + return flowConfig.intranetBaseUri; + } + + @Override + protected String messageWhenConfigValueEmpty() { + return "请在[全局设置]-[系统设置]中对 flow 的内网地址进行设置"; + } + + @Override + protected void doCheck(String value) throws Exception { + flowService.alive(); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/GatewayInternetCheckpoint.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/GatewayInternetCheckpoint.java new file mode 100644 index 000000000..33d45ad49 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/GatewayInternetCheckpoint.java @@ -0,0 +1,62 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.checkpoint; + +import com.welab.wefe.board.service.dto.globalconfig.MemberInfoModel; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.wefe.checkpoint.AbstractCheckpoint; +import com.welab.wefe.common.wefe.enums.ServiceType; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +/** + * @author zane + */ +@Service +public class GatewayInternetCheckpoint extends AbstractCheckpoint { + @Autowired + protected GlobalConfigService globalConfigService; + + @Override + public ServiceType service() { + return ServiceType.GatewayService; + } + + @Override + public String desc() { + return "检查 board 与 gateway 服务在公网的连通性"; + } + + @Override + public String getConfigValue() { + MemberInfoModel memberInfo = globalConfigService.getMemberInfo(); + if (memberInfo == null) { + return null; + } + return memberInfo.getMemberGatewayUri(); + } + + @Override + protected String messageWhenConfigValueEmpty() { + return "请在[全局设置]-[成员设置]中对 gateway 的对外通信地址进行设置"; + } + + @Override + protected void doCheck(String value) throws Exception { + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/GatewayIntranetCheckpoint.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/GatewayIntranetCheckpoint.java new file mode 100644 index 000000000..e2e85a999 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/GatewayIntranetCheckpoint.java @@ -0,0 +1,65 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.checkpoint; + +import com.welab.wefe.board.service.service.GatewayService; +import com.welab.wefe.board.service.service.globalconfig.GlobalConfigService; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.checkpoint.AbstractCheckpoint; +import com.welab.wefe.common.wefe.enums.ServiceType; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +/** + * @author zane + */ +@Service +public class GatewayIntranetCheckpoint extends AbstractCheckpoint { + + @Autowired + protected GlobalConfigService globalConfigService; + + @Override + public ServiceType service() { + return ServiceType.GatewayService; + } + + @Override + public String desc() { + return "检查 board 与 gateway 服务在内网的连通性"; + } + + @Override + public String getConfigValue() { + return globalConfigService.getGatewayConfig().intranetBaseUri; + } + + @Override + protected String messageWhenConfigValueEmpty() { + return "请在[全局设置]-[系统设置]中对 gateway 的内网地址进行设置"; + } + + @Override + protected void doCheck(String value) throws Exception { + GatewayService gatewayService = Launcher.getBean(GatewayService.class); + + // Since the gateway does not currently have an alive interface, + // temporarily adjust a method to test the connectivity between the board and the gateway. + gatewayService.refreshMemberBlacklistCache(); + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/StorageCheckpoint.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/StorageCheckpoint.java new file mode 100644 index 000000000..f5328a015 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/StorageCheckpoint.java @@ -0,0 +1,83 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.checkpoint; + +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.common.data.storage.config.JdbcParamConfig; +import com.welab.wefe.common.data.storage.model.DataItemModel; +import com.welab.wefe.common.data.storage.repo.Storage; +import com.welab.wefe.common.data.storage.service.StorageService; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.common.wefe.checkpoint.AbstractCheckpoint; +import com.welab.wefe.common.wefe.enums.ServiceType; +import org.apache.commons.lang3.RandomStringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import static com.welab.wefe.board.service.service.DataSetStorageService.DATABASE_NAME; + +/** + * @author zane + */ +@Service +public class StorageCheckpoint extends AbstractCheckpoint { + @Autowired + protected Config config; + + @Override + public ServiceType service() { + return ServiceType.StorageService; + } + + @Override + public String desc() { + return "检查 board 对 storage 服务的访问是否正常"; + } + + @Override + public String getConfigValue() { + JdbcParamConfig storageConfig = Launcher.getBean(JdbcParamConfig.class); + return storageConfig.getUrl(); + } + + @Override + protected String messageWhenConfigValueEmpty() { + return null; + } + + @Override + protected void doCheck(String value) throws Exception { + + StorageService service = Launcher.getBean(StorageService.class); + Storage storage = service.getStorage(); + String name = RandomStringUtils.randomAlphabetic(6); + try { + storage.put(DATABASE_NAME, name, new DataItemModel<>(name, "test")); + } catch (Exception e) { + super.log(e); + throw new Exception(config.getDbType().name() + " put异常,请检查相关配置是否正确。"); + } + + try { + storage.dropTB(DATABASE_NAME, name); + } catch (Exception e) { + super.log(e); + throw new Exception(config.getDbType().name() + " drop异常,请检查相关配置是否正确。"); + } + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/UnionConnectionCheckpoint.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/UnionConnectionCheckpoint.java new file mode 100644 index 000000000..872738b63 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/checkpoint/UnionConnectionCheckpoint.java @@ -0,0 +1,37 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.checkpoint; + +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.common.wefe.checkpoint.AbstractUnionConnectionCheckpoint; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +/** + * @author zane + */ +@Service +public class UnionConnectionCheckpoint extends AbstractUnionConnectionCheckpoint { + @Autowired + protected Config config; + + @Override + public String getConfigValue() { + return config.getUnionBaseUrl(); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/AbstractDataResourceService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/AbstractDataResourceService.java new file mode 100644 index 000000000..6d8a7e80c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/AbstractDataResourceService.java @@ -0,0 +1,31 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.service.data_resource; + +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.dto.vo.data_resource.AbstractDataResourceUpdateInputModel; +import com.welab.wefe.board.service.service.AbstractService; + +/** + * @author zane + * @date 2021/12/1 + */ +public abstract class AbstractDataResourceService extends AbstractService { + + public abstract DataResourceMysqlModel findOneById(String dataSetId); + + protected abstract void beforeUpdate(DataResourceMysqlModel m, AbstractDataResourceUpdateInputModel in); +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/DataResourceService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/DataResourceService.java new file mode 100644 index 000000000..a119deb36 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/DataResourceService.java @@ -0,0 +1,392 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.service.data_resource; + +import com.welab.wefe.board.service.api.data_resource.DataResourceQueryApi; +import com.welab.wefe.board.service.database.entity.data_resource.BloomFilterMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.job.ProjectDataSetMySqlModel; +import com.welab.wefe.board.service.database.entity.job.ProjectMySqlModel; +import com.welab.wefe.board.service.database.repository.ProjectDataSetRepository; +import com.welab.wefe.board.service.database.repository.ProjectRepository; +import com.welab.wefe.board.service.database.repository.base.BaseRepository; +import com.welab.wefe.board.service.database.repository.base.RepositoryManager; +import com.welab.wefe.board.service.database.repository.data_resource.BloomFilterRepository; +import com.welab.wefe.board.service.database.repository.data_resource.DataResourceRepository; +import com.welab.wefe.board.service.database.repository.data_resource.ImageDataSetRepository; +import com.welab.wefe.board.service.database.repository.data_resource.TableDataSetRepository; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.entity.data_resource.output.BloomFilterOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.DataResourceOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.ImageDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.TableDataSetOutputModel; +import com.welab.wefe.board.service.dto.entity.project.ProjectUsageDetailOutputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.AbstractDataResourceUpdateInputModel; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterService; +import com.welab.wefe.board.service.service.data_resource.image_data_set.ImageDataSetService; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.data.mysql.Where; +import com.welab.wefe.common.data.mysql.enums.OrderBy; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DataResourcePublicLevel; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +/** + * @author zane + * @date 2021/12/1 + */ +@Service +public class DataResourceService extends AbstractDataResourceService { + @Autowired + private ProjectDataSetRepository projectDataSetRepository; + @Autowired + private ProjectRepository projectRepository; + @Autowired + private DataResourceRepository dataResourceRepository; + @Autowired + private TableDataSetService tableDataSetService; + @Autowired + private ImageDataSetService imageDataSetService; + @Autowired + private BloomFilterService bloomFilterSetService; + @Autowired + private TableDataSetRepository tableDataSetRepository; + @Autowired + private ImageDataSetRepository imageDataSetRepository; + @Autowired + private BloomFilterRepository bloomFilterRepository; + + /** + * Update the number of data sets used in the project + */ + public void updateUsageCountInProject(String dataSetId) { + dataResourceRepository.updateUsageCountInProject(dataSetId); + + DataResourceMysqlModel model = (DataResourceMysqlModel) dataResourceRepository.findById(dataSetId).orElse(null); + if (model == null) { + return; + } + + try { + unionService.lazyUpdateDataResource(model); + } catch (StatusCodeWithException e) { + super.log(e); + } + } + + /** + * The number of data sets used in the flow ++ + */ + public void usageCountInFlowIncrement(String dataSetId, Class clazz) throws StatusCodeWithException { + updateUsageCount(dataSetId, clazz, x -> x.setUsageCountInProject(x.getUsageCountInProject() + 1)); + } + + /** + * The number of data sets used in the flow -- + */ + public void usageCountInFlowDecrement(String dataSetId, Class clazz) throws StatusCodeWithException { + updateUsageCount(dataSetId, clazz, x -> x.setUsageCountInFlow(x.getUsageCountInFlow() - 1)); + } + + /** + * The number of data sets used in the job ++ + */ + public void usageCountInJobIncrement(String dataSetId) throws StatusCodeWithException { + DataResourceMysqlModel one = (DataResourceMysqlModel) dataResourceRepository.findById(dataSetId).orElse(null); + if (one == null) { + return; + } + Class clazz = null; + switch (one.getDataResourceType()) { + case ImageDataSet: + clazz = ImageDataSetMysqlModel.class; + break; + case TableDataSet: + clazz = TableDataSetMysqlModel.class; + break; + case BloomFilter: + clazz = BloomFilterMysqlModel.class; + break; + default: + } + updateUsageCount(dataSetId, clazz, x -> x.setUsageCountInJob(x.getUsageCountInJob() + 1)); + } + + /** + * Update the various usage count of the data set + */ + private void updateUsageCount(String dataSetId, Class clazz, Consumer func) throws StatusCodeWithException { + BaseRepository repo = RepositoryManager.get(clazz); + T model = (T) repo.findById(dataSetId).orElse(null); + if (model == null) { + return; + } + + func.accept(model); + repo.save(model); + + unionService.lazyUpdateDataResource(model); + } + + + /** + * Query the project information used by the dataset in the project + */ + public List queryUsageInProject(String dataResourceId) { + List result = new ArrayList<>(); + + // 查询资源的引用记录 + List usageInProjectList = projectDataSetRepository.queryUsageInProject(dataResourceId); + if (usageInProjectList == null || usageInProjectList.isEmpty()) { + return result; + } + + // 查询引用资源的项目详情 + for (ProjectDataSetMySqlModel usageInProject : usageInProjectList) { + ProjectMySqlModel project = projectRepository.findOneById(usageInProject.getProjectId()); + result.add(ModelMapper.map(project, ProjectUsageDetailOutputModel.class)); + } + + return result; + } + + /** + * update data set info + */ + public void update(AbstractDataResourceUpdateInputModel input) throws StatusCodeWithException { + DataResourceMysqlModel model = findOneById(input.getId()); + if (model == null) { + return; + } + + model.setUpdatedBy(input); + model.setName(input.getName()); + model.setDescription(input.getDescription()); + model.setPublicMemberList(input.getPublicMemberList()); + model.setPublicLevel(input.getPublicLevel()); + model.setTags(standardizeTags(input.getTags())); + handlePublicMemberList(model); + + beforeUpdate(model, input); + RepositoryManager.get(model.getClass()).save(model); + + + unionService.upsertDataResource(model); + CacheObjects.refreshDataResourceTags(model.getDataResourceType()); + } + + /** + * Process the list of visible members + *

+ * When the scene is visible to the specified members, automatically add itself is also visible. + */ + public void handlePublicMemberList(DataResourceMysqlModel model) { + + // When the PublicLevel is PublicWithMemberList, if list contains yourself, + // you will be removed, and union will handle the data that you must be visible. + if (model.getPublicLevel() == DataResourcePublicLevel.PublicWithMemberList) { + String memberId = CacheObjects.getMemberId(); + + + if (model.getPublicMemberList().contains(memberId)) { + String list = model.getPublicMemberList() + .replace(memberId, "") + .replace(",,", ","); + + model.setPublicMemberList(list); + } + } + + } + + /** + * Standardize the tag list + */ + public String standardizeTags(List tags) { + if (tags == null) { + return ""; + } + + tags = tags.stream() + // Remove comma(,,) + .map(x -> x.replace(",", "").replace(",", "")) + // Remove empty elements + .filter(x -> !StringUtil.isEmpty(x)) + .distinct() + .sorted() + .collect(Collectors.toList()); + + // Concatenate into a string, add a comma before and after it to facilitate like query. + return "," + StringUtil.join(tags, ',') + ","; + + } + + @Override + public DataResourceMysqlModel findOneById(String dataSetId) { + throw new UnsupportedOperationException(); + } + + @Override + protected void beforeUpdate(DataResourceMysqlModel m, AbstractDataResourceUpdateInputModel in) { + throw new UnsupportedOperationException(); + } + + /** + * 从本地或 union 中获取一个 DataResource 的详细信息 + * + * @param memberId 成员Id + * @param dataResourceId 资源Id + * @param mysqlClass 资源对应的 MysqlModel.class + * @param outputClass 输出类型的 class + */ + public O + findDataResourceFromLocalOrUnion(String memberId, String dataResourceId, Class mysqlClass, Class outputClass) throws StatusCodeWithException { + + BaseRepository repository = RepositoryManager.get(mysqlClass); + + if (memberId.equals(CacheObjects.getMemberId())) { + Object obj = repository.findById(dataResourceId).orElse(null); + if (obj == null) { + return null; + } + return ModelMapper.map(obj, outputClass); + } else { + return unionService.getDataResourceDetail(dataResourceId, outputClass); + } + } + + public DataResourceOutputModel findDataResourceFromLocalOrUnion(ProjectDataSetMySqlModel projectDataSet) throws StatusCodeWithException { + + if (CacheObjects.getMemberId().equals(projectDataSet.getMemberId())) { + Object obj = dataResourceRepository.findById(projectDataSet).orElse(null); + if (obj == null) { + return null; + } + return ModelMapper.map(obj, DataResourceOutputModel.class); + } else { + return unionService.getDataResourceDetail( + projectDataSet.getDataSetId(), + projectDataSet.getDataResourceType(), + DataResourceOutputModel.class + ); + } + } + + + public void delete(String dataResourceId, DataResourceType dataResourceType) throws StatusCodeWithException { + switch (dataResourceType) { + case ImageDataSet: + imageDataSetService.delete(dataResourceId); + break; + case TableDataSet: + tableDataSetService.delete(dataResourceId); + break; + case BloomFilter: + bloomFilterSetService.delete(dataResourceId); + default: + } + } + + + public PagingOutput query(DataResourceQueryApi.Input input) throws StatusCodeWithException { + Where where = Where + .create() + .equal("id", input.getId()) + .in("dataResourceType", input.getDataResourceType()) + .contains("name", input.getName()) + .containsItem("tags", input.getTag()) + .equal("createdBy", input.getCreator()) + .equal("derivedResource", false) + .orderBy("createdTime", OrderBy.asc); + + // 查所有资源 + if (CollectionUtils.isEmpty(input.getDataResourceType()) || input.getDataResourceType().size() > 1) { + PagingOutput page = dataResourceRepository.paging( + where.build(DataResourceMysqlModel.class), + input + ); + + // 将查到的数据按类型转换为 output 类型 + List list = new ArrayList<>(); + for (Object item : page.getList()) { + DataResourceMysqlModel dataResource = (DataResourceMysqlModel) item; + Class targetClass = null; + + switch (dataResource.getDataResourceType()) { + case BloomFilter: + targetClass = BloomFilterOutputModel.class; + break; + case ImageDataSet: + targetClass = ImageDataSetOutputModel.class; + break; + case TableDataSet: + targetClass = TableDataSetOutputModel.class; + break; + default: + StatusCode.UNEXPECTED_ENUM_CASE.throwException(); + } + + list.add(ModelMapper.map(item, targetClass)); + } + + return PagingOutput.of(page.getTotal(), list); + } + + // 查所指定类型的资源 + switch (input.getDataResourceType().get(0)) { + case TableDataSet: + return tableDataSetRepository.paging( + where + .equal("containsY", input.getContainsY()) + .equal("derivedResource", false) + .build(TableDataSetMysqlModel.class), + input, + TableDataSetOutputModel.class + ); + case ImageDataSet: + return imageDataSetRepository.paging( + where + .equal("forJobType", input.getForJobType()) + .build(ImageDataSetMysqlModel.class), + input, + ImageDataSetOutputModel.class + ); + case BloomFilter: + return bloomFilterRepository.paging( + where.build(BloomFilterMysqlModel.class), + input, + BloomFilterOutputModel.class + ); + default: + return null; + } + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/DataResourceUploadTaskService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/DataResourceUploadTaskService.java new file mode 100644 index 000000000..24d830ed5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/DataResourceUploadTaskService.java @@ -0,0 +1,209 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.data_resource; + +import com.welab.wefe.board.service.api.data_resource.upload_task.DataResourceUploadTaskQueryApi; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceUploadTaskMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.DataResourceRepository; +import com.welab.wefe.board.service.database.repository.data_resource.DataResourceUploadTaskRepository; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.entity.data_resource.output.DataResourceUploadTaskOutputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.AbstractDataResourceUpdateInputModel; +import com.welab.wefe.board.service.service.AbstractService; +import com.welab.wefe.common.Convert; +import com.welab.wefe.common.TimeSpan; +import com.welab.wefe.common.data.mysql.Where; +import com.welab.wefe.common.util.DateUtil; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DataResourceUploadStatus; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.jpa.domain.Specification; +import org.springframework.stereotype.Service; + +import java.util.Date; +import java.util.function.Consumer; + +/** + * @author lonnie + */ +@Service +public class DataResourceUploadTaskService extends AbstractService { + + private static final Object LOCKER = new Object(); + @Autowired + protected DataResourceRepository dataResourceRepository; + @Autowired + private DataResourceUploadTaskRepository dataResourceUploadTaskRepository; + + /** + * 创建一个新的上传任务 + */ + public DataResourceUploadTaskMysqlModel newTask(DataResourceType dataResourceType, AbstractDataResourceUpdateInputModel input) { + + DataResourceUploadTaskMysqlModel task = new DataResourceUploadTaskMysqlModel(); + task.setDataResourceName(input.getName()); + task.setProgressRatio(0); + task.setDataResourceId(new DataResourceMysqlModel().getId()); + task.setStatus(DataResourceUploadStatus.uploading); + task.setDataResourceType(dataResourceType); + dataResourceUploadTaskRepository.save(task); + return task; + } + + public DataResourceUploadTaskMysqlModel findByDataResourceId(String dataResource) { + Specification where = Where + .create() + .equal("dataResourceId", dataResource) + .build(DataResourceUploadTaskMysqlModel.class); + + return dataResourceUploadTaskRepository.findOne(where).orElse(null); + } + + /** + * 开始计算进度之前,更新message。 + */ + public void updateMessageBeforeStart(String dataResourceId, String message) { + updateProgress(dataResourceId, 0, 0, 0, message); + } + + public void updateProgress(String dataResourceId, long totalDataRowCount, long completedDataCount, long invalidDataCount) { + updateProgress(dataResourceId, totalDataRowCount, completedDataCount, invalidDataCount, null); + } + + /** + * Update upload progress + */ + public void updateProgress(String dataResourceId, long totalDataRowCount, long completedDataCount, long invalidDataCount, String message) { + // Since storing data sets into storage is a concurrent operation, onerror, updateprogress, complete and other operations may occur simultaneously to update the same task. + // In order to avoid disordered update sequence, lock operation is required here. + synchronized (LOCKER) { + DataResourceUploadTaskMysqlModel task = findByDataResourceId(dataResourceId); + + // 已经结束的任务不再更新进度 + if (task.getStatus() != DataResourceUploadStatus.uploading) { + return; + } + + if (completedDataCount > totalDataRowCount) { + completedDataCount = totalDataRowCount; + } + + int progress = 0; + if (totalDataRowCount > 0) { + // Calculate progress + progress = Convert.toInt(completedDataCount * 100L / totalDataRowCount); + } + + if (completedDataCount > 0) { + // When the early reading speed is slow, force progress++ + if (task.getProgressRatio() < 5 + && completedDataCount < 10000 + && completedDataCount > task.getCompletedDataCount() + && progress <= task.getProgressRatio() + ) { + progress = task.getProgressRatio() + 1; + } + } + + // Avoid dividing by 0 + if (progress == 0) { + progress = 1; + } + + // Because the data_set has not been updated yet. The progress cannot be set to 100 temporarily, otherwise the front end will jump in advance. + if (progress == 100) { + progress = 99; + } + + // Calculate estimated time + long estimateTime = 0; + if (progress < 100) { + long spend = System.currentTimeMillis() - task.getCreatedTime().getTime(); + estimateTime = spend / progress * (100 - progress); + } + + task.setTotalDataCount(totalDataRowCount); + task.setInvalidDataCount(invalidDataCount); + task.setCompletedDataCount(completedDataCount); + task.setEstimateRemainingTime(estimateTime); + task.setProgressRatio(progress); + task.setErrorMessage(message); + task.setUpdatedTime(new Date()); + + dataResourceUploadTaskRepository.save(task); + + LOG.info("资源上传任务进度:" + task.getProgressRatio() + " , " + completedDataCount + "/" + totalDataRowCount); + } + } + + /** + * Upload complete + */ + public void complete(String dataResourceId) { + synchronized (LOCKER) { + DataResourceUploadTaskMysqlModel task = findByDataResourceId(dataResourceId); + task.setCompletedDataCount(task.getTotalDataCount()); + task.setEstimateRemainingTime(0); + task.setProgressRatio(100); + task.setUpdatedTime(new Date()); + task.setStatus(DataResourceUploadStatus.completed); + task.setErrorMessage("已完成"); + dataResourceUploadTaskRepository.save(task); + } + } + + public DataResourceUploadTaskMysqlModel findById(String id) { + return dataResourceUploadTaskRepository.findById(id).orElse(null); + } + + public void update(DataResourceUploadTaskMysqlModel dataSetTask, Consumer func) { + if (dataSetTask == null) { + return; + } + + func.accept(dataSetTask); + dataSetTask.setUpdatedTime(new Date()); + dataResourceUploadTaskRepository.save(dataSetTask); + } + + public PagingOutput query(DataResourceUploadTaskQueryApi.Input input) { + Specification where = Where + .create() + .greaterThan("updatedTime", DateUtil.getDate(System.currentTimeMillis() - TimeSpan.fromMinute(10).toMs())) + .build(DataResourceUploadTaskMysqlModel.class); + + return dataResourceUploadTaskRepository.paging(where, input, DataResourceUploadTaskOutputModel.class); + } + + /** + * An exception occurred while saving the dataset + */ + public void onError(String dataResourceId, Exception e) { + synchronized (LOCKER) { + DataResourceUploadTaskMysqlModel task = findByDataResourceId(dataResourceId); + if (task == null) { + return; + } + + task.setErrorMessage(e.getMessage()); + task.setUpdatedTime(new Date()); + task.setStatus(DataResourceUploadStatus.failed); + dataResourceUploadTaskRepository.save(task); + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/AbstractDataResourceAddService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/AbstractDataResourceAddService.java new file mode 100644 index 000000000..8bd3debdd --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/AbstractDataResourceAddService.java @@ -0,0 +1,144 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.service.data_resource.add; + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.database.entity.data_resource.BloomFilterMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceUploadTaskMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; +import com.welab.wefe.board.service.dto.vo.data_resource.AbstractDataResourceUpdateInputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.DataResourceAddOutputModel; +import com.welab.wefe.board.service.service.AbstractService; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.board.service.service.DataSetStorageService; +import com.welab.wefe.board.service.service.ServiceCheckService; +import com.welab.wefe.board.service.service.checkpoint.StorageCheckpoint; +import com.welab.wefe.board.service.service.data_resource.DataResourceService; +import com.welab.wefe.board.service.service.data_resource.DataResourceUploadTaskService; +import com.welab.wefe.common.CommonThreadPool; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.wefe.checkpoint.dto.ServiceCheckPointOutput; +import com.welab.wefe.common.wefe.enums.DataResourceStorageType; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import org.modelmapper.ModelMapper; +import org.springframework.beans.factory.annotation.Autowired; + +/** + * @author zane + * @date 2021/12/2 + */ +public abstract class AbstractDataResourceAddService extends AbstractService { + @Autowired + protected DataResourceUploadTaskService dataResourceUploadTaskService; + @Autowired + private DataResourceService dataResourceService; + @Autowired + protected DataSetStorageService dataSetStorageService; + @Autowired + private ServiceCheckService serviceCheckService; + + // region abstract method + + protected abstract void doAdd(AbstractDataResourceUpdateInputModel in, DataResourceUploadTaskMysqlModel task, DataResourceMysqlModel m) throws StatusCodeWithException; + + protected abstract Class getMysqlModelClass(); + + protected abstract DataResourceType getDataResourceType(); + + // endregion + + + /** + * 添加资源的公共方法 + * + * @return 资源Id + */ + public DataResourceAddOutputModel add(AbstractDataResourceUpdateInputModel input) throws StatusCodeWithException { + DataResourceType dataResourceType = getDataResourceType(); + DataResourceUploadTaskMysqlModel task = dataResourceUploadTaskService.newTask(dataResourceType, input); + + DataResourceMysqlModel model = new ModelMapper().map(input, getMysqlModelClass()); + model.setId(task.getDataResourceId()); + model.setCreatedBy(input); + model.setDataResourceType(dataResourceType); + model.setTags(dataResourceService.standardizeTags(input.getTags())); + dataResourceService.handlePublicMemberList(model); + checkAndSetStorageLocation(model); + + // 异步执行资源保存动作 + CommonThreadPool.run(() -> { + try { + doAdd(input, task, model); + unionService.upsertDataResource(model); + dataResourceUploadTaskService.complete(task.getDataResourceId()); + } catch (Exception e) { + LOG.error(e.getClass().getSimpleName() + " " + e.getMessage(), e); + dataResourceUploadTaskService.onError(model.getId(), e); + } + }); + + // Refresh the data set tag list + CacheObjects.refreshDataResourceTags(model.getDataResourceType()); + + return new DataResourceAddOutputModel(task.getDataResourceId(), task.getId()); + } + + @Autowired + private StorageCheckpoint storageCheckpoint; + + /** + * 检查并设置资源的存储位置 + */ + private void checkAndSetStorageLocation(DataResourceMysqlModel model) throws StatusCodeWithException { + Class mysqlModelClass = getMysqlModelClass(); + // table data set + if (mysqlModelClass == TableDataSetMysqlModel.class) { + ServiceCheckPointOutput availableInfo = storageCheckpoint.check(); + if (!availableInfo.isSuccess()) { + StatusCode + .DATABASE_LOST + .throwException("storage 服务访问失败:" + availableInfo.getMessage() + ",请检服务是否正常:" + config.getDbType()); + } + + model.setStorageType(DataResourceStorageType.StorageService); + model.setStorageNamespace(DataSetStorageService.DATABASE_NAME); + model.setStorageResourceName(dataSetStorageService.createRawDataSetTableName(model.getId())); + } + // image data set & bloom filter + else { + model.setStorageType(DataResourceStorageType.LocalFileSystem); + model.setStorageNamespace( + WeFeFileSystem + .getFileDir(model.getDataResourceType()) + .resolve(model.getId()) + .toAbsolutePath() + .toString() + ); + + // 生成的过滤器文件统一文件名 + if (mysqlModelClass == BloomFilterMysqlModel.class) { + model.setStorageResourceName("bloom_filter.data"); + } + + FileUtil.createDir(model.getStorageNamespace()); + } + + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/BloomFilterAddService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/BloomFilterAddService.java new file mode 100644 index 000000000..8a7f298bd --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/BloomFilterAddService.java @@ -0,0 +1,228 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.data_resource.add; + + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.constant.DataSetAddMethod; +import com.welab.wefe.board.service.database.entity.DataSourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.BloomFilterMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceUploadTaskMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.BloomFilterRepository; +import com.welab.wefe.board.service.dto.vo.data_resource.AbstractDataResourceUpdateInputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.BloomFilterAddInputModel; +import com.welab.wefe.board.service.service.data_resource.DataResourceUploadTaskService; +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterColumnService; +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterService; +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterStorageService; +import com.welab.wefe.board.service.service.fusion.FieldInfoService; +import com.welab.wefe.board.service.util.*; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import org.apache.commons.io.FileUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.io.File; +import java.io.IOException; +import java.sql.Connection; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; + +/** + * The service class for add bloom_filter + * + * @author jacky.jiang + */ +@Service +public class BloomFilterAddService extends AbstractDataResourceAddService { + + @Autowired + protected BloomFilterRepository bloomFilterRepository; + @Autowired + protected BloomFilterService bloomfilterService; + @Autowired + protected BloomFilterStorageService bloomfilterStorageService; + @Autowired + protected BloomFilterColumnService bloomfilterColumnService; + @Autowired + protected DataResourceUploadTaskService dataResourceUploadTaskService; + @Autowired + protected FieldInfoService fieldInfoService; + + + @Override + protected void doAdd(AbstractDataResourceUpdateInputModel in, DataResourceUploadTaskMysqlModel task, DataResourceMysqlModel m) throws StatusCodeWithException { + BloomFilterAddInputModel input = (BloomFilterAddInputModel) in; + BloomFilterMysqlModel model = (BloomFilterMysqlModel) m; + + String sourceFilePath = WeFeFileSystem.getFilePath(DataResourceType.BloomFilter, input.getFilename()) + .toAbsolutePath() + .toString(); + + model.setSourcePath(sourceFilePath); + model.setDataSourceId(input.getDataSourceId()); + model.setHashFunction(input.getHashFunction()); + fieldInfoService.saveAll(model.getId(), input.getFieldInfoList()); + + // save bloom_filter info to file + model.setUpdatedTime(new Date()); + bloomFilterRepository.save(model); + + // Parse and save the original data + try { + AbstractBloomFilterReader bloomfilterReader = createBloomfilterReader(input); + readAllToFilterFile(model, bloomfilterReader, input.isDeduplication()); + } catch (Exception e) { + LOG.error(e.getClass().getSimpleName() + " " + e.getMessage(), e); + dataResourceUploadTaskService.onError(task.getId(), e); + return; + } + + // save bloom_filter column info to database +// bloomfilterColumnService.update(model.getId(), input.getMetadataList()); + + // Delete files uploaded by HttpUpload + try { + if (input.getBloomfilterAddMethod().equals(DataSetAddMethod.HttpUpload)) { + File file = bloomfilterService.getBloomfilterFile(input.getBloomfilterAddMethod(), input.getFilename()); + FileUtils.deleteQuietly(file); + } + } catch (StatusCodeWithException e) { + super.log(e); + } + + } + + /** + * create AbstractDataSetReader + */ + private AbstractBloomFilterReader createBloomfilterReader(BloomFilterAddInputModel input) throws StatusCodeWithException { + switch (input.getBloomfilterAddMethod()) { + case Database: + return createSqlBloomfilterReader(input); + case HttpUpload: + case LocalFile: + return createFileBloomfilterReader(input); + default: + StatusCode + .UNEXPECTED_ENUM_CASE + .throwException("暂不支持的过滤器解析方式:" + input.getBloomfilterAddMethod()); + } + + return null; + } + + /** + * create CsvBloomFilterReader/ExcelBloomfilterReader + */ + private AbstractBloomFilterReader createFileBloomfilterReader(BloomFilterAddInputModel input) throws StatusCodeWithException { + try { + File file = bloomfilterService.getBloomfilterFile(input.getBloomfilterAddMethod(), input.getFilename()); + boolean isCsv = file.getName().endsWith("csv"); + return isCsv + ? new CsvBloomFilterReader(input.getMetadataList(), file) + : new ExcelBloomfilterReader(input.getMetadataList(), file); + + } catch (IOException e) { + StatusCode.FILE_IO_ERROR.throwException(e); + return null; + } + } + + /** + * create SqlDataSetReader + */ + private SqlBloomFilterReader createSqlBloomfilterReader(BloomFilterAddInputModel input) throws StatusCodeWithException { + DataSourceMysqlModel dataSource = bloomfilterService.getDataSourceById(input.getDataSourceId()); + if (dataSource == null) { + throw new StatusCodeWithException("此dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); + } + Connection conn = JdbcManager.getConnection( + dataSource.getDatabaseType(), + dataSource.getHost(), + dataSource.getPort(), + dataSource.getUserName(), + dataSource.getPassword(), + dataSource.getDatabaseName() + ); + + return new SqlBloomFilterReader(input.getMetadataList(), conn, input.getSql()); + } + + /** + * Parse the bloom_filter file and save it to filter file + * + * @param deduplication Do you need to de-duplicate the bloom_filter + */ + private void readAllToFilterFile(BloomFilterMysqlModel model, AbstractBloomFilterReader bloomfilterReader, boolean deduplication) throws StatusCodeWithException { + long start = System.currentTimeMillis(); + LOG.info("开始解析过滤器:" + model.getId()); + + // update data set upload task info + DataResourceUploadTaskMysqlModel uploadProgress = dataResourceUploadTaskService.findByDataResourceId(model.getId()); + dataResourceUploadTaskService.update(uploadProgress, x -> x.setTotalDataCount(bloomfilterReader.getTotalDataRowCount())); + + // get bloom_filter headers + List rawHeaders = bloomfilterReader.getHeader(); + + // data row consumption method + BloomFilterAddServiceDataRowConsumer dataRowConsumer = new BloomFilterAddServiceDataRowConsumer(model, deduplication, bloomfilterReader); + + // read all data rows of the raw bloom_filter + bloomfilterReader.readAll(dataRowConsumer); + + // wait for the consumption queue to finish + dataRowConsumer.waitForFinishAndClose(); + + LOG.info("过滤器解析完毕:" + model.getId() + " spend:" + ((System.currentTimeMillis() - start) / 1000) + "s"); + } + + /** + * sort headers, move column y to the second column. + */ + private List sortHeaders(List headers) { + if (!headers.contains("y")) { + return headers; + } + + // A new list must be opened here, and the original column header cannot be modified. + List list = new ArrayList<>(); + for (String name : headers) { + if ("y".equals(name)) { + continue; + } + list.add(name); + } + list.add(1, "y"); + return list; + } + + + @Override + protected Class getMysqlModelClass() { + return BloomFilterMysqlModel.class; + } + + @Override + protected DataResourceType getDataResourceType() { + return DataResourceType.BloomFilter; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/BloomFilterAddServiceDataRowConsumer.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/BloomFilterAddServiceDataRowConsumer.java new file mode 100644 index 000000000..a74ec8c6d --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/BloomFilterAddServiceDataRowConsumer.java @@ -0,0 +1,299 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.data_resource.add; + +import com.welab.wefe.board.service.constant.Config; +import com.welab.wefe.board.service.database.entity.data_resource.BloomFilterMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.BloomFilterRepository; +import com.welab.wefe.board.service.service.data_resource.DataResourceUploadTaskService; +import com.welab.wefe.board.service.service.data_resource.bloom_filter.BloomFilterStorageService; +import com.welab.wefe.board.service.service.fusion.FieldInfoService; +import com.welab.wefe.board.service.util.AbstractBloomFilterReader; +import com.welab.wefe.board.service.util.primarykey.FieldInfo; +import com.welab.wefe.board.service.util.primarykey.PrimaryKeyUtils; +import com.welab.wefe.board.service.util.unique.AbstractDataSetUniqueFilter; +import com.welab.wefe.board.service.util.unique.DataSetBloomUniqueFilter; +import com.welab.wefe.board.service.util.unique.DataSetMemoryUniqueFilter; +import com.welab.wefe.common.BatchConsumer; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.web.Launcher; +import com.welab.wefe.fusion.core.utils.CryptoUtils; +import com.welab.wefe.fusion.core.utils.PSIUtils; +import com.welab.wefe.fusion.core.utils.bf.BloomFilters; +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.params.RSAKeyParameters; +import org.bouncycastle.crypto.params.RSAPrivateCrtKeyParameters; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.math.BigInteger; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.concurrent.atomic.LongAdder; +import java.util.function.Consumer; + +/** + * @author jacky.jiang + */ +public class BloomFilterAddServiceDataRowConsumer implements Consumer> { + private final Logger LOG = LoggerFactory.getLogger(BloomFilterAddServiceDataRowConsumer.class); + + protected Config config; + //region construction parameters + private BloomFilterRepository bloomFilterRepository; + + /** + * bloom_filter id + */ + private String bloomfilterId; + /** + * Do you need to de-duplicate + */ + private boolean deduplication; + + private AsymmetricCipherKeyPair keyPair; + + private BloomFilters bf; + + private RSAKeyParameters rsaPK; + + private BigInteger rsaE; + + private BigInteger rsaN; + + private RSAPrivateCrtKeyParameters rsaSK; + + private BigInteger rsaD; + + private BigInteger rsaP; + + private BigInteger rsaQ; + + private BigInteger cp; + + private BigInteger cq; + + public List fieldInfoList; + + private Integer processCount = 0; + + private Integer totalDataCount = 0; + + private String bloomfilterPath; + + //endregion + + /** + * To increase the writing speed, batch processing is used. + */ + private BatchConsumer> batchConsumer; + private int maxBatchSize = 0; + /** + * deduplication filter + */ + private AbstractDataSetUniqueFilter uniqueFilter; + private BloomFilterStorageService bloomfilterStorageService; + private DataResourceUploadTaskService dataResourceUploadTaskService; + private AbstractBloomFilterReader bloomfilterReader; + + /** + * The number of duplicate data in the primary key + */ + private final LongAdder repeatDataCount = new LongAdder(); + + public BloomFilterAddServiceDataRowConsumer(BloomFilterMysqlModel model, boolean deduplication, AbstractBloomFilterReader bloomfilterReader) throws StatusCodeWithException { + this.bloomfilterId = model.getId(); + this.deduplication = deduplication; + this.bloomfilterReader = bloomfilterReader; + this.keyPair = CryptoUtils.generateKeys(1024); + this.totalDataCount = (int) bloomfilterReader.getTotalDataRowCount(); + this.bf = new BloomFilters(0.0001, totalDataCount); + this.rsaPK = (RSAKeyParameters) keyPair.getPublic(); + this.rsaE = rsaPK.getExponent(); + this.rsaN = rsaPK.getModulus(); + this.rsaSK = (RSAPrivateCrtKeyParameters) keyPair.getPrivate(); + this.rsaD = rsaSK.getExponent(); + this.rsaP = rsaSK.getP(); + this.rsaQ = rsaSK.getQ(); + this.cp = rsaQ.modInverse(rsaP).multiply(rsaQ); + this.cq = rsaP.modInverse(rsaQ).multiply(rsaP); + + if (deduplication) { + this.uniqueFilter = createUniqueFilter(bloomfilterReader.getTotalDataRowCount()); + } + this.config = Launcher.CONTEXT.getBean(Config.class); + this.bloomfilterStorageService = Launcher.CONTEXT.getBean(BloomFilterStorageService.class); + this.dataResourceUploadTaskService = Launcher.CONTEXT.getBean(DataResourceUploadTaskService.class); + this.bloomFilterRepository = Launcher.CONTEXT.getBean(BloomFilterRepository.class); + FieldInfoService service = Launcher.CONTEXT.getBean(FieldInfoService.class); + this.fieldInfoList = service.fieldInfoList(bloomfilterId); + File outFile = model.file(); + + + model.setRsaD(this.rsaD.toString()); + model.setRsaN(this.rsaN.toString()); + model.setRsaE(this.rsaE.toString()); + model.setRsaP(this.rsaP.toString()); + model.setRsaQ(this.rsaQ.toString()); + model.setTotalDataCount(this.totalDataCount); + this.bloomFilterRepository.save(model); + + this.bloomfilterPath = outFile.getPath(); + batchConsumer = new BatchConsumer<>(10, 1_000, rows -> { + try { + generateFilter(bloomfilterId, rows); + // update bloom_filter upload progress + dataResourceUploadTaskService.updateProgress( + bloomfilterId, + bloomfilterReader.getTotalDataRowCount(), + bloomfilterReader.getReadDataRows(), + getRepeatDataCount() + ); + } catch (Exception e) { + LOG.error(e.getMessage(), e); + dataResourceUploadTaskService.onError(bloomfilterId, e); + } + + }); + + } + + /** + * Generating filter + */ + public void generateFilter(String bloomfilterId, List> rows) throws IOException { + + for (LinkedHashMap data : rows) { + try { + String key = PrimaryKeyUtils.create(JObject.create(data), fieldInfoList); + BigInteger h = PSIUtils.stringToBigInteger(key); + //优化前加密方法 +// BigInteger z = h.modPow(rsaD, rsaN); + + //crt优化后 + BigInteger rp = h.modPow(rsaD.remainder(rsaP.subtract(BigInteger.valueOf(1))), rsaP); + BigInteger rq = h.modPow(rsaD.remainder(rsaQ.subtract(BigInteger.valueOf(1))), rsaQ); + + BigInteger z = (rp.multiply(cp).add(rq.multiply(cq))).remainder(rsaN); + + this.bf.add(z); + this.processCount = this.processCount + 1; + } catch (Exception e) { + e.printStackTrace(); + } + } + + FileOutputStream outputStream = new FileOutputStream(this.bloomfilterPath); + this.bf.writeTo(outputStream); + outputStream.close(); + + } + + + @Override + public void accept(LinkedHashMap row) { + // In order to enable the upload progress bar to start as soon as possible, + // the initial batch size is set to be smaller. + if (bloomfilterReader.getReadDataRows() < 100) { + batchConsumer.setMaxBatchSize(50); + } else if (bloomfilterReader.getReadDataRows() < 1000) { + batchConsumer.setMaxBatchSize(100); + } + // Later processing according to reasonable batch size + else { + // Update the batch size of batch write according to the number of columns + if (this.maxBatchSize < 1) { + this.maxBatchSize = bloomfilterStorageService.getAddBatchSize(row.size()); + batchConsumer.setMaxBatchSize(this.maxBatchSize); + } + } + + // Save the data row + + batchConsumer.add(row); + + } + + /** + * Wait for the consumption queue to finish + */ + public void waitForFinishAndClose() { + batchConsumer.waitForFinishAndClose(); + } + + /** + * The count of data duplicated by the primary key + */ + public long getRepeatDataCount() { + return repeatDataCount.longValue(); + } + + + /** + * Save data to storage and ensure that the data is not duplicated + */ + private void saveRowWithDeduplication(List row) { +// String id = String.valueOf(row.get(0)); +// +// ContainResult containResult = uniqueFilter.contains(id); +// while (true) { +// switch (containResult) { +// // Already exists: discard duplicate data +// case In: +// repeatDataCount.increment(); +// return; +// +// // Does not exist, write +// case NotIn: +// batchConsumer.add(row); +// return; +// +// // Not sure: Wait for the data written in the queue to be written to confirm the query +// case MaybeIn: +// // Waiting for all data in the queue to be written to storage +// batchConsumer.waitForClean(); +// +// // Query in the storage to confirm whether it exists +// containResult = bloomfilterStorageService.containsKey(bloomfilterId, id) +// ? ContainResult.In +// : ContainResult.NotIn; +// continue; +// +// default: +// return; +// } +// } + } + + /** + * Create a deduplication filter + */ + private AbstractDataSetUniqueFilter createUniqueFilter(long totalDataRowCount) { + + // Use memory filters when the amount of data is small + if (totalDataRowCount > 100_000) { + return new DataSetBloomUniqueFilter(totalDataRowCount); + } else { + return new DataSetMemoryUniqueFilter(); + } + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/ImageDataSetAddService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/ImageDataSetAddService.java new file mode 100644 index 000000000..30e0eaea5 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/ImageDataSetAddService.java @@ -0,0 +1,159 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.service.data_resource.add; + +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceUploadTaskMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_set.ImageDataSetSampleMysqlModel; +import com.welab.wefe.board.service.database.repository.ImageDataSetSampleRepository; +import com.welab.wefe.board.service.database.repository.data_resource.ImageDataSetRepository; +import com.welab.wefe.board.service.dto.vo.data_resource.AbstractDataResourceUpdateInputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.ImageDataSetAddInputModel; +import com.welab.wefe.board.service.service.data_resource.image_data_set.data_set_parser.AbstractImageDataSetParser; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.file.decompression.SuperDecompressor; +import com.welab.wefe.common.file.decompression.dto.DecompressionResult; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.util.ListUtil; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.io.File; +import java.util.List; +import java.util.TreeSet; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * @author zane + * @date 2021/12/2 + */ +@Service +public class ImageDataSetAddService extends AbstractDataResourceAddService { + + @Autowired + private ImageDataSetRepository imageDataSetRepository; + @Autowired + private ImageDataSetSampleRepository imageDataSetSampleRepository; + + @Override + protected void doAdd(AbstractDataResourceUpdateInputModel in, DataResourceUploadTaskMysqlModel task, DataResourceMysqlModel m) throws StatusCodeWithException { + + LOG.info("{} 开始解析图片数据集文件...", m.getId()); + + ImageDataSetAddInputModel input = (ImageDataSetAddInputModel) in; + ImageDataSetMysqlModel model = (ImageDataSetMysqlModel) m; + + File inputFile = WeFeFileSystem.getFilePath(DataResourceType.ImageDataSet, input.getFilename()).toFile(); + LOG.info("{} 获取到图片数据集文件:{}", m.getId(), inputFile.getAbsolutePath()); + + DecompressionResult fileDecompressionResult = null; + List sampleList = null; + try { + dataResourceUploadTaskService.updateMessageBeforeStart(model.getId(), "解压中..."); + fileDecompressionResult = SuperDecompressor.decompression(inputFile, true); + dataResourceUploadTaskService.updateMessageBeforeStart(model.getId(), "解压完成,正在解析样本..."); + LOG.info("{} 完成解压,包含文件 {} 个", m.getId(), fileDecompressionResult.files.size()); + + sampleList = AbstractImageDataSetParser + .getParser(input.forJobType) + .parseFilesToSamples(model, fileDecompressionResult.files); + LOG.info("{} 完成样本解析,包含样本 {} 个", m.getId(), sampleList.size()); + dataResourceUploadTaskService.updateProgress(model.getId(), sampleList.size(), 1, 0, "已完成样本解析"); + + setImageDataSetModel(input, model, sampleList); + dataResourceUploadTaskService.updateProgress(model.getId(), sampleList.size(), 2, 0); + } catch (Exception e) { + super.log(e); + StatusCode.FILE_IO_ERROR.throwException(e); + } + + // save models to database + imageDataSetRepository.save(model); + LOG.info("{} 数据集信息已入库,开始保存 {} 个样本信息。", m.getId(), sampleList.size()); + + AtomicInteger count = new AtomicInteger(); + int totalCount = sampleList.size(); + + ListUtil.parallelEach( + sampleList, + sample -> { + try { + imageDataSetSampleRepository.save(sample); + count.incrementAndGet(); + if (count.get() % 50 == 0) { + dataResourceUploadTaskService.updateProgress(model.getId(), totalCount, count.get(), 0, "正在保存样本信息..."); + LOG.info("{} 样本信息保存中,当前进度 {}/{}", m.getId(), count.get(), totalCount); + } + } catch (Exception e) { + LOG.error(e.getMessage(), e); + } + } + ); + + LOG.info("{} 样本保存完毕 {}/{}", m.getId(), count.get(), totalCount); + + // delete source images + FileUtil.deleteFileOrDir(inputFile); + LOG.info("{} 原始数据集文件已删除:{}", m.getId(), inputFile.getAbsolutePath()); + + fileDecompressionResult.deleteAllDirAndFiles(); + LOG.info("{} 原始数据集解压后的文件夹已删除:{}", m.getId(), fileDecompressionResult.baseDir); + } + + private void setImageDataSetModel(ImageDataSetAddInputModel input, ImageDataSetMysqlModel dataSet, List sampleList) { + dataSet.setForJobType(input.forJobType); + + // distinct labels + TreeSet labelSet = new TreeSet<>(); + sampleList + .stream() + .filter(x -> x.isLabeled()) + .forEach(x -> + labelSet.addAll(x.getLabelSet()) + ); + dataSet.setLabelList( + StringUtil.joinByComma(labelSet) + ); + + dataSet.setTotalDataCount(sampleList.size()); + dataSet.setLabeledCount( + sampleList.stream().filter(x -> x.isLabeled()).count() + ); + + dataSet.setLabelCompleted( + sampleList.stream().allMatch(x -> x.isLabeled()) + ); + dataSet.setFilesSize( + ListUtil.sumLong(sampleList, x -> x.getFileSize()) + ); + + } + + @Override + protected Class getMysqlModelClass() { + return ImageDataSetMysqlModel.class; + } + + @Override + protected DataResourceType getDataResourceType() { + return DataResourceType.ImageDataSet; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/TableDataSetAddService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/TableDataSetAddService.java new file mode 100644 index 000000000..18d530803 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/TableDataSetAddService.java @@ -0,0 +1,245 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.data_resource.add; + +import com.welab.wefe.board.service.constant.DataSetAddMethod; +import com.welab.wefe.board.service.database.entity.DataSourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceUploadTaskMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.TableDataSetMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.TableDataSetRepository; +import com.welab.wefe.board.service.dto.vo.data_resource.AbstractDataResourceUpdateInputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.TableDataSetAddInputModel; +import com.welab.wefe.board.service.service.DataSetColumnService; +import com.welab.wefe.board.service.service.DataSetStorageService; +import com.welab.wefe.board.service.service.data_resource.DataResourceUploadTaskService; +import com.welab.wefe.board.service.service.data_resource.table_data_set.TableDataSetService; +import com.welab.wefe.board.service.util.*; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import org.apache.commons.io.FileUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.io.File; +import java.io.IOException; +import java.sql.Connection; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * The service class for add data set + * + * @author Zane + */ +@Service +public class TableDataSetAddService extends AbstractDataResourceAddService { + + @Autowired + protected TableDataSetRepository tableDataSetRepository; + @Autowired + protected TableDataSetService tableDataSetService; + @Autowired + protected DataSetStorageService dataSetStorageService; + @Autowired + protected DataSetColumnService dataSetColumnService; + @Autowired + protected DataResourceUploadTaskService dataResourceUploadTaskService; + + + @Override + public void doAdd(AbstractDataResourceUpdateInputModel in, DataResourceUploadTaskMysqlModel task, DataResourceMysqlModel m) throws StatusCodeWithException { + TableDataSetAddInputModel input = (TableDataSetAddInputModel) in; + TableDataSetMysqlModel model = (TableDataSetMysqlModel) m; + + // Parse and save the original data set + AbstractTableDataSetReader dataSetReader = createDataSetReader(input); + try { + readAllToStorage(model, dataSetReader, input.isDeduplication()); + } catch (Exception e) { + // 如果是表单错误,则用户重新编辑表单后提交即可,不用重新上传文件。 + boolean isFormError = false; + if (e instanceof StatusCodeWithException) { + isFormError = ((StatusCodeWithException) e).getStatusCode().equals(StatusCode.ERROR_IN_DATA_RESOURCE_ADD_FORM); + } + if (!isFormError) { + deleteFile(input); + } + + throw e; + } + + // save data set info to database + tableDataSetRepository.save(model); + + // save data set column info to database + dataSetColumnService.update(model.getId(), input.getMetadataList()); + + // Delete files uploaded by HttpUpload + deleteFile(input); + + } + + private void deleteFile(TableDataSetAddInputModel input) { + // Delete files uploaded by HttpUpload + try { + if (input.getDataSetAddMethod().equals(DataSetAddMethod.HttpUpload)) { + File file = tableDataSetService.getDataSetFile(input.getDataSetAddMethod(), input.getFilename()); + FileUtils.deleteQuietly(file); + } + } catch (StatusCodeWithException e) { + super.log(e); + } + } + + + @Override + protected Class getMysqlModelClass() { + return TableDataSetMysqlModel.class; + } + + @Override + protected DataResourceType getDataResourceType() { + return DataResourceType.TableDataSet; + } + + /** + * create AbstractDataSetReader + */ + private AbstractTableDataSetReader createDataSetReader(TableDataSetAddInputModel input) throws StatusCodeWithException { + switch (input.getDataSetAddMethod()) { + case Database: + return createSqlDataSetReader(input); + case HttpUpload: + case LocalFile: + return createFileDataSetReader(input); + default: + StatusCode + .UNEXPECTED_ENUM_CASE + .throwException("暂不支持的数据集解析方式:" + input.getDataSetAddMethod()); + } + + return null; + } + + /** + * create CsvDataSetReader/ExcelDataSetReader + */ + private AbstractTableDataSetReader createFileDataSetReader(TableDataSetAddInputModel input) throws StatusCodeWithException { + try { + File file = tableDataSetService.getDataSetFile(input.getDataSetAddMethod(), input.getFilename()); + boolean isCsv = file.getName().endsWith("csv"); + return isCsv + ? new CsvTableDataSetReader(input.getMetadataList(), file) + : new ExcelTableDataSetReader(input.getMetadataList(), file); + + } catch (IOException e) { + StatusCode.FILE_IO_ERROR.throwException(e); + return null; + } + } + + /** + * create SqlDataSetReader + */ + private SqlTableDataSetReader createSqlDataSetReader(TableDataSetAddInputModel input) throws StatusCodeWithException { + DataSourceMysqlModel dataSource = tableDataSetService.getDataSourceById(input.getDataSourceId()); + if (dataSource == null) { + throw new StatusCodeWithException("此dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); + } + Connection conn = JdbcManager.getConnection( + dataSource.getDatabaseType(), + dataSource.getHost(), + dataSource.getPort(), + dataSource.getUserName(), + dataSource.getPassword(), + dataSource.getDatabaseName() + ); + + return new SqlTableDataSetReader(input.getMetadataList(), conn, input.getSql()); + } + + /** + * Parse the dataset file and save it to lmdb/clickhouse + * + * @param deduplication Do you need to de-duplicate the data set + */ + private void readAllToStorage(TableDataSetMysqlModel model, AbstractTableDataSetReader dataSetReader, boolean deduplication) throws StatusCodeWithException { + long start = System.currentTimeMillis(); + LOG.info("开始解析数据集:" + model.getId()); + + // update data set upload task info + DataResourceUploadTaskMysqlModel uploadProgress = dataResourceUploadTaskService.findByDataResourceId(model.getId()); + dataResourceUploadTaskService.update(uploadProgress, x -> x.setTotalDataCount(dataSetReader.getTotalDataRowCount())); + + // get data set headers + List rawHeaders = dataSetReader.getHeader(); + // order headers + List sortedHeaders = sortHeaders(rawHeaders); + + // save headers info to storage + dataSetStorageService.saveHeaderRow(model.getId(), sortedHeaders); + // data row consumption method + TableDataSetAddServiceDataRowConsumer dataRowConsumer = new TableDataSetAddServiceDataRowConsumer(model.getId(), deduplication, dataSetReader); + + // read all data rows of the raw data set + dataSetReader.readAll(dataRowConsumer); + + // wait for the consumption queue to finish + dataRowConsumer.waitForFinishAndClose(); + + LOG.info("数据集解析完毕:" + model.getId() + " spend:" + ((System.currentTimeMillis() - start) / 1000) + "s"); + + // fill model + model.setContainsY(dataSetReader.isContainsY()); + model.setPrimaryKeyColumn(rawHeaders.get(0)); + model.setTotalDataCount(dataSetReader.getReadDataRows() - dataRowConsumer.getRepeatDataCount()); + model.setColumnCount(rawHeaders.size()); + model.setColumnNameList(StringUtil.join(sortedHeaders, ',')); + model.setFeatureCount(dataSetReader.isContainsY() ? rawHeaders.size() - 2 : rawHeaders.size() - 1); + model.setFeatureNameList(StringUtil.join(rawHeaders.stream().filter(x -> !model.getPrimaryKeyColumn().equals(x) && !"y".equals(x)).collect(Collectors.toList()), ',')); + model.setyCount(dataSetReader.isContainsY() ? 1 : 0); + model.setyNameList(dataSetReader.isContainsY() ? "y" : null); + model.setyPositiveSampleCount(dataRowConsumer.getPositiveExampleCount()); + model.setyPositiveSampleRatio(dataRowConsumer.getPositiveExampleRatio()); + + } + + /** + * sort headers, move column y to the second column. + */ + private List sortHeaders(List headers) { + if (!headers.contains("y")) { + return headers; + } + + // A new list must be opened here, and the original column header cannot be modified. + List list = new ArrayList<>(); + for (String name : headers) { + if ("y".equals(name)) { + continue; + } + list.add(name); + } + list.add(1, "y"); + return list; + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/TableDataSetAddServiceDataRowConsumer.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/TableDataSetAddServiceDataRowConsumer.java new file mode 100644 index 000000000..49b1afbb7 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/add/TableDataSetAddServiceDataRowConsumer.java @@ -0,0 +1,311 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.data_resource.add; + + +import com.welab.wefe.board.service.service.DataSetStorageService; +import com.welab.wefe.board.service.service.data_resource.DataResourceUploadTaskService; +import com.welab.wefe.board.service.util.AbstractTableDataSetReader; +import com.welab.wefe.board.service.util.unique.AbstractDataSetUniqueFilter; +import com.welab.wefe.board.service.util.unique.ContainResult; +import com.welab.wefe.board.service.util.unique.DataSetBloomUniqueFilter; +import com.welab.wefe.board.service.util.unique.DataSetMemoryUniqueFilter; +import com.welab.wefe.common.BatchConsumer; +import com.welab.wefe.common.Validator; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.ListUtil; +import com.welab.wefe.common.util.Md5; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.Launcher; +import org.apache.commons.collections4.CollectionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.LongAdder; +import java.util.function.Consumer; + +/** + * @author zane.luo + */ +public class TableDataSetAddServiceDataRowConsumer implements Consumer> { + private final Logger LOG = LoggerFactory.getLogger(TableDataSetAddServiceDataRowConsumer.class); + + //region construction parameters + + /** + * data set id + */ + private final String dataSetId; + /** + * Do you need to de-duplicate + */ + private final boolean deduplication; + + //endregion + + /** + * To increase the writing speed, batch processing is used. + */ + private final BatchConsumer> batchConsumer; + private int maxBatchSize = 0; + /** + * deduplication filter + */ + private AbstractDataSetUniqueFilter uniqueFilter; + private final DataSetStorageService dataSetStorageService; + private final DataResourceUploadTaskService dataResourceUploadTaskService; + private final AbstractTableDataSetReader dataSetReader; + + /** + * first column name in headers + */ + private final String firstColumnName; + /** + * is headers contains y column + */ + private final boolean containsY; + /** + * index of y in headers + */ + private final int yIndex; + + /** + * Number of positive cases + */ + private final AtomicLong yPositiveExampleCount = new AtomicLong(0); + + /** + * The number of duplicate data in the primary key + */ + private final LongAdder repeatDataCount = new LongAdder(); + + public TableDataSetAddServiceDataRowConsumer(String dataSetId, boolean deduplication, AbstractTableDataSetReader dataSetReader) throws StatusCodeWithException { + this.dataSetId = dataSetId; + this.deduplication = deduplication; + this.dataSetReader = dataSetReader; + + if (deduplication) { + uniqueFilter = createUniqueFilter(dataSetReader.getTotalDataRowCount()); + } + + List headers = dataSetReader.getHeader(); + this.firstColumnName = headers.get(0); + this.containsY = headers.contains("y"); + this.yIndex = headers.indexOf("y"); + + this.dataSetStorageService = Launcher.getBean(DataSetStorageService.class); + this.dataResourceUploadTaskService = Launcher.getBean(DataResourceUploadTaskService.class); + + batchConsumer = new BatchConsumer<>(10, 1_000, rows -> { + + try { + // save data row to storage + dataSetStorageService.saveDataRows(dataSetId, rows); + + // statistic positive rate + statisticPositiveExampleCount(this.containsY, this.yIndex, rows); + + // update data set upload progress + dataResourceUploadTaskService.updateProgress( + dataSetId, + dataSetReader.getTotalDataRowCount(), + dataSetReader.getReadDataRows(), + getRepeatDataCount() + ); + } catch (Exception e) { + LOG.error(e.getMessage(), e); + dataResourceUploadTaskService.onError(dataSetId, e); + } + + }); + + } + + + @Override + public void accept(LinkedHashMap row) { + + // In order to enable the upload progress bar to start as soon as possible, + // the initial batch size is set to be smaller. + if (dataSetReader.getReadDataRows() < 100) { + batchConsumer.setMaxBatchSize(50); + } else if (dataSetReader.getReadDataRows() < 1000) { + batchConsumer.setMaxBatchSize(100); + } + // Later processing according to reasonable batch size + else { + // Update the batch size of batch write according to the number of columns + if (this.maxBatchSize < 1) { + this.maxBatchSize = dataSetStorageService.getAddBatchSize(row.size()); + batchConsumer.setMaxBatchSize(this.maxBatchSize); + } + } + + + // Salt and hash the primary key + String id = String.valueOf(row.get(firstColumnName)); + id = Md5.of("hello" + id + "world"); + row.put(firstColumnName, id); + + + List values = new ArrayList<>(row.values()); + + // Move column y to the second column (the first column is the primary key) + if (containsY) { + moveY(values, values.get(yIndex)); + } + + // Save the data row + if (deduplication) { + saveRowWithDeduplication(values); + } else { + batchConsumer.add(values); + } + } + + /** + * Wait for the consumption queue to finish + */ + public void waitForFinishAndClose() { + batchConsumer.waitForFinishAndClose(); + } + + /** + * The count of data duplicated by the primary key + */ + public long getRepeatDataCount() { + return repeatDataCount.longValue(); + } + + /** + * Move column y to the second column (the first column is the primary key) + */ + private void moveY(List values, Object y) { + + if (!Validator.isInteger(y)) { + throw new RuntimeException( + "y 列必须为整数,数据集第 " + + dataSetReader.getReadDataRows() + + " 行附近发现非整数:" + + (StringUtil.isEmpty(String.valueOf(y)) ? "空" : y) + + ",请修正数据集后重试。" + ); + } + + ListUtil.moveElement(values, yIndex, 1); + } + + /** + * Save data to storage and ensure that the data is not duplicated + */ + private void saveRowWithDeduplication(List row) { + String id = String.valueOf(row.get(0)); + + ContainResult containResult = uniqueFilter.contains(id); + while (true) { + switch (containResult) { + // Already exists: discard duplicate data + case In: + repeatDataCount.increment(); + return; + + // Does not exist, write + case NotIn: + batchConsumer.add(row); + return; + + // Not sure: Wait for the data written in the queue to be written to confirm the query + case MaybeIn: + // Waiting for all data in the queue to be written to storage + batchConsumer.waitForClean(); + + // Query in the storage to confirm whether it exists + containResult = dataSetStorageService.containsKey(dataSetId, id) + ? ContainResult.In + : ContainResult.NotIn; + continue; + + default: + return; + } + } + } + + /** + * Create a deduplication filter + */ + private AbstractDataSetUniqueFilter createUniqueFilter(long totalDataRowCount) { + + // Use memory filters when the amount of data is small + if (totalDataRowCount > 100_000) { + return new DataSetBloomUniqueFilter(totalDataRowCount); + } else { + return new DataSetMemoryUniqueFilter(); + } + } + + /** + * Count the number of positive cases + */ + private void statisticPositiveExampleCount(boolean containsY, int yIndex, List> rows) { + if (!containsY || yIndex < 0 || CollectionUtils.isEmpty(rows)) { + return; + } + + // When it comes in, the y of row has moved to the position with index 1 + yIndex = 1; + for (List row : rows) { + Object value = row.get(yIndex); + if (null == value) { + continue; + } + String yValue = String.valueOf(value); + + if ("0".equals(yValue)) { + continue; + } + + yPositiveExampleCount.incrementAndGet(); + } + } + + /** + * Calculate the proportion of positive examples + */ + public double getPositiveExampleRatio() { + long totalCount = this.dataSetReader.getReadDataRows() - this.getRepeatDataCount(); + if (totalCount <= 0) { + return 0; + } + return new BigDecimal(this.yPositiveExampleCount.get()) + .divide(new BigDecimal(totalCount), 4, RoundingMode.HALF_UP) + .doubleValue(); + } + + /** + * Get the number of positive cases + */ + public long getPositiveExampleCount() { + return yPositiveExampleCount.get(); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterColumnService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterColumnService.java new file mode 100644 index 000000000..bbc312179 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterColumnService.java @@ -0,0 +1,74 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.data_resource.bloom_filter; + +import com.welab.wefe.board.service.database.entity.fusion.bloomfilter.BloomFilterColumnMysqlModel; +import com.welab.wefe.board.service.database.repository.fusion.BloomFilterColumnRepository; +import com.welab.wefe.board.service.dto.base.PagingInput; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.fusion.BloomFilterColumnInputModel; +import com.welab.wefe.board.service.dto.fusion.BloomFilterColumnOutputModel; +import com.welab.wefe.board.service.service.AbstractService; +import com.welab.wefe.common.data.mysql.Where; +import com.welab.wefe.common.data.mysql.enums.OrderBy; +import org.modelmapper.ModelMapper; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.jpa.domain.Specification; +import org.springframework.stereotype.Service; + +import java.util.List; + +/** + * @author jacky.jiang + */ +@Service +public class BloomFilterColumnService extends AbstractService { + @Autowired + BloomFilterColumnRepository bloomFilterColumnRepository; + + public PagingOutput list(String dataSetId) { + Specification where = Where + .create() + .equal("dataSetId", dataSetId) + .orderBy("index", OrderBy.asc) + .build(BloomFilterColumnMysqlModel.class); + + // The front end does not do paging, + // but considering that there may be a data set with a particularly large number of fields, + // the paging method is used to query here. + return bloomFilterColumnRepository.paging( + where, + new PagingInput(0, 10000), BloomFilterColumnOutputModel.class + ); + } + + public void update(String dataSetId, List list) { + // clear data set columns + bloomFilterColumnRepository.deleteByBloomFilterId(dataSetId); + + // save data set columns + for (int i = 0; i < list.size(); i++) { + BloomFilterColumnInputModel item = list.get(i); + + BloomFilterColumnMysqlModel column = new ModelMapper().map(item, BloomFilterColumnMysqlModel.class); + column.setBloomFilterId(dataSetId); + column.setIndex(i); + + bloomFilterColumnRepository.save(column); + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterService.java new file mode 100644 index 000000000..ab3273d3c --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterService.java @@ -0,0 +1,280 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.data_resource.bloom_filter; + +import com.welab.wefe.board.service.api.data_resource.bloom_filter.BloomFilterDataResourceListApi; +import com.welab.wefe.board.service.api.data_resource.bloom_filter.BloomFilterDeleteApi; +import com.welab.wefe.board.service.base.file_system.WeFeFileSystem; +import com.welab.wefe.board.service.constant.BloomfilterAddMethod; +import com.welab.wefe.board.service.database.entity.DataSourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.BloomFilterMysqlModel; +import com.welab.wefe.board.service.database.entity.job.ProjectMemberMySqlModel; +import com.welab.wefe.board.service.database.entity.job.ProjectMySqlModel; +import com.welab.wefe.board.service.database.repository.DataSourceRepository; +import com.welab.wefe.board.service.database.repository.JobMemberRepository; +import com.welab.wefe.board.service.database.repository.JobRepository; +import com.welab.wefe.board.service.database.repository.ProjectRepository; +import com.welab.wefe.board.service.database.repository.base.RepositoryManager; +import com.welab.wefe.board.service.database.repository.data_resource.BloomFilterRepository; +import com.welab.wefe.board.service.dto.entity.BloomFilterDataResourceListOutputModel; +import com.welab.wefe.board.service.dto.entity.data_resource.output.BloomFilterOutputModel; +import com.welab.wefe.board.service.dto.entity.project.ProjectDetailMemberOutputModel; +import com.welab.wefe.board.service.dto.entity.project.data_set.ProjectDataResourceOutputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.BloomFilterUpdateInputModel; +import com.welab.wefe.board.service.onlinedemo.OnlineDemoBranchStrategy; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.board.service.service.ProjectDataSetService; +import com.welab.wefe.board.service.service.ProjectMemberService; +import com.welab.wefe.board.service.service.data_resource.DataResourceService; +import com.welab.wefe.board.service.util.JdbcManager; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.web.util.ModelMapper; +import com.welab.wefe.common.wefe.enums.DataResourceType; +import com.welab.wefe.common.wefe.enums.DataResourcePublicLevel; +import com.welab.wefe.common.wefe.enums.JobMemberRole; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.io.File; +import java.sql.Connection; +import java.util.List; +import java.util.stream.Collectors; + +/** + * @author jacky.jiang + */ +@Service +public class BloomFilterService extends DataResourceService { + + @Autowired + protected BloomFilterRepository repo; + @Autowired + protected BloomFilterStorageService bloomfilterStorageService; + @Autowired + protected JobRepository jobRepository; + @Autowired + protected JobMemberRepository jobMemberRepository; + @Autowired + protected JobRepository featureJobRepository; + @Autowired + DataSourceRepository dataSourceRepo; + @Autowired + private ProjectRepository projectRepo; + @Autowired + private ProjectMemberService projectMemberService; + @Autowired + private ProjectDataSetService projectDataSetService; + + + public BloomFilterOutputModel findDataSetFromLocalOrUnion(String memberId, String dataSetId) throws StatusCodeWithException { + + if (memberId.equals(CacheObjects.getMemberId())) { + BloomFilterMysqlModel dataSet = repo.findById(dataSetId).orElse(null); + if (dataSet == null) { + return null; + } + return ModelMapper.map(dataSet, BloomFilterOutputModel.class); + } else { + return unionService.getDataResourceDetail(dataSetId, BloomFilterOutputModel.class); + } + } + + /** + * Get uploaded file + */ + public File getBloomfilterFile(BloomfilterAddMethod method, String filename) throws StatusCodeWithException { + File file = null; + switch (method) { + case HttpUpload: + file = WeFeFileSystem.getFilePath(DataResourceType.BloomFilter, filename).toFile(); + break; + case LocalFile: + file = new File(filename); + break; + case Database: + break; + default: + } + + if (null == file || !file.exists()) { + throw new StatusCodeWithException("未找到文件:" + filename, StatusCode.PARAMETER_VALUE_INVALID); + } + + return file; + } + + /** + * delete bloom_filter + */ + public void delete(BloomFilterDeleteApi.Input input) throws StatusCodeWithException { + BloomFilterMysqlModel model = repo.findById(input.getId()).orElse(null); + if (model == null) { + return; + } + + OnlineDemoBranchStrategy.hackOnDelete(input, model, "只能删除自己添加的数据集。"); + + delete(model); + } + + /** + * delete bloom_filter + */ + public void delete(String bloomFilterId) throws StatusCodeWithException { + BloomFilterMysqlModel model = repo.findById(bloomFilterId).orElse(null); + if (model == null) { + return; + } + + delete(model); + } + + /** + * delete bloom_filter + */ + public void delete(BloomFilterMysqlModel model) throws StatusCodeWithException { + + // delete bloom_filter from database + repo.deleteById(model.getId()); + + // delete bloom_filter from folder + bloomfilterStorageService.deleteBloomfilter(model.getId()); + + // Notify the union to do not public the bloom_filter + unionService.deleteDataResource(model); + + // Refresh the bloom_filter tag list + CacheObjects.refreshDataResourceTags(model.getDataResourceType()); + + + } + + /** + * Process the list of visible members + *

+ * When the scene is visible to the specified members, automatically add itself is also visible. + */ + public void handlePublicMemberList(BloomFilterMysqlModel model) { + + // When the PublicLevel is PublicWithMemberList, if list contains yourself, + // you will be removed, and union will handle the data that you must be visible. + if (model.getPublicLevel() == DataResourcePublicLevel.PublicWithMemberList) { + String memberId = CacheObjects.getMemberId(); + + + if (model.getPublicMemberList().contains(memberId)) { + String list = model.getPublicMemberList() + .replace(memberId, "") + .replace(",,", ","); + + model.setPublicMemberList(list); + } + } + + } + + + /** + * get data source by id + */ + public DataSourceMysqlModel getDataSourceById(String dataSourceId) { + return dataSourceRepo.findById(dataSourceId).orElse(null); + } + + + public BloomFilterMysqlModel findOne(String bloomFilterId) { + return repo.findById(bloomFilterId).orElse(null); + } + + /** + * Test whether SQL can be queried normally + */ + public boolean testSqlQuery(String dataSourceId, String sql) throws StatusCodeWithException { + DataSourceMysqlModel model = getDataSourceById(dataSourceId); + if (model == null) { + throw new StatusCodeWithException("dataSourceId在数据库不存在", StatusCode.DATA_NOT_FOUND); + } + + if (StringUtils.isEmpty(sql)) { + throw new StatusCodeWithException("请填入sql查询语句", StatusCode.PARAMETER_CAN_NOT_BE_EMPTY); + } + + Connection conn = JdbcManager.getConnection( + model.getDatabaseType(), + model.getHost(), + model.getPort(), + model.getUserName(), + model.getPassword(), + model.getDatabaseName() + ); + + return JdbcManager.testQuery(conn, sql, true); + } + + + public BloomFilterDataResourceListOutputModel query(BloomFilterDataResourceListApi.Input input) throws StatusCodeWithException { + ProjectMySqlModel project = projectRepo.findOne("projectId", input.getProjectId(), ProjectMySqlModel.class); + if (project == null) { + throw new StatusCodeWithException("未找到相应的项目!", StatusCode.ILLEGAL_REQUEST); + } + + BloomFilterDataResourceListOutputModel output = ModelMapper.map(project, BloomFilterDataResourceListOutputModel.class); + + + ProjectMemberMySqlModel memberMySqlModel = projectMemberService.findOneByMemberId(input.getProjectId(), input.getMemberId(), input.getRole()); + ProjectDetailMemberOutputModel memberOutputModel = ModelMapper.map(memberMySqlModel, ProjectDetailMemberOutputModel.class); + + List allDataSetList = projectDataSetService.listRawDataSet(input.getProjectId(), null, input.getMemberId(), input.getRole(), null); + memberOutputModel.setDataResourceList( + allDataSetList.stream().filter( + x -> StringUtils.isEmpty(input.getName()) ? true : x.getDataResource().getName().contains(input.getName())) + .collect(Collectors.toList()) + ); + + + output.setDataSetList(memberOutputModel.getDataResourceList()); + + return output; + } + + /** + * update bloom filter info + */ + public void update(BloomFilterUpdateInputModel input) throws StatusCodeWithException { + BloomFilterMysqlModel model = findOne(input.getId()); + if (model == null) { + return; + } + + model.setUpdatedBy(input); + model.setName(input.getName()); + model.setDescription(input.getDescription()); + model.setPublicMemberList(input.getPublicMemberList()); + model.setPublicLevel(input.getPublicLevel()); + model.setTags(standardizeTags(input.getTags())); + handlePublicMemberList(model); + + RepositoryManager.get(model.getClass()).save(model); + + + unionService.upsertDataResource(model); + CacheObjects.refreshDataResourceTags(model.getDataResourceType()); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterStorageService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterStorageService.java new file mode 100644 index 000000000..53cc80557 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterStorageService.java @@ -0,0 +1,221 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.data_resource.bloom_filter; + +import com.alibaba.fastjson.JSON; +import com.welab.wefe.board.service.service.AbstractService; +import com.welab.wefe.common.data.storage.common.Constant; +import com.welab.wefe.common.data.storage.model.DataItemModel; +import com.welab.wefe.common.data.storage.model.PageInputModel; +import com.welab.wefe.common.data.storage.model.PageOutputModel; +import com.welab.wefe.common.data.storage.service.StorageService; +import com.welab.wefe.common.util.StringUtil; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; + +/** + * bloom_filter storage service read and write class + *

+ * + * @author jacky.jiang + */ +@Service +public class BloomFilterStorageService extends AbstractService { + public static final String DATABASE_NAME = Constant.DBName.WEFE_DATA; + + @Autowired + StorageService storageService; + + /** + * Determine whether the specified key exists + */ + public boolean containsKey(String dataSetId, String key) { + String table = createRawBloomfilterTableName(dataSetId); + boolean contains = storageService.getByKey(DATABASE_NAME, table, key) != null; + return contains; + } + + /** + * remove bloom_filter from storage + */ + public void deleteBloomfilter(String bloomfilterId) { + String table = createRawBloomfilterTableName(bloomfilterId); + storageService.dropTB(DATABASE_NAME, table); + } + + /** + * save data set header info to storage + */ + public void saveHeaderRow(String dataSetId, List row) { + String sid = null; + List header = new ArrayList<>(); + + for (String item : row) { + if (sid == null) { + sid = String.valueOf(item); + } else { + header.add(String.valueOf(item)); + } + } + + String tableName = createRawBloomfilterTableName(dataSetId) + ".meta"; + + // According to the convention, + // sid needs to be converted to json string so that double quotation marks are added before and after. + sid = JSON.toJSONString(sid); + save(tableName, "sid", sid); + + // According to the convention, + // the header needs to be converted to json string + // so that double quotation marks are added before and after it. + String headerRow = JSON.toJSONString(StringUtil.join(header, ",")); + save(tableName, "header", headerRow); + + } + + /** + * save data row to storage + */ + public void saveDataRow(String dataSetId, Collection values) { + save(createRawBloomfilterTableName(dataSetId), buildDataItemModel(values)); + } + + /** + * save data rows to storage + */ + public void saveDataRows(String bloomfilterId, List> rows) { + + List> list = rows + .stream() + .map(x -> buildDataItemModel(x)) + .collect(Collectors.toList()); + + saveList(createRawBloomfilterTableName(bloomfilterId), list); + } + + /** + * Convert the data rows in the dataset to DataItemModel + */ + private DataItemModel buildDataItemModel(Collection values) { + String key = null; + List list = new ArrayList<>(); + + for (Object item : values) { + if (key == null) { + key = String.valueOf(item); + } else { + list.add(String.valueOf(item)); + } + } + return new DataItemModel<>(key, StringUtil.join(list, ",")); + } + + + /** + * view the bloom_filter data rows + */ + public List> previewBloomfilter(String dbName, String tableName, int limit) { + PageOutputModel page = storageService.getPage(dbName, tableName, new PageInputModel(0, limit)); + + List> data = page.getData(); + return data + .stream() + .map(x -> { + List list = new ArrayList<>(); + list.add(String.valueOf(x.getK())); + + Object value = x.getV(); + if (value != null) { + for (String item : String.valueOf(value).split(",")) { + list.add(item); + } + } + + return list; + }) + .collect(Collectors.toList()); + } + + /** + * save a record to storage + */ + private void save(String tableName, String key, String value) { + storageService.save(DATABASE_NAME, tableName, new DataItemModel<>(key, value)); + } + + /** + * save a record to storage + */ + private void save(String tableName, DataItemModel item) { + storageService.save(DATABASE_NAME, tableName, item); + } + + /** + * save multi records to storage + */ + public void saveList(String tableName, List> list) { + storageService.saveList(DATABASE_NAME, tableName, list); + } + + /** + * read by pagination + */ + public PageOutputModel getListByPage(String namespace, String tableName, PageInputModel inputModel) { + return storageService.getPage(namespace, tableName, inputModel); + } + + /** + * real all record from storage table + */ + public List getList(String tableName) { + return storageService.getList(DATABASE_NAME, tableName); + } + + /** + * Generate the raw bloom_filter table name + */ + public String createRawBloomfilterTableName(String bloomfilterId) { + return "blommfilter_" + bloomfilterId; + } + + /** + * Get row count of table + */ + public int count(String tableName) { + return storageService.count(DATABASE_NAME, tableName); + } + + /** + * Get row count of table + */ + public int count(String databaseName, String tableName) { + return storageService.count(databaseName, tableName); + } + + /** + * Calculate the appropriate batch size based on the number of columns in the data set + */ + public int getAddBatchSize(int columns) { + return storageService.getAddBatchSize(columns); + } + +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterTaskService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterTaskService.java new file mode 100644 index 000000000..059260510 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/bloom_filter/BloomFilterTaskService.java @@ -0,0 +1,163 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.welab.wefe.board.service.service.data_resource.bloom_filter; + + +import com.welab.wefe.board.service.database.entity.fusion.bloomfilter.BloomFilterTaskMysqlModel; +import com.welab.wefe.board.service.database.repository.data_resource.BloomFilterRepository; +import com.welab.wefe.board.service.database.repository.fusion.BloomFilterTaskRepository; +import com.welab.wefe.board.service.service.AbstractService; +import com.welab.wefe.common.Convert; +import com.welab.wefe.common.data.mysql.Where; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.jpa.domain.Specification; +import org.springframework.stereotype.Service; + +import java.util.Date; +import java.util.function.Consumer; + + +/** + * @author jacky.jiang + */ +@Service +public class BloomFilterTaskService extends AbstractService { + + private static final Object LOCKER = new Object(); + @Autowired + protected BloomFilterRepository bloomfilterRepository; + @Autowired + private BloomFilterTaskRepository bloomfilterTaskRepository; + + + public BloomFilterTaskMysqlModel findByBloomfilterId(String bloomFilterId) { + Specification where = Where + .create() + .equal("bloomFilterId", bloomFilterId) + .build(BloomFilterTaskMysqlModel.class); + + return bloomfilterTaskRepository.findOne(where).orElse(null); + } + + /** + * Update upload progress + */ + public void updateProgress(String bloomfilterId, long totalDataRowCount, long readedDataRows, long repeatDataCount) { + // Since storing bloomfilters into storage is a concurrent operation, onerror, updateprogress, complete and other operations may occur simultaneously to update the same task. + // In order to avoid disordered update sequence, lock operation is required here. + synchronized (LOCKER) { + BloomFilterTaskMysqlModel bloomfilterTask = findByBloomfilterId(bloomfilterId); + + // Calculate progress + int progress = Convert.toInt(readedDataRows * 100L / totalDataRowCount); + + // When the early reading speed is slow, force progress++ + if (bloomfilterTask.getProgress() < 5 + && readedDataRows < 10000 + && readedDataRows > bloomfilterTask.getAddedRowCount() + && progress <= bloomfilterTask.getProgress() + ) { + progress = bloomfilterTask.getProgress() + 1; + } + + // Avoid dividing by 0 + if (progress == 0) { + progress = 1; + } + + // Because the bloom_filter has not been updated yet. The progress cannot be set to 100 temporarily, otherwise the front end will jump in advance. + if (progress == 100) { + progress = 99; + } + + // Calculate estimated time + long estimateTime = 0; + if (progress < 100) { + long spend = System.currentTimeMillis() - bloomfilterTask.getCreatedTime().getTime(); + estimateTime = spend / progress * (100 - progress); + } + + bloomfilterTask.setRepeatIdRowCount(repeatDataCount); + bloomfilterTask.setAddedRowCount(readedDataRows); + bloomfilterTask.setEstimateTime(estimateTime); + bloomfilterTask.setProgress(progress); + bloomfilterTask.setUpdatedTime(new Date()); + + bloomfilterTaskRepository.save(bloomfilterTask); + + LOG.info("过滤器任务进度:" + bloomfilterTask.getProgress() + " , " + readedDataRows + "/" + totalDataRowCount); + } + } + + /** + * Upload complete + */ + public void complete(String bloomFilterId) { + synchronized (LOCKER) { + BloomFilterTaskMysqlModel bloomfilterTask = findByBloomfilterId(bloomFilterId); + bloomfilterTask.setAddedRowCount(bloomfilterTask.getTotalRowCount()); + bloomfilterTask.setEstimateTime(0); + bloomfilterTask.setProgress(100); + bloomfilterTask.setUpdatedTime(new Date()); + + bloomfilterTaskRepository.save(bloomfilterTask); + } + } + + public BloomFilterTaskMysqlModel findById(String id) { + return bloomfilterTaskRepository.findById(id).orElse(null); + } + + public void update(BloomFilterTaskMysqlModel bloomfilterTask, Consumer func) { + if (bloomfilterTask == null) { + return; + } + + func.accept(bloomfilterTask); + bloomfilterTask.setUpdatedTime(new Date()); + bloomfilterTaskRepository.save(bloomfilterTask); + } + +// public PagingOutput query(QueryApi.Input input) { +// Specification where = Where +// .create() +// .greaterThan("progress", 0) +// .lessThan("progress", 100) +// .greaterThan("updatedTime", DateUtil.getDate(System.currentTimeMillis() - TimeSpan.fromMinute(10).toMs())) +// .build(BloomFilterTaskMysqlModel.class); +// +// return bloomfilterTaskRepository.paging(where, input, BloomFilterTaskOutputModel.class); +// } + + /** + * An exception occurred while saving the bloom_filter + */ + public void onError(String bloomfilterId, Exception e) { + synchronized (LOCKER) { + BloomFilterTaskMysqlModel bloomfilterTask = findByBloomfilterId(bloomfilterId); + if (bloomfilterTask == null) { + return; + } + + bloomfilterTask = findByBloomfilterId(bloomfilterTask.getBloomFilterId()); + bloomfilterTask.setErrorMessage(e.getMessage()); + bloomfilterTask.setUpdatedTime(new Date()); + + bloomfilterTaskRepository.save(bloomfilterTask); + } + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/ImageDataSetSampleService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/ImageDataSetSampleService.java new file mode 100644 index 000000000..b039c3636 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/ImageDataSetSampleService.java @@ -0,0 +1,138 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.service.data_resource.image_data_set; + +import com.welab.wefe.board.service.api.data_resource.image_data_set.sample.ImageDataSetSampleQueryApi; +import com.welab.wefe.board.service.api.data_resource.image_data_set.sample.ImageDataSetSampleStatisticsApi; +import com.welab.wefe.board.service.api.data_resource.image_data_set.sample.ImageDataSetSampleUpdateApi; +import com.welab.wefe.board.service.database.entity.data_set.ImageDataSetSampleMysqlModel; +import com.welab.wefe.board.service.database.repository.ImageDataSetSampleRepository; +import com.welab.wefe.board.service.dto.base.PagingOutput; +import com.welab.wefe.board.service.dto.entity.data_set.ImageDataSetSampleOutputModel; +import com.welab.wefe.board.service.service.AbstractService; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.data.mysql.Where; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.util.JObject; +import com.welab.wefe.common.util.MapUtil; +import com.welab.wefe.common.util.StringUtil; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.jpa.domain.Specification; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.stream.Collectors; + +/** + * @author zane + * @date 2021/11/10 + */ +@Service +public class ImageDataSetSampleService extends AbstractService { + @Autowired + private ImageDataSetSampleRepository imageDataSetSampleRepository; + @Autowired + private ImageDataSetService imageDataSetService; + + /** + * 获取所有已标注的样本 + */ + public List allLabeled(String dataSetId) { + Specification where = Where + .create() + .equal("dataSetId", dataSetId) + .equal("labeled", true) + .build(ImageDataSetSampleMysqlModel.class); + + List all = imageDataSetSampleRepository.findAll(where); + return all; + } + + public PagingOutput query(ImageDataSetSampleQueryApi.Input input) { + + Where where = Where + .create() + .equal("dataSetId", input.getDataSetId()) + .equal("labeled", input.getLabeled()); + + if (StringUtil.isNotEmpty(input.getLabel())) { + if (input.labelMatchWithContains) { + where.contains("labelList", input.getLabel()); + } else { + // 前后拼接逗号,用于精确匹配单个 label。 + where.contains("labelList", "," + input.getLabel() + ","); + } + } + + + return imageDataSetSampleRepository.paging( + where.build(ImageDataSetSampleMysqlModel.class), + input, + ImageDataSetSampleOutputModel.class + ); + } + + public void update(ImageDataSetSampleUpdateApi.Input input) throws StatusCodeWithException { + ImageDataSetSampleMysqlModel sample = imageDataSetSampleRepository.findById(input.id).orElse(null); + if (sample == null) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("id 对应的样本不存在:" + input.id); + } + sample.setLabeled(input.labelInfo.isLabeled()); + sample.setLabelInfo(JObject.create(input.labelInfo)); + sample.setLabelList(StringUtil.joinByComma(input.labelInfo.labelList())); + sample.setUpdatedBy(input); + + imageDataSetSampleRepository.save(sample); + + imageDataSetService.updateLabelInfo(sample.getDataSetId()); + } + + public void delete(String id) throws StatusCodeWithException { + ImageDataSetSampleMysqlModel sample = imageDataSetSampleRepository.findById(id).orElse(null); + if (sample == null) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("id 对应的样本不存在:" + id); + } + + imageDataSetSampleRepository.delete(sample); + imageDataSetService.updateLabelInfo(sample.getDataSetId()); + + FileUtil.deleteFileOrDir(sample.getFilePath()); + } + + /** + * 统计样本分布情况 + */ + public ImageDataSetSampleStatisticsApi.Output statistics(String dataSetId) { + Map countByLabel = new TreeMap<>(); + Map countBySample = new TreeMap<>(); + + imageDataSetSampleRepository.getAllLabelList(dataSetId) + .stream() + .filter(x -> StringUtil.isNotEmpty(x)) + .forEach(x -> { + List labelList = StringUtil.splitWithoutEmptyItem(x, ","); + labelList.forEach(label -> MapUtil.increment(countByLabel, label)); + + List distinctLabelList = labelList.stream().distinct().collect(Collectors.toList()); + distinctLabelList.forEach(label -> MapUtil.increment(countBySample, label)); + }); + + return new ImageDataSetSampleStatisticsApi.Output(countByLabel, countBySample); + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/ImageDataSetService.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/ImageDataSetService.java new file mode 100644 index 000000000..e9136f1a6 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/ImageDataSetService.java @@ -0,0 +1,148 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.service.data_resource.image_data_set; + +import com.welab.wefe.board.service.api.data_resource.image_data_set.ImageDataSetDeleteApi; +import com.welab.wefe.board.service.database.entity.data_resource.DataResourceMysqlModel; +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import com.welab.wefe.board.service.database.repository.ImageDataSetSampleRepository; +import com.welab.wefe.board.service.database.repository.data_resource.ImageDataSetRepository; +import com.welab.wefe.board.service.dto.entity.data_resource.output.ImageDataSetOutputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.AbstractDataResourceUpdateInputModel; +import com.welab.wefe.board.service.dto.vo.data_resource.ImageDataSetUpdateInputModel; +import com.welab.wefe.board.service.onlinedemo.OnlineDemoBranchStrategy; +import com.welab.wefe.board.service.service.CacheObjects; +import com.welab.wefe.board.service.service.data_resource.DataResourceService; +import com.welab.wefe.board.service.service.data_resource.image_data_set.data_set_parser.AbstractImageDataSetParser; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.util.StringUtil; +import com.welab.wefe.common.web.util.ModelMapper; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.io.File; +import java.util.TreeSet; + +/** + * @author zane + * @date 2021/11/8 + */ +@Service +public class ImageDataSetService extends DataResourceService { + + @Autowired + private ImageDataSetRepository imageDataSetRepository; + @Autowired + private ImageDataSetSampleRepository imageDataSetSampleRepository; + + /** + * get data sets info from local or union + */ + public ImageDataSetOutputModel findDataSetFromLocalOrUnion(String memberId, String dataSetId) throws StatusCodeWithException { + + if (memberId.equals(CacheObjects.getMemberId())) { + ImageDataSetMysqlModel dataSet = imageDataSetRepository.findById(dataSetId).orElse(null); + if (dataSet == null) { + return null; + } + return ModelMapper.map(dataSet, ImageDataSetOutputModel.class); + } else { + return unionService.getDataResourceDetail(dataSetId, ImageDataSetOutputModel.class); + } + } + + public synchronized void updateLabelInfo(String dataSetId) throws StatusCodeWithException { + ImageDataSetMysqlModel dataSet = findOneById(dataSetId); + TreeSet labelSet = new TreeSet<>(); + imageDataSetSampleRepository.getAllDistinctLabelList(dataSetId) + .stream() + .filter(x -> StringUtil.isNotEmpty(x)) + .forEach(x -> + labelSet.addAll(StringUtil.splitWithoutEmptyItem(x, ",")) + ); + + dataSet.setLabelList(StringUtil.joinByComma(labelSet)); + dataSet.setTotalDataCount(imageDataSetSampleRepository.getSampleCount(dataSetId)); + dataSet.setLabeledCount(imageDataSetSampleRepository.getLabeledCount(dataSetId)); + + dataSet.setLabelCompleted(dataSet.getTotalDataCount().equals(dataSet.getLabeledCount())); + + imageDataSetRepository.save(dataSet); + + unionService.lazyUpdateDataResource(dataSet); + + } + + @Override + public ImageDataSetMysqlModel findOneById(String dataSetId) { + return imageDataSetRepository.findById(dataSetId).orElse(null); + } + + @Override + protected void beforeUpdate(DataResourceMysqlModel m, AbstractDataResourceUpdateInputModel in) { + ImageDataSetMysqlModel model = (ImageDataSetMysqlModel) m; + ImageDataSetUpdateInputModel input = (ImageDataSetUpdateInputModel) in; + } + + /** + * delete image data set + */ + public void delete(ImageDataSetDeleteApi.Input input) throws StatusCodeWithException { + ImageDataSetMysqlModel model = imageDataSetRepository.findById(input.getId()).orElse(null); + if (model == null) { + return; + } + + OnlineDemoBranchStrategy.hackOnDelete(input, model, "只能删除自己添加的数据集。"); + + delete(model); + } + + public void delete(String dataSetId) throws StatusCodeWithException { + ImageDataSetMysqlModel model = imageDataSetRepository.findById(dataSetId).orElse(null); + if (model == null) { + return; + } + + delete(model); + } + + public void delete(ImageDataSetMysqlModel model) throws StatusCodeWithException { + imageDataSetRepository.deleteById(model.getId()); + imageDataSetSampleRepository.deleteByDataSetId(model.getId()); + + FileUtil.deleteFileOrDir(model.getStorageNamespace()); + CacheObjects.refreshDataResourceTags(model.getDataResourceType()); + + unionService.deleteDataResource(model); + } + + + public File download(String dataSetId, String jobId) throws StatusCodeWithException { + ImageDataSetMysqlModel dataSet = findOneById(dataSetId); + + File file = AbstractImageDataSetParser.getDataSetFile(dataSet, jobId); + if (!file.exists()) { + StatusCode + .PARAMETER_VALUE_INVALID + .throwException("该 job 尚未生成数据集文件:" + jobId); + } + return file; + + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/data_set_parser/AbstractImageDataSetParser.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/data_set_parser/AbstractImageDataSetParser.java new file mode 100644 index 000000000..579432a5a --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/data_set_parser/AbstractImageDataSetParser.java @@ -0,0 +1,187 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.service.data_resource.image_data_set.data_set_parser; + + +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_set.ImageDataSetSampleMysqlModel; +import com.welab.wefe.board.service.service.AbstractService; +import com.welab.wefe.common.Convert; +import com.welab.wefe.common.StatusCode; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.file.compression.impl.Zip; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.web.CurrentAccount; +import com.welab.wefe.common.wefe.enums.DeepLearningJobType; +import org.apache.commons.io.FileUtils; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.*; +import java.util.stream.Collectors; + +/** + * @author zane + * @date 2021/11/26 + */ +public abstract class AbstractImageDataSetParser extends AbstractService { + /** + * 解析文件列表,获取样本信息。 + */ + protected abstract List parseFilesToSamples(ImageDataSetMysqlModel dataSet, Map imageFiles, Map xmlFiles, Map txtFiles) throws Exception; + + /** + * 将数据集样本输出到数据集文件输出目录,为打包为数据集文件做好准备。 + */ + protected abstract void emitSamplesToDataSetFileDir(ImageDataSetMysqlModel dataSet, final List trainSamples, final List testSamples, Path outputDir) throws Exception; + + public static AbstractImageDataSetParser getParser(DeepLearningJobType type) { + switch (type) { + case classify: + return new ClassifyImageDataSetParser(); + case detection: + return new DetectionImageDataSetParser(); + default: + return null; + } + } + + public static File getDataSetFile(ImageDataSetMysqlModel dataSet, String jobId) { + return Paths.get( + dataSet.getStorageNamespace(), + "output", + jobId + ".zip" + ).toFile(); + } + + /** + * 将数据集样本打包为数据集文件 + */ + public File parseSamplesToDataSetFile(String jobId, ImageDataSetMysqlModel dataSet, final List samples, int trainTestSplitRatio) throws Exception { + // 根据切割比例计算训练集和测试集样本的数量 + int trainCount = Convert.toInt(trainTestSplitRatio / 100D * samples.size()); + if (trainCount < 1) { + trainCount = 1; + } + int testCount = samples.size() - trainCount; + + // 将全部样本切割为训练集和测试集 + Random rand = new Random(); + List trainList = new ArrayList<>(); + List testList = new ArrayList<>(); + for (ImageDataSetSampleMysqlModel sample : samples) { + // 该样本是否判定为 train + boolean isTrainSample = false; + // train 数量还没凑够 + if (trainList.size() < trainCount) { + // test 已经凑够了,或者命运选择这条样本为 train + if (testList.size() >= testCount || rand.nextBoolean()) { + isTrainSample = true; + } + } + + if (isTrainSample) { + trainList.add(sample); + } else { + testList.add(sample); + } + } + + // 生成打包路径 + Path outputDir = Paths.get( + dataSet.getStorageNamespace(), + "output", + jobId + ); + + // 删除已存在的文件 + FileUtil.deleteFileOrDir(outputDir.toString()); + // 将样本内容输出到打包目录 + emitSamplesToDataSetFileDir(dataSet, trainList, testList, outputDir); + return new Zip().compression( + outputDir.toString(), + getDataSetFile(dataSet, jobId).getAbsolutePath() + ); + + } + + public List parseFilesToSamples(ImageDataSetMysqlModel dataSet, final Set allFiles) throws Exception { + // 过滤掉隐藏文件和操作系统临时目录中的文件 + List files = allFiles.stream() + .filter(x -> !x.getAbsolutePath().contains("/__MACOSX/")) + .filter(x -> !x.isHidden()) + .collect(Collectors.toList()); + + Set fileNameSet = new HashSet<>(); + for (File file : files) { + String fileName = file.getName(); + if (fileNameSet.contains(fileName)) { + StatusCode.PARAMETER_VALUE_INVALID.throwException("检测到多个文件名为:" + fileName + ",请删除或修改文件名后重试。"); + } + fileNameSet.add(fileName); + } + + Map imageFiles = files + .stream() + .filter(FileUtil::isImage) + .collect(Collectors.toMap( + FileUtil::getFileNameWithoutSuffix, + x -> x + )); + + Map xmlFiles = files + .stream() + .filter(x -> "xml".equalsIgnoreCase(FileUtil.getFileSuffix(x))) + .collect(Collectors.toMap( + FileUtil::getFileNameWithoutSuffix, + x -> x + )); + + Map txtFiles = files + .stream() + .filter(x -> "txt".equalsIgnoreCase(FileUtil.getFileSuffix(x))) + .collect(Collectors.toMap( + FileUtil::getFileNameWithoutSuffix, + x -> x + )); + + List samples = parseFilesToSamples(dataSet, imageFiles, xmlFiles, txtFiles); + + return samples; + } + + protected ImageDataSetSampleMysqlModel createSampleModel(ImageDataSetMysqlModel dataSet, File imageFile) throws StatusCodeWithException, IOException { + ImageDataSetSampleMysqlModel sample = new ImageDataSetSampleMysqlModel(); + sample.setDataSetId(dataSet.getId()); + sample.setFileName(imageFile.getName()); + sample.setFilePath( + Paths.get(dataSet.getStorageNamespace(), imageFile.getName()).toString() + ); + sample.setFileSize(imageFile.length()); + sample.setCreatedBy(CurrentAccount.id()); + + // move image to dest dir + File destFile = new File(sample.getFilePath()); + if (destFile.exists()) { + destFile.delete(); + } + FileUtils.copyFile(imageFile, destFile); + + return sample; + } +} diff --git a/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/data_set_parser/ClassifyImageDataSetParser.java b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/data_set_parser/ClassifyImageDataSetParser.java new file mode 100644 index 000000000..4d25e7fb1 --- /dev/null +++ b/board/board-service/src/main/java/com/welab/wefe/board/service/service/data_resource/image_data_set/data_set_parser/ClassifyImageDataSetParser.java @@ -0,0 +1,268 @@ +/* + * Copyright 2021 Tianmian Tech. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.welab.wefe.board.service.service.data_resource.image_data_set.data_set_parser; + + +import com.welab.wefe.board.service.database.entity.data_resource.ImageDataSetMysqlModel; +import com.welab.wefe.board.service.database.entity.data_set.ImageDataSetSampleMysqlModel; +import com.welab.wefe.common.Convert; +import com.welab.wefe.common.exception.StatusCodeWithException; +import com.welab.wefe.common.file.compression.impl.Tgz; +import com.welab.wefe.common.util.FileUtil; +import com.welab.wefe.common.util.ListUtil; +import com.welab.wefe.common.util.StringUtil; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * @author zane + * @date 2021/11/26 + */ +public class ClassifyImageDataSetParser extends AbstractImageDataSetParser { + + /** + * e.g: + * 0 pink primrose + * 1 hard-leaved pocket orchid + * 2 canterbury bells + * 3 sweet pea + */ + private static final String LABEL_LIST_FILE_NAME = "label_list.txt"; + private static final Pattern LABEL_LIST_PATTERN = Pattern.compile("^\\s*(?\\d+)\\s+(?