From fb73c859a40372b144d178a37bb8f78c1c96c060 Mon Sep 17 00:00:00 2001
From: wyh <318450744@qq.com>
Date: Fri, 15 Dec 2023 12:09:15 +0800
Subject: [PATCH] Initial commit
---
.idea/.gitignore | 8 +
.idea/deployment.xml | 22 ++
.idea/inspectionProfiles/Project_Default.xml | 54 ++++
.../inspectionProfiles/profiles_settings.xml | 6 +
.idea/misc.xml | 4 +
.idea/modules.xml | 8 +
.idea/remote-mappings.xml | 10 +
.idea/segment_snn.iml | 8 +
basic.py | 133 +++++++++
dataset.py | 259 ++++++++++++++++++
fpn.py | 106 +++++++
load_data.py | 202 ++++++++++++++
model.py | 37 +++
train.py | 129 +++++++++
vgg16.py | 80 ++++++
15 files changed, 1066 insertions(+)
create mode 100644 .idea/.gitignore
create mode 100644 .idea/deployment.xml
create mode 100644 .idea/inspectionProfiles/Project_Default.xml
create mode 100644 .idea/inspectionProfiles/profiles_settings.xml
create mode 100644 .idea/misc.xml
create mode 100644 .idea/modules.xml
create mode 100644 .idea/remote-mappings.xml
create mode 100644 .idea/segment_snn.iml
create mode 100644 basic.py
create mode 100644 dataset.py
create mode 100644 fpn.py
create mode 100644 load_data.py
create mode 100644 model.py
create mode 100644 train.py
create mode 100644 vgg16.py
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..35410ca
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
new file mode 100644
index 0000000..6cee02e
--- /dev/null
+++ b/.idea/deployment.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..9e40aad
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,54 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..da16ce0
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..d108090
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/remote-mappings.xml b/.idea/remote-mappings.xml
new file mode 100644
index 0000000..5db6a93
--- /dev/null
+++ b/.idea/remote-mappings.xml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/segment_snn.iml b/.idea/segment_snn.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/segment_snn.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/basic.py b/basic.py
new file mode 100644
index 0000000..ed2b913
--- /dev/null
+++ b/basic.py
@@ -0,0 +1,133 @@
+from torch import nn
+from braincog.base.node.node import *
+from functools import partial
+
+class LayerWiseConvModule(nn.Module):
+ """
+ SNN卷积模块
+ :param in_channels: 输入通道数
+ :param out_channels: 输出通道数
+ :param kernel_size: kernel size
+ :param stride: stride
+ :param padding: padding
+ :param bias: Bias
+ :param node: 神经元类型
+ :param kwargs:
+ """
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False,
+ node=BiasLIFNode,
+ step=6,
+ **kwargs):
+
+ super().__init__()
+
+ if node is None:
+ raise TypeError
+
+ self.groups = kwargs['groups'] if 'groups' in kwargs else 1
+ self.conv = nn.Conv2d(in_channels=in_channels * self.groups,
+ out_channels=out_channels * self.groups,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ bias=bias)
+ self.gn = nn.GroupNorm(self.groups, out_channels * self.groups)
+ self.node = partial(node, **kwargs)()
+ self.step = step
+ self.activation = nn.Identity()
+
+ def forward(self, x):
+ x = rearrange(x, '(t b) c w h -> t b c w h', t=self.step)
+ outputs = []
+ for t in range(self.step):
+ outputs.append(self.gn(self.conv(x[t])))
+ outputs = torch.stack(outputs) # t b c w h
+ outputs = rearrange(outputs, 't b c w h -> (t b) c w h')
+ outputs = self.node(outputs)
+ return outputs
+
+
+class TEP(nn.Module):
+ def __init__(self, step, channel, device=None, dtype=None):
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super(TEP, self).__init__()
+ self.step = step
+ self.gn = nn.GroupNorm(channel, channel)
+
+ def forward(self, x):
+ x = rearrange(x, '(t b) c w h -> t b c w h', t=self.step)
+ fire_rate = torch.mean(x, dim=0)
+ fire_rate = self.gn(fire_rate) + 1
+
+ x = x * fire_rate
+ x = rearrange(x, 't b c w h -> (t b) c w h')
+
+ return x
+
+
+class LayerWiseLinearModule(nn.Module):
+ """
+ 线性模块
+ :param in_features: 输入尺寸
+ :param out_features: 输出尺寸
+ :param bias: 是否有Bias, 默认 ``False``
+ :param node: 神经元类型, 默认 ``LIFNode``
+ :param args:
+ :param kwargs:
+ """
+
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ bias=True,
+ node=BiasLIFNode,
+ step=6,
+ spike=True,
+ *args,
+ **kwargs):
+ super().__init__()
+ if node is None:
+ raise TypeError
+
+ self.groups = kwargs['groups'] if 'groups' in kwargs else 1
+ if self.groups == 1:
+ self.fc = nn.Linear(in_features=in_features,
+ out_features=out_features, bias=bias)
+ else:
+ self.fc = nn.ModuleList()
+ for i in range(self.groups):
+ self.fc.append(nn.Linear(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias
+ ))
+ self.node = partial(node, **kwargs)()
+ self.step = step
+ self.spike = spike
+
+ def forward(self, x):
+ if self.groups == 1: # (t b) c
+ x = rearrange(x, '(t b) c -> t b c', t=self.step)
+ outputs = []
+ for t in range(self.step):
+ outputs.append(self.fc(x[t]))
+ outputs = torch.stack(outputs) # t b c
+ outputs = rearrange(outputs, 't b c -> (t b) c')
+ else: # b (c t)
+ x = rearrange(x, 'b (c t) -> t b c', t=self.groups)
+ outputs = []
+ for i in range(self.groups):
+ outputs.append(self.fc[i](x[i]))
+ outputs = torch.stack(outputs) # t b c
+ outputs = rearrange(outputs, 't b c -> b (c t)')
+ if self.spike:
+ return self.node(outputs)
+ else:
+ return outputs
\ No newline at end of file
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 0000000..f9b9e86
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,259 @@
+import os
+import pickle
+import torch
+import numpy as np
+from PIL import Image, ImageOps, ImageFilter
+from tqdm import trange
+from PIL import Image
+from torchvision import transforms
+import random
+from torchvision import transforms
+import torch.utils.data as data
+
+class SegmentationDataset(object):
+ """Segmentation Base Dataset"""
+
+ def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
+ super(SegmentationDataset, self).__init__()
+ self.root = root
+ self.transform = transform
+ self.split = split
+ self.mode = mode if mode is not None else split
+ self.base_size = base_size
+ self.crop_size = crop_size
+
+ def _val_sync_transform(self, img, mask):
+ outsize = self.crop_size
+ short_size = outsize
+ w, h = img.size
+ if w > h:
+ oh = short_size
+ ow = int(1.0 * w * oh / h)
+ else:
+ ow = short_size
+ oh = int(1.0 * h * ow / w)
+ img = img.resize((ow, oh), Image.BILINEAR)
+ mask = mask.resize((ow, oh), Image.NEAREST)
+ # center crop
+ w, h = img.size
+ x1 = int(round((w - outsize) / 2.))
+ y1 = int(round((h - outsize) / 2.))
+ img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
+ mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
+ # final transform
+ img, mask = self._img_transform(img), self._mask_transform(mask)
+ return img, mask
+
+ def _sync_transform(self, img, mask):
+ # random mirror
+ if random.random() < 0.5:
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
+ crop_size = self.crop_size
+ # random scale (short edge)
+ short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
+ w, h = img.size
+ if h > w:
+ ow = short_size
+ oh = int(1.0 * h * ow / w)
+ else:
+ oh = short_size
+ ow = int(1.0 * w * oh / h)
+ img = img.resize((ow, oh), Image.BILINEAR)
+ mask = mask.resize((ow, oh), Image.NEAREST)
+ # pad crop
+ if short_size < crop_size:
+ padh = crop_size - oh if oh < crop_size else 0
+ padw = crop_size - ow if ow < crop_size else 0
+ img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
+ mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
+ # random crop crop_size
+ w, h = img.size
+ x1 = random.randint(0, w - crop_size)
+ y1 = random.randint(0, h - crop_size)
+ img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
+ mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
+ # gaussian blur as in PSP
+ if random.random() < 0.5:
+ img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
+ # final transform
+ img, mask = self._img_transform(img), self._mask_transform(mask)
+ return img, mask
+
+ def _img_transform(self, img):
+ return np.array(img)
+
+ def _mask_transform(self, mask):
+ return np.array(mask).astype('int32')
+
+ @property
+ def num_class(self):
+ """Number of categories."""
+ return self.NUM_CLASS
+
+ @property
+ def pred_offset(self):
+ return 0
+
+class COCOSegmentation(SegmentationDataset):
+ """COCO Semantic Segmentation Dataset for VOC Pre-training.
+
+ Parameters
+ ----------
+ root : string
+ Path to ADE20K folder. Default is './datasets/coco'
+ split: string
+ 'train', 'val' or 'test'
+ transform : callable, optional
+ A function that transforms the image
+ Examples
+ --------
+ >>> from torchvision import transforms
+ >>> import torch.utils.data as data
+ >>> # Transforms for Normalization
+ >>> input_transform = transforms.Compose([
+ >>> transforms.ToTensor(),
+ >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
+ >>> ])
+ >>> # Create Dataset
+ >>> trainset = COCOSegmentation(split='train', transform=input_transform)
+ >>> # Create Training Loader
+ >>> train_data = data.DataLoader(
+ >>> trainset, 4, shuffle=True,
+ >>> num_workers=4)
+ """
+ CAT_LIST = [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 72]
+ NUM_CLASS = len(CAT_LIST)
+
+ def __init__(self, root='../datasets/coco', annotation_root='', split='train', mode=None, transform=None, stride=1, **kwargs):
+ super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs)
+ # lazy import pycocotools
+ from pycocotools.coco import COCO
+ from pycocotools import mask
+ if split == 'train':
+ print('train set')
+ ann_file = os.path.join(annotation_root, 'annotations/instances_train2017.json')
+ ids_file = os.path.join(annotation_root, 'annotations/train_ids.mx')
+ self.root = os.path.join(annotation_root, 'train2017')
+ else:
+ print('val set')
+ ann_file = os.path.join(annotation_root, 'annotations/instances_val2017.json')
+ ids_file = os.path.join(annotation_root, 'annotations/val_ids.mx')
+ self.root = os.path.join(root, 'val2017')
+ self.coco = COCO(ann_file)
+ self.coco_mask = mask
+ if os.path.exists(ids_file):
+ with open(ids_file, 'rb') as f:
+ self.ids = pickle.load(f)
+ else:
+ ids = list(self.coco.imgs.keys())
+ self.ids = self._preprocess(ids, ids_file)
+ self.transform = transform
+ self.stride = stride
+
+ def __getitem__(self, index):
+ coco = self.coco
+ img_id = self.ids[index]
+ img_metadata = coco.loadImgs(img_id)[0]
+ path = img_metadata['file_name']
+ img = Image.open(os.path.join(self.root, path)).convert('RGB')
+ cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
+ mask = Image.fromarray(self._gen_seg_mask(
+ cocotarget, img_metadata['height'], img_metadata['width']))
+ # synchrosized transform
+ if self.mode == 'train':
+ img, mask = self._sync_transform(img, mask)
+ elif self.mode == 'val':
+ img, mask = self._val_sync_transform(img, mask)
+ else:
+ assert self.mode == 'testval'
+ img, mask = self._img_transform(img), self._mask_transform(mask)
+ # general resize, normalize and toTensor
+ if self.transform is not None:
+ img = self.transform(img)
+ frames = torch.diff(self.generate_dynamic_translation(img), dim=0)
+ p_img = torch.zeros_like(frames)
+ n_img = torch.zeros_like(frames)
+ p_img[frames > 0] = frames[frames > 0]
+ n_img[frames < 0] = frames[frames < 0]
+ output = torch.concat([p_img, n_img], dim=1)
+ return output, mask
+
+ def _mask_transform(self, mask):
+ return torch.LongTensor(np.array(mask).astype('int32'))
+
+ def _gen_seg_mask(self, target, h, w):
+ mask = np.zeros((h, w), dtype=np.uint8)
+ coco_mask = self.coco_mask
+ for instance in target:
+ rle = coco_mask.frPyObjects(instance['segmentation'], h, w)
+ m = coco_mask.decode(rle)
+ cat = instance['category_id']
+ if cat in self.CAT_LIST:
+ c = self.CAT_LIST.index(cat)
+ else:
+ continue
+ if len(m.shape) < 3:
+ mask[:, :] += (mask == 0) * (m * c)
+ else:
+ mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
+ return mask
+
+ def generate_dynamic_translation(self, image):
+ tracex = self.stride * 2 * np.array([0, 2, 1, 0, 2, 1, 1, 2, 1])
+ tracey = self.stride * 2 * np.array([0, 1, 2, 1, 0, 2, 1, 1, 2])
+
+ num_frames = len(tracex)
+ channel = image.shape[0]
+ height = image.shape[1]
+ width = image.shape[2]
+
+ frames = torch.zeros((num_frames, channel, height, width))
+ for i in range(num_frames):
+ anchor_x = tracex[i]
+ anchor_y = tracey[i]
+ frames[i, :, anchor_y // 2: height - anchor_y // 2, anchor_x // 2: width - anchor_x // 2] = image[:, anchor_y:, anchor_x:]
+ return frames
+
+ def _preprocess(self, ids, ids_file):
+ print("Preprocessing mask, this will take a while." + \
+ "But don't worry, it only run once for each split.")
+ tbar = trange(len(ids))
+ new_ids = []
+ for i in tbar:
+ img_id = ids[i]
+ cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
+ img_metadata = self.coco.loadImgs(img_id)[0]
+ mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width'])
+ # more than 1k pixels
+ if (mask > 0).sum() > 1000:
+ new_ids.append(img_id)
+ tbar.set_description('Doing: {}/{}, got {} qualified images'. \
+ format(i, len(ids), len(new_ids)))
+ print('Found number of qualified images: ', len(new_ids))
+ with open(ids_file, 'wb') as f:
+ pickle.dump(new_ids, f)
+ return new_ids
+
+ @property
+ def classes(self):
+ """Category names."""
+ return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle',
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
+ 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
+ 'tv')
+
+ def __len__(self):
+ return len(self.root)
+
+if __name__ == '__main__':
+ input_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
+ ])
+ # Create Dataset
+ trainset = COCOSegmentation(root='/kaggle/working/coco', annotation_root='coco_ann2017',split='val', transform=input_transform)
+ # Create Training Loader
+ train_data = data.DataLoader(
+ trainset, 4, shuffle=True,
+ )
\ No newline at end of file
diff --git a/fpn.py b/fpn.py
new file mode 100644
index 0000000..e11f2ac
--- /dev/null
+++ b/fpn.py
@@ -0,0 +1,106 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from basic import *
+
+class FPNSegmentationHead(nn.Module):
+ def __init__(self,
+ in_dim,
+ out_dim,
+ decode_intermediate_input=True,
+ hidden_dim=256,
+ shortcut_dims=[64, 128, 256, 512],
+ node=BiasLIFNode,
+ step=6,
+ align_corners=True):
+ super().__init__()
+ self.align_corners = align_corners
+ self.node = node
+ self.step = step
+ self.decode_intermediate_input = decode_intermediate_input
+
+
+ # self.conv_in = ConvGN(in_dim, hidden_dim, 1)
+ self.conv_in = LayerWiseConvModule(in_dim, hidden_dim, 1, padding=(0, 0), node=BiasLIFNode, step=self.step)
+
+ # self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3)
+ # self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3)
+ # self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3)
+ self.conv_16x = LayerWiseConvModule(hidden_dim, hidden_dim, 3, node=BiasLIFNode, step=self.step)
+ self.conv_8x = LayerWiseConvModule(hidden_dim, hidden_dim // 2, 3, node=BiasLIFNode, step=self.step)
+ self.conv_4x = LayerWiseConvModule(hidden_dim // 2, hidden_dim // 2, 3, node=BiasLIFNode, step=self.step)
+ self.conv_2x = LayerWiseConvModule(hidden_dim // 2, hidden_dim // 2, 3, node=BiasLIFNode, step=self.step)
+ # self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1)
+ # self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1)
+ # self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1)
+ self.in_tep = TEP(step=self.step, channel=hidden_dim, device=None, dtype=None)
+ self.adapter_16x = LayerWiseConvModule(shortcut_dims[-1], hidden_dim, 1, padding=(0, 0), node=BiasLIFNode, step=self.step)
+ self.tep_16x = TEP(step=self.step, channel=hidden_dim, device=None, dtype=None)
+ self.adapter_8x = LayerWiseConvModule(shortcut_dims[-2], hidden_dim, 1, padding=(0, 0), node=BiasLIFNode, step=self.step)
+ self.tep_8x = TEP(step=self.step, channel=hidden_dim // 2, device=None, dtype=None)
+ self.adapter_4x = LayerWiseConvModule(shortcut_dims[-3], hidden_dim // 2, 1, padding=(0, 0), node=BiasLIFNode, step=self.step)
+ self.tep_4x = TEP(step=self.step, channel=hidden_dim // 2, device=None, dtype=None)
+ self.adapter_2x = LayerWiseConvModule(shortcut_dims[-4], hidden_dim // 2, 1, padding=(0, 0), node=BiasLIFNode, step=self.step)
+ self.tep_2x = TEP(step=self.step, channel=hidden_dim // 2, device=None, dtype=None)
+
+ self.conv_out = LayerWiseConvModule(hidden_dim // 2, out_dim, 1, padding=(0, 0), node=BiasLIFNode, step=self.step)
+
+ self._init_weight()
+
+ def forward(self, inputs, shortcuts):
+ self.reset()
+
+ inputs = rearrange(inputs, 't b c w h -> (t b) c w h')
+ for i in range(len(shortcuts)):
+ shortcuts[i] = rearrange(shortcuts[i], 't b c w h -> (t b) c w h')
+
+ # if self.decode_intermediate_input:
+ # x = torch.cat(inputs, dim=1)
+ # else:
+ # x = inputs[-1]
+ if self.decode_intermediate_input:
+ x = self.in_tep(self.conv_in(inputs))
+ else:
+ x = inputs
+
+ x = F.interpolate(x,
+ size=shortcuts[-1].size()[-2:],
+ mode="bilinear",
+ align_corners=self.align_corners)
+ x = self.tep_16x(self.conv_16x(self.adapter_16x(shortcuts[-1]) + x))
+
+ x = F.interpolate(x,
+ size=shortcuts[-2].size()[-2:],
+ mode="bilinear",
+ align_corners=self.align_corners)
+ x = self.tep_8x(self.conv_8x(self.adapter_8x(shortcuts[-2]) + x))
+
+ x = F.interpolate(x,
+ size=shortcuts[-3].size()[-2:],
+ mode="bilinear",
+ align_corners=self.align_corners)
+ x = self.tep_4x(self.conv_4x(self.adapter_4x(shortcuts[-3]) + x))
+
+ x = F.interpolate(x,
+ size=shortcuts[-4].size()[-2:],
+ mode="bilinear",
+ align_corners=self.align_corners)
+ x = self.tep_2x(self.conv_2x(self.adapter_2x(shortcuts[-4]) + x))
+
+ x = self.conv_out(x)
+
+ return x
+
+ def _init_weight(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def reset(self):
+ """
+ 重置所有神经元的膜电位
+ :return:
+ """
+ for mod in self.modules():
+ if hasattr(mod, 'n_reset'):
+ mod.n_reset()
diff --git a/load_data.py b/load_data.py
new file mode 100644
index 0000000..d94fd0b
--- /dev/null
+++ b/load_data.py
@@ -0,0 +1,202 @@
+import numpy as np
+from torchvision import datasets, transforms
+import torch
+from torch.utils.data import Dataset
+import tonic
+from tonic import DiskCachedDataset
+import torch.nn.functional as F
+import os
+
+MNIST_MEAN = 0.1307
+MNIST_STD = 0.3081
+CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
+CIFAR10_STD_DEV = (0.2023, 0.1994, 0.2010)
+cifar100_mean = [0.5071, 0.4865, 0.4409]
+cifar100_std = [0.2673, 0.2563, 0.2761]
+DVSCIFAR10_MEAN_16 = [0.3290, 0.4507]
+DVSCIFAR10_STD_16 = [1.8398, 1.6549]
+
+DATA_DIR = '/data/datasets'
+
+
+class CustomDataset(Dataset):
+ """An abstract Dataset class wrapped around Pytorch Dataset class.
+ """
+
+ def __init__(self, dataset, indices):
+ self.dataset = dataset
+ self.indices = [int(i) for i in indices]
+
+ def __len__(self):
+ return len(self.indices)
+
+ def __getitem__(self, item):
+ x, y = self.dataset[self.indices[item]]
+ return x, y
+
+
+def load_static_data(data_root, batch_size, dataset):
+ if dataset == 'cifar10':
+ transform_train = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV)])
+
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV)])
+
+ train_data = datasets.CIFAR10(data_root, train=True, transform=transform_train, download=True)
+ test_data = datasets.CIFAR10(data_root, train=False, transform=transform_test, download=True)
+
+ train_loader = torch.utils.data.DataLoader(
+ train_data,
+ batch_size=batch_size,
+ shuffle=True
+ )
+ test_loader = torch.utils.data.DataLoader(
+ test_data,
+ batch_size=batch_size,
+ )
+ elif dataset == 'MNIST':
+ transform_train = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(MNIST_MEAN, MNIST_STD)])
+
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(MNIST_MEAN, MNIST_STD)])
+
+ train_data = datasets.MNIST(data_root, train=True, transform=transform_train, download=True)
+ test_data = datasets.MNIST(data_root, train=False, transform=transform_test, download=True)
+
+ train_loader = torch.utils.data.DataLoader(
+ train_data,
+ batch_size=batch_size,
+ shuffle=True
+ )
+ test_loader = torch.utils.data.DataLoader(
+ test_data,
+ batch_size=batch_size,
+ )
+ elif dataset == 'FashionMNIST':
+ transform_train = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(MNIST_MEAN, MNIST_STD)])
+
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(MNIST_MEAN, MNIST_STD)])
+
+ train_data = datasets.FashionMNIST(data_root, train=True, transform=transform_train, download=True)
+ test_data = datasets.FashionMNIST(data_root, train=False, transform=transform_test, download=True)
+
+ train_loader = torch.utils.data.DataLoader(
+ train_data,
+ batch_size=batch_size,
+ shuffle=True
+ )
+ test_loader = torch.utils.data.DataLoader(
+ test_data,
+ batch_size=batch_size,
+ )
+
+ return train_data, test_data, train_loader, test_loader
+
+
+def load_dvs10_data(batch_size, step, **kwargs):
+ size = kwargs['size'] if 'size' in kwargs else 48
+ sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
+ train_transform = transforms.Compose([
+ # tonic.transforms.Denoise(filter_time=10000),
+ # tonic.transforms.DropEvent(p=0.1),
+ tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
+ test_transform = transforms.Compose([
+ # tonic.transforms.Denoise(filter_time=10000),
+ tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
+ train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform)
+ test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform)
+
+ train_transform = transforms.Compose([
+ lambda x: torch.tensor(x, dtype=torch.float),
+ lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
+ ])
+ test_transform = transforms.Compose([
+ lambda x: torch.tensor(x, dtype=torch.float),
+ lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True),
+ ])
+
+ train_dataset = DiskCachedDataset(train_dataset,
+ cache_path=f'./dataset/dvs_cifar10/train_cache_{step}',
+ transform=train_transform)
+ test_dataset = DiskCachedDataset(train_dataset,
+ cache_path=f'./dataset/dvs_cifar10/test_cache_{step}',
+ transform=test_transform)
+
+ print(train_dataset)
+
+ num_train = len(train_dataset)
+ num_per_cls = num_train // 10
+ indices_train, indices_test = [], []
+ portion = kwargs['portion'] if 'portion' in kwargs else .9
+ for i in range(10):
+ indices_train.extend(
+ list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion))))
+ indices_test.extend(
+ list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls)))
+ train_dataset = CustomDataset(train_dataset, np.array(indices_train))
+ test_dataset = CustomDataset(test_dataset, np.array(indices_test))
+
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True,
+ pin_memory=True, drop_last=False, num_workers=1
+ )
+
+ test_loader = torch.utils.data.DataLoader(
+ test_dataset, batch_size=batch_size,
+ pin_memory=True, drop_last=False, num_workers=1
+ )
+
+ return train_loader, test_loader, train_dataset, test_dataset
+
+
+def load_nmnist_data(batch_size, step, **kwargs):
+ size = kwargs['size'] if 'size' in kwargs else 28
+ sensor_size = tonic.datasets.NMNIST.sensor_size
+ train_transform = transforms.Compose([
+ # tonic.transforms.Denoise(filter_time=10000),
+ # tonic.transforms.DropEvent(p=0.1),
+ tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
+ test_transform = transforms.Compose([
+ # tonic.transforms.Denoise(filter_time=10000),
+ tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ])
+ train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=train_transform, train=True)
+ test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=test_transform, train=False)
+
+ train_transform = transforms.Compose([
+ transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.float)),
+ transforms.Lambda(lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True)),
+
+ ])
+ test_transform = transforms.Compose([
+ transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.float)),
+ transforms.Lambda(lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True)),
+ ])
+
+ train_dataset = DiskCachedDataset(train_dataset,
+ cache_path=f'./dataset/NMNIST/train_cache_{step}',
+ transform=train_transform)
+ test_dataset = DiskCachedDataset(test_dataset,
+ cache_path=f'./dataset/NMNIST/test_cache_{step}',
+ transform=test_transform)
+
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True,
+ pin_memory=True, drop_last=False, num_workers=0
+ )
+
+ test_loader = torch.utils.data.DataLoader(
+ test_dataset, batch_size=batch_size,
+ pin_memory=True, drop_last=False, num_workers=0
+ )
+
+ return train_loader, test_loader, train_dataset, test_dataset
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..6d06a37
--- /dev/null
+++ b/model.py
@@ -0,0 +1,37 @@
+import torch
+
+from fpn import FPNSegmentationHead
+from vgg16 import VGG16
+from torch import nn
+from braincog.base.node.node import *
+
+
+class SegmentModel(nn.Module):
+ def __init__(self, output_size, node=BiasLIFNode, step=6):
+ super(SegmentModel, self).__init__()
+ self.output_size = output_size
+ self.node = node
+ self.step = step
+ self.encoder = VGG16()
+ self.decoder = FPNSegmentationHead(512, 13,
+ decode_intermediate_input=True,
+ shortcut_dims=[64, 128, 256, 512],
+ node=node,
+ step=step,
+ align_corners=True)
+
+ def forward(self, x: torch.Tensor = None):
+ """
+ x -> t b c w h
+ """
+ embs = self.encoder(x)
+ for i in range(len(embs)):
+ embs[i] = rearrange(embs[i], '(t b) c w h -> t b c w h', t=self.step)
+ logits = self.decoder(embs[-1], embs[0:-1])
+ logits = rearrange(logits, '(t b) c w h -> t b c w h', t=self.step)
+ out_logits = torch.mean(logits, dim=0)
+ out_logits = F.interpolate(out_logits,
+ size=self.output_size,
+ mode="bilinear",
+ align_corners=True)
+ return out_logits
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..5cc4421
--- /dev/null
+++ b/train.py
@@ -0,0 +1,129 @@
+import sys
+
+import torch
+
+sys.path.append('../../..')
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+from torchvision.transforms import *
+import time
+from braincog.utils import setup_seed
+from dataset import SegmentationDataset
+from braincog.base.node.node import *
+from model import SegmentModel
+
+device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
+DATA_DIR = '/data/datasets'
+
+
+def get_cifar10_loader(batch_size, train_batch=None, num_workers=4, conversion=False, distributed=False):
+ normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
+ # transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ # CIFAR10Policy(),
+ # transforms.ToTensor(),
+ # Cutout(n_holes=1, length=16),
+ # normalize])
+ transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ normalize])
+ transform_test = transforms.Compose([transforms.ToTensor(), normalize])
+ train_batch = batch_size if train_batch is None else train_batch
+ cifar10_train = datasets.CIFAR10(root=DATA_DIR, train=True, download=False,
+ transform=transform_test if conversion else transform_train)
+ cifar10_test = datasets.CIFAR10(root=DATA_DIR, train=False, download=False, transform=transform_test)
+
+ if distributed:
+ train_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_train)
+ val_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_test, shuffle=False, drop_last=True)
+ train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=False,
+ num_workers=num_workers, pin_memory=True, sampler=train_sampler)
+ test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False,
+ num_workers=num_workers, pin_memory=True, sampler=val_sampler)
+ else:
+ train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=True,
+ num_workers=num_workers, pin_memory=True)
+ test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False,
+ num_workers=num_workers, pin_memory=True)
+
+ return train_iter, test_iter
+
+
+def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='mse'):
+ best = 0
+ net = net.to(device)
+ print("training on ", device)
+ if losstype == 'mse':
+ loss = torch.nn.MSELoss()
+ else:
+ loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
+ losses = []
+
+ for epoch in range(num_epochs):
+ for param_group in optimizer.param_groups:
+ learning_rate = param_group['lr']
+
+ losss = []
+ train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
+ for X, y in train_iter:
+ optimizer.zero_grad()
+ # X = X.to(device)
+ # y = y.to(device)
+ X = torch.ones(6, 8, 2, 128, 128).to(device)
+ y = torch.ones(8, 13, 128, 128).to(device)
+ y_hat = net(X)
+ label = y
+ l = loss(y_hat, label)
+ losss.append(l.cpu().item())
+ l.backward()
+ optimizer.step()
+ train_l_sum += l.cpu().item()
+ # train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
+ n += y.shape[0]
+ batch_count += 1
+ scheduler.step()
+ test_acc = evaluate_accuracy(test_iter, net)
+ losses.append(np.mean(losss))
+ print('epoch %d, lr %.6f, loss %.6f, train acc %.6f, test acc %.6f, time %.1f sec'
+ % (epoch + 1, learning_rate, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
+
+ if test_acc > best:
+ best = test_acc
+ torch.save(net.state_dict(), './checkpoints/CIFAR10_VGG16.pth')
+
+
+def evaluate_accuracy(data_iter, net, device=None, only_onebatch=False):
+ if device is None and isinstance(net, torch.nn.Module):
+ device = list(net.parameters())[0].device
+ acc_sum, n = 0.0, 0
+ with torch.no_grad():
+ for X, y in data_iter:
+ net.eval()
+ acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
+ net.train()
+ n += y.shape[0]
+
+ if only_onebatch: break
+ return acc_sum / n
+
+
+if __name__ == '__main__':
+ setup_seed(42)
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+
+ batch_size = 8
+ step = 6
+ train_iter, test_iter, _, _ = get_nmnist_data(batch_size, step)
+ # train_iter, test_iter = get_cifar10_loader(batch_size)
+ print('dataloader finished')
+
+ lr, num_epochs = 0.01, 300
+ net = SegmentModel(output_size=(128, 128), node=BiasLIFNode, step=step)
+ optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=num_epochs)
+ train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='crossentropy')#'crossentropy')
+
+ # net.load_state_dict(torch.load("./CIFAR10_VGG16.pth", map_location=device))
+ net = net.to(device)
+ acc = evaluate_accuracy(test_iter, net, device)
+ print(acc)
diff --git a/vgg16.py b/vgg16.py
new file mode 100644
index 0000000..de15ae2
--- /dev/null
+++ b/vgg16.py
@@ -0,0 +1,80 @@
+from torch import nn
+from basic import *
+
+
+class VGG16(nn.Module):
+ def __init__(self, node=BiasLIFNode, step=6, **kwargs): # 1 3e38
+ super(VGG16, self).__init__()
+ self.step = step
+
+ self.downsample_2x = nn.Sequential(
+ LayerWiseConvModule(2, 64, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=64, device=None, dtype=None),
+ LayerWiseConvModule(64, 64, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=64, device=None, dtype=None),
+ nn.MaxPool2d(2, 2)
+ )
+
+ self.downsample_4x = nn.Sequential(
+ LayerWiseConvModule(64, 128, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=128, device=None, dtype=None),
+ LayerWiseConvModule(128, 128, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=128, device=None, dtype=None),
+ nn.MaxPool2d(2, 2)
+ )
+
+ self.downsample_8x = nn.Sequential(
+ LayerWiseConvModule(128, 256, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=256, device=None, dtype=None),
+ LayerWiseConvModule(256, 256, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=256, device=None, dtype=None),
+ nn.MaxPool2d(2, 2)
+ )
+
+ self.downsample_16x = nn.Sequential(
+ LayerWiseConvModule(256, 512, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=512, device=None, dtype=None),
+ LayerWiseConvModule(512, 512, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=512, device=None, dtype=None),
+ nn.MaxPool2d(2, 2)
+ )
+
+ self.downsample_32x = nn.Sequential(
+ LayerWiseConvModule(512, 512, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=512, device=None, dtype=None),
+ LayerWiseConvModule(512, 512, 3, 1, 1, node=BiasLIFNode, step=self.step),
+ TEP(step=self.step, channel=512, device=None, dtype=None),
+ nn.MaxPool2d(2, 2)
+ )
+
+ self.fc = LayerWiseLinearModule(512, 10, bias=True, node=BiasLIFNode, step=self.step)
+ # self.node = partial(node, **kwargs)()
+
+ def forward(self, input):
+ self.reset()
+ # input = input.permute(1, 0, 2, 3, 4)
+ input = rearrange(input, 't b c w h -> (t b) c w h')
+
+ # embedding
+ downsample_2x = self.downsample_2x(input)
+ downsample_4x = self.downsample_4x(downsample_2x)
+ downsample_8x = self.downsample_8x(downsample_4x)
+ downsample_16x = self.downsample_16x(downsample_8x)
+ downsample_32x = self.downsample_32x(downsample_16x)
+
+ shortcuts = [downsample_2x, downsample_4x, downsample_8x, downsample_16x, downsample_32x]
+
+ # x = downsample_32x.view(downsample_32x.shape[0], -1)
+ # output = self.fc(x)
+ # outputs = rearrange(output, '(t b) c -> t b c', t=self.step)
+
+ return shortcuts # sum(outputs) / len(outputs)
+
+ def reset(self):
+ """
+ 重置所有神经元的膜电位
+ :return:
+ """
+ for mod in self.modules():
+ if hasattr(mod, 'n_reset'):
+ mod.n_reset()