From fb73c859a40372b144d178a37bb8f78c1c96c060 Mon Sep 17 00:00:00 2001 From: wyh <318450744@qq.com> Date: Fri, 15 Dec 2023 12:09:15 +0800 Subject: [PATCH] Initial commit --- .idea/.gitignore | 8 + .idea/deployment.xml | 22 ++ .idea/inspectionProfiles/Project_Default.xml | 54 ++++ .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/remote-mappings.xml | 10 + .idea/segment_snn.iml | 8 + basic.py | 133 +++++++++ dataset.py | 259 ++++++++++++++++++ fpn.py | 106 +++++++ load_data.py | 202 ++++++++++++++ model.py | 37 +++ train.py | 129 +++++++++ vgg16.py | 80 ++++++ 15 files changed, 1066 insertions(+) create mode 100644 .idea/.gitignore create mode 100644 .idea/deployment.xml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/remote-mappings.xml create mode 100644 .idea/segment_snn.iml create mode 100644 basic.py create mode 100644 dataset.py create mode 100644 fpn.py create mode 100644 load_data.py create mode 100644 model.py create mode 100644 train.py create mode 100644 vgg16.py diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..6cee02e --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..9e40aad --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,54 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..da16ce0 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..d108090 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/remote-mappings.xml b/.idea/remote-mappings.xml new file mode 100644 index 0000000..5db6a93 --- /dev/null +++ b/.idea/remote-mappings.xml @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/segment_snn.iml b/.idea/segment_snn.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/segment_snn.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/basic.py b/basic.py new file mode 100644 index 0000000..ed2b913 --- /dev/null +++ b/basic.py @@ -0,0 +1,133 @@ +from torch import nn +from braincog.base.node.node import * +from functools import partial + +class LayerWiseConvModule(nn.Module): + """ + SNN卷积模块 + :param in_channels: 输入通道数 + :param out_channels: 输出通道数 + :param kernel_size: kernel size + :param stride: stride + :param padding: padding + :param bias: Bias + :param node: 神经元类型 + :param kwargs: + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + node=BiasLIFNode, + step=6, + **kwargs): + + super().__init__() + + if node is None: + raise TypeError + + self.groups = kwargs['groups'] if 'groups' in kwargs else 1 + self.conv = nn.Conv2d(in_channels=in_channels * self.groups, + out_channels=out_channels * self.groups, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=bias) + self.gn = nn.GroupNorm(self.groups, out_channels * self.groups) + self.node = partial(node, **kwargs)() + self.step = step + self.activation = nn.Identity() + + def forward(self, x): + x = rearrange(x, '(t b) c w h -> t b c w h', t=self.step) + outputs = [] + for t in range(self.step): + outputs.append(self.gn(self.conv(x[t]))) + outputs = torch.stack(outputs) # t b c w h + outputs = rearrange(outputs, 't b c w h -> (t b) c w h') + outputs = self.node(outputs) + return outputs + + +class TEP(nn.Module): + def __init__(self, step, channel, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super(TEP, self).__init__() + self.step = step + self.gn = nn.GroupNorm(channel, channel) + + def forward(self, x): + x = rearrange(x, '(t b) c w h -> t b c w h', t=self.step) + fire_rate = torch.mean(x, dim=0) + fire_rate = self.gn(fire_rate) + 1 + + x = x * fire_rate + x = rearrange(x, 't b c w h -> (t b) c w h') + + return x + + +class LayerWiseLinearModule(nn.Module): + """ + 线性模块 + :param in_features: 输入尺寸 + :param out_features: 输出尺寸 + :param bias: 是否有Bias, 默认 ``False`` + :param node: 神经元类型, 默认 ``LIFNode`` + :param args: + :param kwargs: + """ + + def __init__(self, + in_features: int, + out_features: int, + bias=True, + node=BiasLIFNode, + step=6, + spike=True, + *args, + **kwargs): + super().__init__() + if node is None: + raise TypeError + + self.groups = kwargs['groups'] if 'groups' in kwargs else 1 + if self.groups == 1: + self.fc = nn.Linear(in_features=in_features, + out_features=out_features, bias=bias) + else: + self.fc = nn.ModuleList() + for i in range(self.groups): + self.fc.append(nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias + )) + self.node = partial(node, **kwargs)() + self.step = step + self.spike = spike + + def forward(self, x): + if self.groups == 1: # (t b) c + x = rearrange(x, '(t b) c -> t b c', t=self.step) + outputs = [] + for t in range(self.step): + outputs.append(self.fc(x[t])) + outputs = torch.stack(outputs) # t b c + outputs = rearrange(outputs, 't b c -> (t b) c') + else: # b (c t) + x = rearrange(x, 'b (c t) -> t b c', t=self.groups) + outputs = [] + for i in range(self.groups): + outputs.append(self.fc[i](x[i])) + outputs = torch.stack(outputs) # t b c + outputs = rearrange(outputs, 't b c -> b (c t)') + if self.spike: + return self.node(outputs) + else: + return outputs \ No newline at end of file diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..f9b9e86 --- /dev/null +++ b/dataset.py @@ -0,0 +1,259 @@ +import os +import pickle +import torch +import numpy as np +from PIL import Image, ImageOps, ImageFilter +from tqdm import trange +from PIL import Image +from torchvision import transforms +import random +from torchvision import transforms +import torch.utils.data as data + +class SegmentationDataset(object): + """Segmentation Base Dataset""" + + def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): + super(SegmentationDataset, self).__init__() + self.root = root + self.transform = transform + self.split = split + self.mode = mode if mode is not None else split + self.base_size = base_size + self.crop_size = crop_size + + def _val_sync_transform(self, img, mask): + outsize = self.crop_size + short_size = outsize + w, h = img.size + if w > h: + oh = short_size + ow = int(1.0 * w * oh / h) + else: + ow = short_size + oh = int(1.0 * h * ow / w) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # center crop + w, h = img.size + x1 = int(round((w - outsize) / 2.)) + y1 = int(round((h - outsize) / 2.)) + img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) + mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize)) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform(self, img, mask): + # random mirror + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + crop_size = self.crop_size + # random scale (short edge) + short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) + w, h = img.size + if h > w: + ow = short_size + oh = int(1.0 * h * ow / w) + else: + oh = short_size + ow = int(1.0 * w * oh / h) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # pad crop + if short_size < crop_size: + padh = crop_size - oh if oh < crop_size else 0 + padw = crop_size - ow if ow < crop_size else 0 + img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) + mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) + # random crop crop_size + w, h = img.size + x1 = random.randint(0, w - crop_size) + y1 = random.randint(0, h - crop_size) + img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + # gaussian blur as in PSP + if random.random() < 0.5: + img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _img_transform(self, img): + return np.array(img) + + def _mask_transform(self, mask): + return np.array(mask).astype('int32') + + @property + def num_class(self): + """Number of categories.""" + return self.NUM_CLASS + + @property + def pred_offset(self): + return 0 + +class COCOSegmentation(SegmentationDataset): + """COCO Semantic Segmentation Dataset for VOC Pre-training. + + Parameters + ---------- + root : string + Path to ADE20K folder. Default is './datasets/coco' + split: string + 'train', 'val' or 'test' + transform : callable, optional + A function that transforms the image + Examples + -------- + >>> from torchvision import transforms + >>> import torch.utils.data as data + >>> # Transforms for Normalization + >>> input_transform = transforms.Compose([ + >>> transforms.ToTensor(), + >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)), + >>> ]) + >>> # Create Dataset + >>> trainset = COCOSegmentation(split='train', transform=input_transform) + >>> # Create Training Loader + >>> train_data = data.DataLoader( + >>> trainset, 4, shuffle=True, + >>> num_workers=4) + """ + CAT_LIST = [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 72] + NUM_CLASS = len(CAT_LIST) + + def __init__(self, root='../datasets/coco', annotation_root='', split='train', mode=None, transform=None, stride=1, **kwargs): + super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs) + # lazy import pycocotools + from pycocotools.coco import COCO + from pycocotools import mask + if split == 'train': + print('train set') + ann_file = os.path.join(annotation_root, 'annotations/instances_train2017.json') + ids_file = os.path.join(annotation_root, 'annotations/train_ids.mx') + self.root = os.path.join(annotation_root, 'train2017') + else: + print('val set') + ann_file = os.path.join(annotation_root, 'annotations/instances_val2017.json') + ids_file = os.path.join(annotation_root, 'annotations/val_ids.mx') + self.root = os.path.join(root, 'val2017') + self.coco = COCO(ann_file) + self.coco_mask = mask + if os.path.exists(ids_file): + with open(ids_file, 'rb') as f: + self.ids = pickle.load(f) + else: + ids = list(self.coco.imgs.keys()) + self.ids = self._preprocess(ids, ids_file) + self.transform = transform + self.stride = stride + + def __getitem__(self, index): + coco = self.coco + img_id = self.ids[index] + img_metadata = coco.loadImgs(img_id)[0] + path = img_metadata['file_name'] + img = Image.open(os.path.join(self.root, path)).convert('RGB') + cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) + mask = Image.fromarray(self._gen_seg_mask( + cocotarget, img_metadata['height'], img_metadata['width'])) + # synchrosized transform + if self.mode == 'train': + img, mask = self._sync_transform(img, mask) + elif self.mode == 'val': + img, mask = self._val_sync_transform(img, mask) + else: + assert self.mode == 'testval' + img, mask = self._img_transform(img), self._mask_transform(mask) + # general resize, normalize and toTensor + if self.transform is not None: + img = self.transform(img) + frames = torch.diff(self.generate_dynamic_translation(img), dim=0) + p_img = torch.zeros_like(frames) + n_img = torch.zeros_like(frames) + p_img[frames > 0] = frames[frames > 0] + n_img[frames < 0] = frames[frames < 0] + output = torch.concat([p_img, n_img], dim=1) + return output, mask + + def _mask_transform(self, mask): + return torch.LongTensor(np.array(mask).astype('int32')) + + def _gen_seg_mask(self, target, h, w): + mask = np.zeros((h, w), dtype=np.uint8) + coco_mask = self.coco_mask + for instance in target: + rle = coco_mask.frPyObjects(instance['segmentation'], h, w) + m = coco_mask.decode(rle) + cat = instance['category_id'] + if cat in self.CAT_LIST: + c = self.CAT_LIST.index(cat) + else: + continue + if len(m.shape) < 3: + mask[:, :] += (mask == 0) * (m * c) + else: + mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) + return mask + + def generate_dynamic_translation(self, image): + tracex = self.stride * 2 * np.array([0, 2, 1, 0, 2, 1, 1, 2, 1]) + tracey = self.stride * 2 * np.array([0, 1, 2, 1, 0, 2, 1, 1, 2]) + + num_frames = len(tracex) + channel = image.shape[0] + height = image.shape[1] + width = image.shape[2] + + frames = torch.zeros((num_frames, channel, height, width)) + for i in range(num_frames): + anchor_x = tracex[i] + anchor_y = tracey[i] + frames[i, :, anchor_y // 2: height - anchor_y // 2, anchor_x // 2: width - anchor_x // 2] = image[:, anchor_y:, anchor_x:] + return frames + + def _preprocess(self, ids, ids_file): + print("Preprocessing mask, this will take a while." + \ + "But don't worry, it only run once for each split.") + tbar = trange(len(ids)) + new_ids = [] + for i in tbar: + img_id = ids[i] + cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) + img_metadata = self.coco.loadImgs(img_id)[0] + mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width']) + # more than 1k pixels + if (mask > 0).sum() > 1000: + new_ids.append(img_id) + tbar.set_description('Doing: {}/{}, got {} qualified images'. \ + format(i, len(ids), len(new_ids))) + print('Found number of qualified images: ', len(new_ids)) + with open(ids_file, 'wb') as f: + pickle.dump(new_ids, f) + return new_ids + + @property + def classes(self): + """Category names.""" + return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', + 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', + 'tv') + + def __len__(self): + return len(self.root) + +if __name__ == '__main__': + input_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((.485, .456, .406), (.229, .224, .225)), + ]) + # Create Dataset + trainset = COCOSegmentation(root='/kaggle/working/coco', annotation_root='coco_ann2017',split='val', transform=input_transform) + # Create Training Loader + train_data = data.DataLoader( + trainset, 4, shuffle=True, + ) \ No newline at end of file diff --git a/fpn.py b/fpn.py new file mode 100644 index 0000000..e11f2ac --- /dev/null +++ b/fpn.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from basic import * + +class FPNSegmentationHead(nn.Module): + def __init__(self, + in_dim, + out_dim, + decode_intermediate_input=True, + hidden_dim=256, + shortcut_dims=[64, 128, 256, 512], + node=BiasLIFNode, + step=6, + align_corners=True): + super().__init__() + self.align_corners = align_corners + self.node = node + self.step = step + self.decode_intermediate_input = decode_intermediate_input + + + # self.conv_in = ConvGN(in_dim, hidden_dim, 1) + self.conv_in = LayerWiseConvModule(in_dim, hidden_dim, 1, padding=(0, 0), node=BiasLIFNode, step=self.step) + + # self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3) + # self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3) + # self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3) + self.conv_16x = LayerWiseConvModule(hidden_dim, hidden_dim, 3, node=BiasLIFNode, step=self.step) + self.conv_8x = LayerWiseConvModule(hidden_dim, hidden_dim // 2, 3, node=BiasLIFNode, step=self.step) + self.conv_4x = LayerWiseConvModule(hidden_dim // 2, hidden_dim // 2, 3, node=BiasLIFNode, step=self.step) + self.conv_2x = LayerWiseConvModule(hidden_dim // 2, hidden_dim // 2, 3, node=BiasLIFNode, step=self.step) + # self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1) + # self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1) + # self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1) + self.in_tep = TEP(step=self.step, channel=hidden_dim, device=None, dtype=None) + self.adapter_16x = LayerWiseConvModule(shortcut_dims[-1], hidden_dim, 1, padding=(0, 0), node=BiasLIFNode, step=self.step) + self.tep_16x = TEP(step=self.step, channel=hidden_dim, device=None, dtype=None) + self.adapter_8x = LayerWiseConvModule(shortcut_dims[-2], hidden_dim, 1, padding=(0, 0), node=BiasLIFNode, step=self.step) + self.tep_8x = TEP(step=self.step, channel=hidden_dim // 2, device=None, dtype=None) + self.adapter_4x = LayerWiseConvModule(shortcut_dims[-3], hidden_dim // 2, 1, padding=(0, 0), node=BiasLIFNode, step=self.step) + self.tep_4x = TEP(step=self.step, channel=hidden_dim // 2, device=None, dtype=None) + self.adapter_2x = LayerWiseConvModule(shortcut_dims[-4], hidden_dim // 2, 1, padding=(0, 0), node=BiasLIFNode, step=self.step) + self.tep_2x = TEP(step=self.step, channel=hidden_dim // 2, device=None, dtype=None) + + self.conv_out = LayerWiseConvModule(hidden_dim // 2, out_dim, 1, padding=(0, 0), node=BiasLIFNode, step=self.step) + + self._init_weight() + + def forward(self, inputs, shortcuts): + self.reset() + + inputs = rearrange(inputs, 't b c w h -> (t b) c w h') + for i in range(len(shortcuts)): + shortcuts[i] = rearrange(shortcuts[i], 't b c w h -> (t b) c w h') + + # if self.decode_intermediate_input: + # x = torch.cat(inputs, dim=1) + # else: + # x = inputs[-1] + if self.decode_intermediate_input: + x = self.in_tep(self.conv_in(inputs)) + else: + x = inputs + + x = F.interpolate(x, + size=shortcuts[-1].size()[-2:], + mode="bilinear", + align_corners=self.align_corners) + x = self.tep_16x(self.conv_16x(self.adapter_16x(shortcuts[-1]) + x)) + + x = F.interpolate(x, + size=shortcuts[-2].size()[-2:], + mode="bilinear", + align_corners=self.align_corners) + x = self.tep_8x(self.conv_8x(self.adapter_8x(shortcuts[-2]) + x)) + + x = F.interpolate(x, + size=shortcuts[-3].size()[-2:], + mode="bilinear", + align_corners=self.align_corners) + x = self.tep_4x(self.conv_4x(self.adapter_4x(shortcuts[-3]) + x)) + + x = F.interpolate(x, + size=shortcuts[-4].size()[-2:], + mode="bilinear", + align_corners=self.align_corners) + x = self.tep_2x(self.conv_2x(self.adapter_2x(shortcuts[-4]) + x)) + + x = self.conv_out(x) + + return x + + def _init_weight(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def reset(self): + """ + 重置所有神经元的膜电位 + :return: + """ + for mod in self.modules(): + if hasattr(mod, 'n_reset'): + mod.n_reset() diff --git a/load_data.py b/load_data.py new file mode 100644 index 0000000..d94fd0b --- /dev/null +++ b/load_data.py @@ -0,0 +1,202 @@ +import numpy as np +from torchvision import datasets, transforms +import torch +from torch.utils.data import Dataset +import tonic +from tonic import DiskCachedDataset +import torch.nn.functional as F +import os + +MNIST_MEAN = 0.1307 +MNIST_STD = 0.3081 +CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) +CIFAR10_STD_DEV = (0.2023, 0.1994, 0.2010) +cifar100_mean = [0.5071, 0.4865, 0.4409] +cifar100_std = [0.2673, 0.2563, 0.2761] +DVSCIFAR10_MEAN_16 = [0.3290, 0.4507] +DVSCIFAR10_STD_16 = [1.8398, 1.6549] + +DATA_DIR = '/data/datasets' + + +class CustomDataset(Dataset): + """An abstract Dataset class wrapped around Pytorch Dataset class. + """ + + def __init__(self, dataset, indices): + self.dataset = dataset + self.indices = [int(i) for i in indices] + + def __len__(self): + return len(self.indices) + + def __getitem__(self, item): + x, y = self.dataset[self.indices[item]] + return x, y + + +def load_static_data(data_root, batch_size, dataset): + if dataset == 'cifar10': + transform_train = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV)]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV)]) + + train_data = datasets.CIFAR10(data_root, train=True, transform=transform_train, download=True) + test_data = datasets.CIFAR10(data_root, train=False, transform=transform_test, download=True) + + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=batch_size, + shuffle=True + ) + test_loader = torch.utils.data.DataLoader( + test_data, + batch_size=batch_size, + ) + elif dataset == 'MNIST': + transform_train = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(MNIST_MEAN, MNIST_STD)]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(MNIST_MEAN, MNIST_STD)]) + + train_data = datasets.MNIST(data_root, train=True, transform=transform_train, download=True) + test_data = datasets.MNIST(data_root, train=False, transform=transform_test, download=True) + + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=batch_size, + shuffle=True + ) + test_loader = torch.utils.data.DataLoader( + test_data, + batch_size=batch_size, + ) + elif dataset == 'FashionMNIST': + transform_train = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(MNIST_MEAN, MNIST_STD)]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(MNIST_MEAN, MNIST_STD)]) + + train_data = datasets.FashionMNIST(data_root, train=True, transform=transform_train, download=True) + test_data = datasets.FashionMNIST(data_root, train=False, transform=transform_test, download=True) + + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=batch_size, + shuffle=True + ) + test_loader = torch.utils.data.DataLoader( + test_data, + batch_size=batch_size, + ) + + return train_data, test_data, train_loader, test_loader + + +def load_dvs10_data(batch_size, step, **kwargs): + size = kwargs['size'] if 'size' in kwargs else 48 + sensor_size = tonic.datasets.CIFAR10DVS.sensor_size + train_transform = transforms.Compose([ + # tonic.transforms.Denoise(filter_time=10000), + # tonic.transforms.DropEvent(p=0.1), + tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ]) + test_transform = transforms.Compose([ + # tonic.transforms.Denoise(filter_time=10000), + tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ]) + train_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform) + test_dataset = tonic.datasets.CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=test_transform) + + train_transform = transforms.Compose([ + lambda x: torch.tensor(x, dtype=torch.float), + lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True), + ]) + test_transform = transforms.Compose([ + lambda x: torch.tensor(x, dtype=torch.float), + lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True), + ]) + + train_dataset = DiskCachedDataset(train_dataset, + cache_path=f'./dataset/dvs_cifar10/train_cache_{step}', + transform=train_transform) + test_dataset = DiskCachedDataset(train_dataset, + cache_path=f'./dataset/dvs_cifar10/test_cache_{step}', + transform=test_transform) + + print(train_dataset) + + num_train = len(train_dataset) + num_per_cls = num_train // 10 + indices_train, indices_test = [], [] + portion = kwargs['portion'] if 'portion' in kwargs else .9 + for i in range(10): + indices_train.extend( + list(range(i * num_per_cls, round(i * num_per_cls + num_per_cls * portion)))) + indices_test.extend( + list(range(round(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls))) + train_dataset = CustomDataset(train_dataset, np.array(indices_train)) + test_dataset = CustomDataset(test_dataset, np.array(indices_test)) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + pin_memory=True, drop_last=False, num_workers=1 + ) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=batch_size, + pin_memory=True, drop_last=False, num_workers=1 + ) + + return train_loader, test_loader, train_dataset, test_dataset + + +def load_nmnist_data(batch_size, step, **kwargs): + size = kwargs['size'] if 'size' in kwargs else 28 + sensor_size = tonic.datasets.NMNIST.sensor_size + train_transform = transforms.Compose([ + # tonic.transforms.Denoise(filter_time=10000), + # tonic.transforms.DropEvent(p=0.1), + tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ]) + test_transform = transforms.Compose([ + # tonic.transforms.Denoise(filter_time=10000), + tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=step), ]) + train_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=train_transform, train=True) + test_dataset = tonic.datasets.NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), transform=test_transform, train=False) + + train_transform = transforms.Compose([ + transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.float)), + transforms.Lambda(lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True)), + + ]) + test_transform = transforms.Compose([ + transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.float)), + transforms.Lambda(lambda x: F.interpolate(x, size=[size, size], mode='bilinear', align_corners=True)), + ]) + + train_dataset = DiskCachedDataset(train_dataset, + cache_path=f'./dataset/NMNIST/train_cache_{step}', + transform=train_transform) + test_dataset = DiskCachedDataset(test_dataset, + cache_path=f'./dataset/NMNIST/test_cache_{step}', + transform=test_transform) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + pin_memory=True, drop_last=False, num_workers=0 + ) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=batch_size, + pin_memory=True, drop_last=False, num_workers=0 + ) + + return train_loader, test_loader, train_dataset, test_dataset diff --git a/model.py b/model.py new file mode 100644 index 0000000..6d06a37 --- /dev/null +++ b/model.py @@ -0,0 +1,37 @@ +import torch + +from fpn import FPNSegmentationHead +from vgg16 import VGG16 +from torch import nn +from braincog.base.node.node import * + + +class SegmentModel(nn.Module): + def __init__(self, output_size, node=BiasLIFNode, step=6): + super(SegmentModel, self).__init__() + self.output_size = output_size + self.node = node + self.step = step + self.encoder = VGG16() + self.decoder = FPNSegmentationHead(512, 13, + decode_intermediate_input=True, + shortcut_dims=[64, 128, 256, 512], + node=node, + step=step, + align_corners=True) + + def forward(self, x: torch.Tensor = None): + """ + x -> t b c w h + """ + embs = self.encoder(x) + for i in range(len(embs)): + embs[i] = rearrange(embs[i], '(t b) c w h -> t b c w h', t=self.step) + logits = self.decoder(embs[-1], embs[0:-1]) + logits = rearrange(logits, '(t b) c w h -> t b c w h', t=self.step) + out_logits = torch.mean(logits, dim=0) + out_logits = F.interpolate(out_logits, + size=self.output_size, + mode="bilinear", + align_corners=True) + return out_logits diff --git a/train.py b/train.py new file mode 100644 index 0000000..5cc4421 --- /dev/null +++ b/train.py @@ -0,0 +1,129 @@ +import sys + +import torch + +sys.path.append('../../..') +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from torchvision.transforms import * +import time +from braincog.utils import setup_seed +from dataset import SegmentationDataset +from braincog.base.node.node import * +from model import SegmentModel + +device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu') +DATA_DIR = '/data/datasets' + + +def get_cifar10_loader(batch_size, train_batch=None, num_workers=4, conversion=False, distributed=False): + normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + # transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + # CIFAR10Policy(), + # transforms.ToTensor(), + # Cutout(n_holes=1, length=16), + # normalize]) + transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize]) + transform_test = transforms.Compose([transforms.ToTensor(), normalize]) + train_batch = batch_size if train_batch is None else train_batch + cifar10_train = datasets.CIFAR10(root=DATA_DIR, train=True, download=False, + transform=transform_test if conversion else transform_train) + cifar10_test = datasets.CIFAR10(root=DATA_DIR, train=False, download=False, transform=transform_test) + + if distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_train) + val_sampler = torch.utils.data.distributed.DistributedSampler(cifar10_test, shuffle=False, drop_last=True) + train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=False, + num_workers=num_workers, pin_memory=True, sampler=train_sampler) + test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, + num_workers=num_workers, pin_memory=True, sampler=val_sampler) + else: + train_iter = torch.utils.data.DataLoader(cifar10_train, batch_size=train_batch, shuffle=True, + num_workers=num_workers, pin_memory=True) + test_iter = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, + num_workers=num_workers, pin_memory=True) + + return train_iter, test_iter + + +def train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='mse'): + best = 0 + net = net.to(device) + print("training on ", device) + if losstype == 'mse': + loss = torch.nn.MSELoss() + else: + loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1) + losses = [] + + for epoch in range(num_epochs): + for param_group in optimizer.param_groups: + learning_rate = param_group['lr'] + + losss = [] + train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time() + for X, y in train_iter: + optimizer.zero_grad() + # X = X.to(device) + # y = y.to(device) + X = torch.ones(6, 8, 2, 128, 128).to(device) + y = torch.ones(8, 13, 128, 128).to(device) + y_hat = net(X) + label = y + l = loss(y_hat, label) + losss.append(l.cpu().item()) + l.backward() + optimizer.step() + train_l_sum += l.cpu().item() + # train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() + n += y.shape[0] + batch_count += 1 + scheduler.step() + test_acc = evaluate_accuracy(test_iter, net) + losses.append(np.mean(losss)) + print('epoch %d, lr %.6f, loss %.6f, train acc %.6f, test acc %.6f, time %.1f sec' + % (epoch + 1, learning_rate, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start)) + + if test_acc > best: + best = test_acc + torch.save(net.state_dict(), './checkpoints/CIFAR10_VGG16.pth') + + +def evaluate_accuracy(data_iter, net, device=None, only_onebatch=False): + if device is None and isinstance(net, torch.nn.Module): + device = list(net.parameters())[0].device + acc_sum, n = 0.0, 0 + with torch.no_grad(): + for X, y in data_iter: + net.eval() + acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item() + net.train() + n += y.shape[0] + + if only_onebatch: break + return acc_sum / n + + +if __name__ == '__main__': + setup_seed(42) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + batch_size = 8 + step = 6 + train_iter, test_iter, _, _ = get_nmnist_data(batch_size, step) + # train_iter, test_iter = get_cifar10_loader(batch_size) + print('dataloader finished') + + lr, num_epochs = 0.01, 300 + net = SegmentModel(output_size=(128, 128), node=BiasLIFNode, step=step) + optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=num_epochs) + train(net, train_iter, test_iter, optimizer, scheduler, device, num_epochs, losstype='crossentropy')#'crossentropy') + + # net.load_state_dict(torch.load("./CIFAR10_VGG16.pth", map_location=device)) + net = net.to(device) + acc = evaluate_accuracy(test_iter, net, device) + print(acc) diff --git a/vgg16.py b/vgg16.py new file mode 100644 index 0000000..de15ae2 --- /dev/null +++ b/vgg16.py @@ -0,0 +1,80 @@ +from torch import nn +from basic import * + + +class VGG16(nn.Module): + def __init__(self, node=BiasLIFNode, step=6, **kwargs): # 1 3e38 + super(VGG16, self).__init__() + self.step = step + + self.downsample_2x = nn.Sequential( + LayerWiseConvModule(2, 64, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=64, device=None, dtype=None), + LayerWiseConvModule(64, 64, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=64, device=None, dtype=None), + nn.MaxPool2d(2, 2) + ) + + self.downsample_4x = nn.Sequential( + LayerWiseConvModule(64, 128, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=128, device=None, dtype=None), + LayerWiseConvModule(128, 128, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=128, device=None, dtype=None), + nn.MaxPool2d(2, 2) + ) + + self.downsample_8x = nn.Sequential( + LayerWiseConvModule(128, 256, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=256, device=None, dtype=None), + LayerWiseConvModule(256, 256, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=256, device=None, dtype=None), + nn.MaxPool2d(2, 2) + ) + + self.downsample_16x = nn.Sequential( + LayerWiseConvModule(256, 512, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=512, device=None, dtype=None), + LayerWiseConvModule(512, 512, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=512, device=None, dtype=None), + nn.MaxPool2d(2, 2) + ) + + self.downsample_32x = nn.Sequential( + LayerWiseConvModule(512, 512, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=512, device=None, dtype=None), + LayerWiseConvModule(512, 512, 3, 1, 1, node=BiasLIFNode, step=self.step), + TEP(step=self.step, channel=512, device=None, dtype=None), + nn.MaxPool2d(2, 2) + ) + + self.fc = LayerWiseLinearModule(512, 10, bias=True, node=BiasLIFNode, step=self.step) + # self.node = partial(node, **kwargs)() + + def forward(self, input): + self.reset() + # input = input.permute(1, 0, 2, 3, 4) + input = rearrange(input, 't b c w h -> (t b) c w h') + + # embedding + downsample_2x = self.downsample_2x(input) + downsample_4x = self.downsample_4x(downsample_2x) + downsample_8x = self.downsample_8x(downsample_4x) + downsample_16x = self.downsample_16x(downsample_8x) + downsample_32x = self.downsample_32x(downsample_16x) + + shortcuts = [downsample_2x, downsample_4x, downsample_8x, downsample_16x, downsample_32x] + + # x = downsample_32x.view(downsample_32x.shape[0], -1) + # output = self.fc(x) + # outputs = rearrange(output, '(t b) c -> t b c', t=self.step) + + return shortcuts # sum(outputs) / len(outputs) + + def reset(self): + """ + 重置所有神经元的膜电位 + :return: + """ + for mod in self.modules(): + if hasattr(mod, 'n_reset'): + mod.n_reset()