diff --git a/correct_CNN_GRU.ipynb b/correct_CNN_GRU.ipynb new file mode 100644 index 0000000..be6ae43 --- /dev/null +++ b/correct_CNN_GRU.ipynb @@ -0,0 +1,1118 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 解压文件\n", + "# 运行先 pip install mne" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:26.479920Z", + "iopub.status.busy": "2021-10-31T13:40:26.479534Z", + "iopub.status.idle": "2021-10-31T13:40:26.483138Z", + "shell.execute_reply": "2021-10-31T13:40:26.482344Z", + "shell.execute_reply.started": "2021-10-31T13:40:26.479850Z" + } + }, + "outputs": [], + "source": [ + "# !unzip \"data/data112870/BCICIV_2a_mat.zip\" -d data/" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:26.484620Z", + "iopub.status.busy": "2021-10-31T13:40:26.484256Z", + "iopub.status.idle": "2021-10-31T13:40:27.261856Z", + "shell.execute_reply": "2021-10-31T13:40:27.261104Z", + "shell.execute_reply.started": "2021-10-31T13:40:26.484517Z" + } + }, + "outputs": [], + "source": [ + "import mne\n", + "import scipy.io as scio\n", + "import numpy as np\n", + "\n", + "def read_data(data_path,i):\n", + " data_path = data_path\n", + " DATA = scio.loadmat(data_path)\n", + " DATA = DATA[\"data\"][0]\n", + " DATA = DATA[i] #主程序倒叙读6个\n", + " EEG_DATA = DATA[\"X\"][0][0].transpose(1,0)\n", + " EEG_label = DATA['y'][0][0][:,0]\n", + "\n", + " ch_names = ['Fz','Fp1','Fp2','AF3','AF4','AF7','AF8','C3','POz','Cz','PO3','C4','PO4','PO5','PO6','PO7','PO8','Oz','O1','Pz','P6','P7','EOG-left','EOG-central','EOG-right'] #Fz是对应,其余随便写的不影响\n", + " ch_types = ['eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg',\n", + " 'eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg',\n", + " 'eeg','eeg','eog','eog','eog']\n", + "\n", + " info = mne.create_info(ch_names = ch_names,\n", + " ch_types=ch_types,\n", + " sfreq=250)\n", + " info.set_montage('standard_1020')\n", + " raw = mne.io.RawArray(EEG_DATA,info)\n", + "\n", + " n_times = DATA[\"trial\"][0][0][:,0] #时间戳\n", + " event = np.zeros((4,12),int)\n", + " v,b,n,m=0,0,0,0\n", + " for i in range (0,n_times.shape[0]):\n", + " if EEG_label[i]==1:\n", + " event[0,v]=n_times[i]\n", + " v+=1\n", + " if EEG_label[i]==2:\n", + " event[1,b]=n_times[i]\n", + " b+=1\n", + " if EEG_label[i]==3:\n", + " event[2,n]=n_times[i]\n", + " n+=1\n", + " if EEG_label[i]==4:\n", + " event[3,m]=n_times[i]\n", + " m+=1\n", + "\n", + "\n", + " events = np.zeros((48,3),int)\n", + " j=0\n", + " for i in range(events.shape[0]):\n", + " if i<12:\n", + " events[i][0]=event[0][j]\n", + " events[i][2]=1\n", + " elif i<24:\n", + " events[i][0]=event[1][j]\n", + " events[i][2]=2\n", + " elif i<36:\n", + " events[i][0]=event[2][j]\n", + " events[i][2]=3\n", + " elif i<48:\n", + " events[i][0]=event[3][j]\n", + " events[i][2]=4\n", + " j+=1\n", + " if j>=12:\n", + " j=0\n", + " events = sorted(events, key = lambda events: events[0])\n", + " event_id = dict(lefthand=1,righthand=2,feet=3,tongue=4)\n", + " picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, ecg=False,\n", + " exclude='bads')\n", + " epochs = mne.Epochs(raw, events, event_id, tmin=3, tmax=6, proj=True,baseline=(None, None), picks=picks,preload=True)\n", + "\n", + "\n", + " data = epochs.get_data()\n", + " # data =np.array(data).reshape((48*22,751))\n", + " label = EEG_label.repeat(22)\n", + " label=label.reshape(48,22)\n", + " return data,label\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Trian Data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:27.262965Z", + "iopub.status.busy": "2021-10-31T13:40:27.262749Z", + "iopub.status.idle": "2021-10-31T13:40:31.787255Z", + "shell.execute_reply": "2021-10-31T13:40:31.786531Z", + "shell.execute_reply.started": "2021-10-31T13:40:27.262928Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n" + ] + } + ], + "source": [ + "train_path = r\"data/A01T.mat\"\n", + "train_data=np.zeros((6,48,22,751),dtype=np.float32)\n", + "train_label=np.zeros((6,48,22),dtype=np.int64)\n", + "for i in range(6):\n", + " data,label = read_data(train_path,-(i+1))\n", + " train_data[i]=data\n", + " label = np.array(label[:])\n", + " train_label[i]=label\n", + "train_data=train_data.reshape((288,22,751))\n", + "train_label=train_label.reshape((288,22))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test Data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:31.788415Z", + "iopub.status.busy": "2021-10-31T13:40:31.788185Z", + "iopub.status.idle": "2021-10-31T13:40:36.315715Z", + "shell.execute_reply": "2021-10-31T13:40:36.314988Z", + "shell.execute_reply.started": "2021-10-31T13:40:31.788373Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n" + ] + } + ], + "source": [ + "test_path = r\"data/A01E.mat\"\n", + "test_data=np.zeros((6,48,22,751),dtype=np.float32)\n", + "test_label=np.zeros((6,48,22),dtype=np.int64)\n", + "for i in range(6):\n", + " data,label = read_data(test_path,-(i+1))\n", + " test_data[i]=data\n", + " label = np.array(label[:])\n", + " test_label[i]=label\n", + "test_data=test_data.reshape((288,22,751))\n", + "test_label=test_label.reshape((288,22))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:36.316726Z", + "iopub.status.busy": "2021-10-31T13:40:36.316526Z", + "iopub.status.idle": "2021-10-31T13:40:37.506724Z", + "shell.execute_reply": "2021-10-31T13:40:37.505797Z", + "shell.execute_reply.started": "2021-10-31T13:40:36.316688Z" + } + }, + "outputs": [], + "source": [ + "import paddle\n", + "import paddle.nn.functional as F\n", + "import paddle.nn as nn\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:37.509529Z", + "iopub.status.busy": "2021-10-31T13:40:37.509287Z", + "iopub.status.idle": "2021-10-31T13:40:37.520149Z", + "shell.execute_reply": "2021-10-31T13:40:37.519484Z", + "shell.execute_reply.started": "2021-10-31T13:40:37.509483Z" + } + }, + "outputs": [], + "source": [ + "class CNN_GRU(nn.Layer):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv1D(1,32,9,1,padding='same',data_format='NCL')\n", + " self.padding1 = nn.MaxPool1D(kernel_size=2,stride=2,padding='valid')\n", + " self.conv2 = nn.Conv1D(32,32,9,1,padding='same',data_format='NCL')\n", + " self.padding2 = nn.MaxPool1D(kernel_size=2,stride=2,padding='valid')\n", + " self.conv3 = nn.Conv1D(32,32,9,1,padding='same',\n", + " data_format='NCL')\n", + " self.padding3 = nn.MaxPool1D(kernel_size=2,stride=2,padding='valid')\n", + " self.flatten1 = nn.Flatten()\n", + " self.flatten2 = nn.Flatten()\n", + " # self.gru = nn.GRU(1,64,1)\n", + " self.gru = nn.GRU(2976,64,1)\n", + " self.dense1 = nn.Linear(64,64)\n", + " self.dense2 = nn.Linear(64,4)\n", + " self.relu1 = nn.ReLU()\n", + " self.relu2 = nn.ReLU()\n", + " self.relu3 = nn.ReLU()\n", + " self.relu4 = nn.ReLU()\n", + " self.dropout = nn.Dropout(p=0.5)\n", + " def forward(self,x):\n", + " x=self.conv1(x)\n", + " x=self.relu1(x)\n", + " x=self.padding1(x)\n", + " x=self.conv2(x)\n", + " x=self.relu2(x)\n", + " x=self.padding2(x)\n", + " x=self.conv3(x)\n", + " x=self.relu3(x)\n", + " x=self.padding3(x)\n", + " x=self.flatten1(x)\n", + " # x = x.unsqueeze(-1)\n", + " x = x.unsqueeze(1)\n", + " x,h=self.gru(x)\n", + " x=self.dense1(x)\n", + " x=self.relu4(x)\n", + " x=self.dropout(x)\n", + " x=self.dense2(x)\n", + " x=self.flatten2(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:37.521519Z", + "iopub.status.busy": "2021-10-31T13:40:37.521316Z", + "iopub.status.idle": "2021-10-31T13:40:40.385538Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-------------------------------------------------------------------------------\n", + " Layer (type) Input Shape Output Shape Param # \n", + "===============================================================================\n", + " Conv1D-1 [[1, 1, 751]] [1, 32, 751] 320 \n", + " ReLU-1 [[1, 32, 751]] [1, 32, 751] 0 \n", + " MaxPool1D-1 [[1, 32, 751]] [1, 32, 375] 0 \n", + " Conv1D-2 [[1, 32, 375]] [1, 32, 375] 9,248 \n", + " ReLU-2 [[1, 32, 375]] [1, 32, 375] 0 \n", + " MaxPool1D-2 [[1, 32, 375]] [1, 32, 187] 0 \n", + " Conv1D-3 [[1, 32, 187]] [1, 32, 187] 9,248 \n", + " ReLU-3 [[1, 32, 187]] [1, 32, 187] 0 \n", + " MaxPool1D-3 [[1, 32, 187]] [1, 32, 93] 0 \n", + " Flatten-1 [[1, 32, 93]] [1, 2976] 0 \n", + " GRU-1 [[1, 1, 2976]] [[1, 1, 64], [1, 1, 64]] 584,064 \n", + " Linear-1 [[1, 1, 64]] [1, 1, 64] 4,160 \n", + " ReLU-4 [[1, 1, 64]] [1, 1, 64] 0 \n", + " Dropout-1 [[1, 1, 64]] [1, 1, 64] 0 \n", + " Linear-2 [[1, 1, 64]] [1, 1, 4] 260 \n", + " Flatten-2 [[1, 1, 4]] [1, 4] 0 \n", + "===============================================================================\n", + "Total params: 607,300\n", + "Trainable params: 607,300\n", + "Non-trainable params: 0\n", + "-------------------------------------------------------------------------------\n", + "Input size (MB): 0.00\n", + "Forward/backward pass size (MB): 0.83\n", + "Params size (MB): 2.32\n", + "Estimated Total Size (MB): 3.15\n", + "-------------------------------------------------------------------------------\n", + "\n", + "[1, 4]\n" + ] + } + ], + "source": [ + "input_spec = paddle.static.InputSpec(\n", + " shape=(-1,1,751),\n", + " dtype='float32',\n", + " name='x'\n", + ")\n", + "model = CNN_GRU()\n", + "\n", + "paddle.summary(model,input_spec)\n", + "out = model(paddle.randn((1,1,750)))\n", + "print(out.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:40.386705Z", + "iopub.status.busy": "2021-10-31T13:40:40.386425Z", + "iopub.status.idle": "2021-10-31T13:40:40.390461Z" + } + }, + "outputs": [], + "source": [ + "loss_function = paddle.nn.CrossEntropyLoss()\n", + "# lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(\n", + "# learning_rate=LEARNING_RATE,\n", + "# T_max=NUM_EPOCHS\n", + "# )\n", + "optimizer = paddle.optimizer.Adam(\n", + " learning_rate=1e-4,\n", + " parameters=model.parameters()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 8:2 T" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:40.391554Z", + "iopub.status.busy": "2021-10-31T13:40:40.391337Z", + "iopub.status.idle": "2021-10-31T13:40:40.395338Z" + } + }, + "outputs": [], + "source": [ + "# data = train_data\n", + "# label = train_label-1\n", + "# # print(data.shape,label.shape)\n", + "\n", + "# from sklearn.preprocessing import StandardScaler #标准化\n", + "# from sklearn.preprocessing import MinMaxScaler\n", + "# sk = StandardScaler()\n", + "# # data = sk.fit_transform(data)\n", + "\n", + "# from sklearn.model_selection import train_test_split\n", + "# # data=np.expand_dims(data,1)\n", + "# # label=np.expand_dims(label,1)\n", + "# train_data,test_data,train_label,test_label = train_test_split(data,label,train_size=0.8)\n", + "# print(train_data.shape,train_label.shape,test_data.shape,test_label.shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:40.396313Z", + "iopub.status.busy": "2021-10-31T13:40:40.396108Z", + "iopub.status.idle": "2021-10-31T13:40:40.399308Z" + } + }, + "outputs": [], + "source": [ + "# train_data=train_data.reshape((5060,751))\n", + "# test_data=test_data.reshape((1276,751))\n", + "# train_label=train_label.reshape((5060))\n", + "# test_label=test_label.reshape((1276))\n", + "# from sklearn.preprocessing import StandardScaler #标准化\n", + "# from sklearn.preprocessing import MinMaxScaler\n", + "# # sk = StandardScaler()\n", + "# sk = MinMaxScaler()\n", + "# train_data = sk.fit_transform(train_data)\n", + "# test_data = sk.fit_transform(test_data)\n", + "# train_data=np.expand_dims(train_data,1)\n", + "# test_data=np.expand_dims(test_data,1)\n", + "# train_label=np.expand_dims(train_label,1)\n", + "# test_label=np.expand_dims(test_label,1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 8:2 E" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:40.400221Z", + "iopub.status.busy": "2021-10-31T13:40:40.400000Z", + "iopub.status.idle": "2021-10-31T13:40:40.770716Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(288, 22, 751) (288, 22)\n", + "(230, 22, 751) (230, 22) (58, 22, 751) (58, 22)\n" + ] + } + ], + "source": [ + "data = test_data\n", + "label = test_label-1\n", + "print(data.shape,label.shape)\n", + "\n", + "from sklearn.preprocessing import StandardScaler #标准化\n", + "from sklearn.preprocessing import MinMaxScaler\n", + "sk = StandardScaler()\n", + "# data = sk.fit_transform(data)\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "# data=np.expand_dims(data,1)\n", + "# label=np.expand_dims(label,1)\n", + "train_data,test_data,train_label,test_label = train_test_split(data,label,train_size=0.8)\n", + "print(train_data.shape,train_label.shape,test_data.shape,test_label.shape)\n", + "train_data=train_data.reshape((5060,751))\n", + "test_data=test_data.reshape((1276,751))\n", + "train_label=train_label.reshape((5060))\n", + "test_label=test_label.reshape((1276))\n", + "from sklearn.preprocessing import StandardScaler #标准化\n", + "from sklearn.preprocessing import MinMaxScaler\n", + "# sk = StandardScaler()\n", + "sk = MinMaxScaler()\n", + "train_data = sk.fit_transform(train_data)\n", + "test_data = sk.fit_transform(test_data)\n", + "train_data=np.expand_dims(train_data,1)\n", + "test_data=np.expand_dims(test_data,1)\n", + "train_label=np.expand_dims(train_label,1)\n", + "test_label=np.expand_dims(test_label,1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:40.771963Z", + "iopub.status.busy": "2021-10-31T13:40:40.771701Z", + "iopub.status.idle": "2021-10-31T13:40:40.774979Z" + } + }, + "outputs": [], + "source": [ + "NUM_EPOCHS=100\n", + "TRAIN_BATCH_SIZE=28" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 五折交叉验证" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:40.776017Z", + "iopub.status.busy": "2021-10-31T13:40:40.775781Z", + "iopub.status.idle": "2021-10-31T13:40:40.780996Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from paddle.io import Dataset\n", + "\n", + "\n", + "class MIDataset(Dataset):\n", + " def __init__(self, data,label):\n", + " self.data = data\n", + " self.label = label\n", + "\n", + " def __getitem__(self, idx):\n", + " data = self.data[idx]\n", + " label = self.label[idx]\n", + " return data, label\n", + "\n", + " def __len__(self):\n", + " return len(self.label)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T13:40:40.781962Z", + "iopub.status.busy": "2021-10-31T13:40:40.781756Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The loss value printed in the log is the current step, and the metric is the average value of previous steps.\n", + "Epoch 1/100\n", + "step 20/180 [==>...........................] - loss: 1.3685 - acc: 0.2196 - ETA: 2s - 15ms/step" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", + " return (isinstance(seq, collections.Sequence) and\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 180/180 [==============================] - loss: 1.3689 - acc: 0.2867 - 8ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 1.3724 - acc: 0.2508 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 2/100\n", + "step 180/180 [==============================] - loss: 1.2532 - acc: 0.4012 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 1.4486 - acc: 0.2202 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 3/100\n", + "step 180/180 [==============================] - loss: 1.1388 - acc: 0.5438 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 1.3594 - acc: 0.3245 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 4/100\n", + "step 180/180 [==============================] - loss: 0.9651 - acc: 0.6917 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 1.8689 - acc: 0.2939 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 5/100\n", + "step 180/180 [==============================] - loss: 0.5273 - acc: 0.7786 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 1.9893 - acc: 0.2861 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 6/100\n", + "step 180/180 [==============================] - loss: 0.4574 - acc: 0.8623 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 2.2041 - acc: 0.2672 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 7/100\n", + "step 180/180 [==============================] - loss: 0.2863 - acc: 0.9022 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 2.3970 - acc: 0.2461 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 8/100\n", + "step 180/180 [==============================] - loss: 0.4334 - acc: 0.9313 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 2.9350 - acc: 0.2320 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 9/100\n", + "step 180/180 [==============================] - loss: 0.1537 - acc: 0.9571 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 2.5006 - acc: 0.2414 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 10/100\n", + "step 180/180 [==============================] - loss: 0.1345 - acc: 0.9667 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 3.3403 - acc: 0.2500 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 11/100\n", + "step 180/180 [==============================] - loss: 0.0554 - acc: 0.9772 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 2.6429 - acc: 0.2500 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 12/100\n", + "step 180/180 [==============================] - loss: 0.1005 - acc: 0.9796 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 3.2488 - acc: 0.2531 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 13/100\n", + "step 180/180 [==============================] - loss: 0.0705 - acc: 0.9879 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.0855 - acc: 0.2524 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 14/100\n", + "step 180/180 [==============================] - loss: 0.0748 - acc: 0.9901 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.3394 - acc: 0.2610 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 15/100\n", + "step 180/180 [==============================] - loss: 0.0311 - acc: 0.9917 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.1119 - acc: 0.2539 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 16/100\n", + "step 180/180 [==============================] - loss: 0.0730 - acc: 0.9940 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 3.7751 - acc: 0.2524 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 17/100\n", + "step 180/180 [==============================] - loss: 0.0077 - acc: 0.9952 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.8625 - acc: 0.2633 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 18/100\n", + "step 180/180 [==============================] - loss: 0.0357 - acc: 0.9948 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.4577 - acc: 0.2571 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 19/100\n", + "step 180/180 [==============================] - loss: 0.0235 - acc: 0.9970 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 3.3805 - acc: 0.2516 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 20/100\n", + "step 180/180 [==============================] - loss: 0.0179 - acc: 0.9960 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 6.6039 - acc: 0.2476 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 21/100\n", + "step 180/180 [==============================] - loss: 0.0194 - acc: 0.9972 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.7909 - acc: 0.2484 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 22/100\n", + "step 180/180 [==============================] - loss: 0.0086 - acc: 0.9976 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.0409 - acc: 0.2563 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 23/100\n", + "step 180/180 [==============================] - loss: 0.0059 - acc: 0.9976 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 3.6320 - acc: 0.2578 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 24/100\n", + "step 180/180 [==============================] - loss: 0.0043 - acc: 0.9988 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.2748 - acc: 0.2610 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 25/100\n", + "step 180/180 [==============================] - loss: 0.0013 - acc: 0.9986 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.3310 - acc: 0.2578 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 26/100\n", + "step 180/180 [==============================] - loss: 0.0154 - acc: 0.9990 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 6.0418 - acc: 0.2571 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 27/100\n", + "step 180/180 [==============================] - loss: 0.0032 - acc: 0.9990 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.3031 - acc: 0.2665 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 28/100\n", + "step 180/180 [==============================] - loss: 0.0125 - acc: 0.9988 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.9294 - acc: 0.2445 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 29/100\n", + "step 180/180 [==============================] - loss: 0.0201 - acc: 0.9986 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.8028 - acc: 0.2594 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 30/100\n", + "step 180/180 [==============================] - loss: 0.0025 - acc: 0.9992 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 3.8357 - acc: 0.2602 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 31/100\n", + "step 180/180 [==============================] - loss: 0.0039 - acc: 0.9980 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.4144 - acc: 0.2500 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 32/100\n", + "step 180/180 [==============================] - loss: 0.0035 - acc: 0.9996 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 6.5052 - acc: 0.2500 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 33/100\n", + "step 180/180 [==============================] - loss: 0.0017 - acc: 0.9984 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.8183 - acc: 0.2508 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 34/100\n", + "step 180/180 [==============================] - loss: 7.3235e-04 - acc: 0.9994 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.7631 - acc: 0.2563 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 35/100\n", + "step 180/180 [==============================] - loss: 0.0121 - acc: 0.9994 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.1270 - acc: 0.2547 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 36/100\n", + "step 180/180 [==============================] - loss: 7.2671e-04 - acc: 0.9984 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.6140 - acc: 0.2484 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 37/100\n", + "step 180/180 [==============================] - loss: 0.0114 - acc: 0.9998 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.3074 - acc: 0.2610 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 38/100\n", + "step 180/180 [==============================] - loss: 0.0034 - acc: 0.9996 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 9.1694 - acc: 0.2555 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 39/100\n", + "step 180/180 [==============================] - loss: 0.0017 - acc: 0.9996 - 8ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.5135 - acc: 0.2657 - 6ms/step \n", + "Eval samples: 1276\n", + "Epoch 40/100\n", + "step 180/180 [==============================] - loss: 0.0106 - acc: 0.9998 - 8ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.0482 - acc: 0.2649 - 5ms/step \n", + "Eval samples: 1276\n", + "Epoch 41/100\n", + "step 180/180 [==============================] - loss: 0.0183 - acc: 0.9996 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.5344 - acc: 0.2516 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 42/100\n", + "step 180/180 [==============================] - loss: 0.0055 - acc: 0.9980 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.9336 - acc: 0.2578 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 43/100\n", + "step 180/180 [==============================] - loss: 0.0024 - acc: 0.9998 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.1236 - acc: 0.2571 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 44/100\n", + "step 180/180 [==============================] - loss: 0.0020 - acc: 0.9992 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 6.4246 - acc: 0.2633 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 45/100\n", + "step 180/180 [==============================] - loss: 4.0084e-04 - acc: 0.9994 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.8445 - acc: 0.2586 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 46/100\n", + "step 180/180 [==============================] - loss: 9.7069e-04 - acc: 0.9996 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 6.2878 - acc: 0.2633 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 47/100\n", + "step 180/180 [==============================] - loss: 2.9551e-04 - acc: 0.9998 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 8.7915 - acc: 0.2539 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 48/100\n", + "step 180/180 [==============================] - loss: 2.7383e-04 - acc: 0.9998 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.0248 - acc: 0.2719 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 49/100\n", + "step 180/180 [==============================] - loss: 0.0045 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 6.7661 - acc: 0.2657 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 50/100\n", + "step 180/180 [==============================] - loss: 1.6638e-04 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 8.4058 - acc: 0.2618 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 51/100\n", + "step 180/180 [==============================] - loss: 2.6548e-04 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.8900 - acc: 0.2563 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 52/100\n", + "step 180/180 [==============================] - loss: 0.0011 - acc: 0.9990 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.7754 - acc: 0.2727 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 53/100\n", + "step 180/180 [==============================] - loss: 2.5093e-04 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 6.7492 - acc: 0.2602 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 54/100\n", + "step 180/180 [==============================] - loss: 5.5084e-04 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 8.5714 - acc: 0.2531 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 55/100\n", + "step 180/180 [==============================] - loss: 0.0224 - acc: 0.9968 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.9686 - acc: 0.2868 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 56/100\n", + "step 180/180 [==============================] - loss: 0.0046 - acc: 0.9982 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 8.2194 - acc: 0.2516 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 57/100\n", + "step 180/180 [==============================] - loss: 8.8604e-04 - acc: 0.9998 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.5505 - acc: 0.2625 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 58/100\n", + "step 180/180 [==============================] - loss: 0.0011 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 10.1895 - acc: 0.2633 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 59/100\n", + "step 180/180 [==============================] - loss: 0.0021 - acc: 0.9996 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.0318 - acc: 0.2602 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 60/100\n", + "step 180/180 [==============================] - loss: 2.6387e-04 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 8.7751 - acc: 0.2539 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 61/100\n", + "step 180/180 [==============================] - loss: 6.4277e-05 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.8332 - acc: 0.2704 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 62/100\n", + "step 180/180 [==============================] - loss: 9.8329e-05 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.5896 - acc: 0.2719 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 63/100\n", + "step 180/180 [==============================] - loss: 1.3659e-04 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 9.1747 - acc: 0.2696 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 64/100\n", + "step 180/180 [==============================] - loss: 3.7494e-04 - acc: 0.9998 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 9.4773 - acc: 0.2641 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 65/100\n", + "step 180/180 [==============================] - loss: 6.0358e-04 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.6718 - acc: 0.2680 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 66/100\n", + "step 180/180 [==============================] - loss: 0.0068 - acc: 0.9964 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 8.4776 - acc: 0.2884 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 67/100\n", + "step 180/180 [==============================] - loss: 2.9123e-05 - acc: 0.9998 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 5.0846 - acc: 0.2743 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 68/100\n", + "step 180/180 [==============================] - loss: 1.1887e-04 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.3466 - acc: 0.2696 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 69/100\n", + "step 180/180 [==============================] - loss: 4.5624e-04 - acc: 0.9998 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 6.5277 - acc: 0.2751 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 70/100\n", + "step 180/180 [==============================] - loss: 0.0017 - acc: 0.9994 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 10.1823 - acc: 0.2727 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 71/100\n", + "step 180/180 [==============================] - loss: 0.0030 - acc: 0.9986 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 7.4184 - acc: 0.2719 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 72/100\n", + "step 180/180 [==============================] - loss: 4.5291e-04 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 6.8726 - acc: 0.2735 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 73/100\n", + "step 180/180 [==============================] - loss: 2.2295e-05 - acc: 1.0000 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 4.9751 - acc: 0.2712 - 4ms/step \n", + "Eval samples: 1276\n", + "Epoch 74/100\n", + "step 50/180 [=======>......................] - loss: 6.0246e-05 - acc: 1.0000 - ETA: 0s - 7ms/step" + ] + } + ], + "source": [ + "train_dataset = MIDataset(train_data,train_label)\n", + "test_dataset = MIDataset(test_data,test_label)\n", + "train_loader = paddle.io.DataLoader(\n", + " train_dataset, \n", + " batch_size=28, \n", + " shuffle=True,\n", + " drop_last=True)\n", + "test_loader = paddle.io.DataLoader(\n", + " test_dataset, \n", + " batch_size=28,\n", + " shuffle=True\n", + ")\n", + "model = CNN_GRU()\n", + "loss_function = paddle.nn.CrossEntropyLoss()\n", + "optimizer = paddle.optimizer.Adam(\n", + "learning_rate=1e-4,\n", + "parameters=model.parameters())\n", + "model_handler = paddle.Model(model)\n", + "model_handler.prepare(optimizer=optimizer,\n", + " loss=loss_function,\n", + " metrics=paddle.metric.Accuracy())\n", + "model_handler.fit(train_loader,\n", + " test_loader,\n", + " epochs=100,\n", + " batch_size=28,\n", + " verbose=1)\n", + "result=(model_handler.evaluate(test_dataset, batch_size=test_data.shape[0]))\n", + "result" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "0e2cd77eede4fb534e37c318dcb84244d460c76d4a4178970183713b6b86de86" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "py35-paddle1.2.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/original_CNN_GRU.ipynb b/original_CNN_GRU.ipynb new file mode 100644 index 0000000..ea0e2a3 --- /dev/null +++ b/original_CNN_GRU.ipynb @@ -0,0 +1,718 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 运行先安MNE,pip install mne\n", + "# 解压文件" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:33.437387Z", + "iopub.status.busy": "2021-10-31T14:04:33.437174Z", + "iopub.status.idle": "2021-10-31T14:04:33.439881Z", + "shell.execute_reply": "2021-10-31T14:04:33.439341Z", + "shell.execute_reply.started": "2021-10-31T14:04:33.437340Z" + } + }, + "outputs": [], + "source": [ + "# !unzip \"data/data112870/BCICIV_2a_mat.zip\" -d data/" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:33.440707Z", + "iopub.status.busy": "2021-10-31T14:04:33.440507Z", + "iopub.status.idle": "2021-10-31T14:04:34.257824Z", + "shell.execute_reply": "2021-10-31T14:04:34.256948Z", + "shell.execute_reply.started": "2021-10-31T14:04:33.440665Z" + } + }, + "outputs": [], + "source": [ + "from read_data import read_data\r\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:34.259049Z", + "iopub.status.busy": "2021-10-31T14:04:34.258711Z", + "iopub.status.idle": "2021-10-31T14:04:34.262600Z", + "shell.execute_reply": "2021-10-31T14:04:34.261629Z", + "shell.execute_reply.started": "2021-10-31T14:04:34.258894Z" + } + }, + "outputs": [], + "source": [ + "train_path = r\"data/A01T.mat\"\n", + "test_path = r\"data/A01E.mat\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Trian Data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:34.264059Z", + "iopub.status.busy": "2021-10-31T14:04:34.263681Z", + "iopub.status.idle": "2021-10-31T14:04:38.791645Z", + "shell.execute_reply": "2021-10-31T14:04:38.790940Z", + "shell.execute_reply.started": "2021-10-31T14:04:34.263956Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n", + "Creating RawArray with float64 data, n_channels=25, n_times=96735\n", + " Range : 0 ... 96734 = 0.000 ... 386.936 secs\n", + "Ready.\n", + "Not setting metadata\n", + "Not setting metadata\n", + "48 matching events found\n", + "Setting baseline interval to [3.0, 6.0] sec\n", + "Applying baseline correction (mode: mean)\n", + "0 projection items activated\n", + "Loading data for 48 events and 751 original time points ...\n", + "0 bad epochs dropped\n" + ] + } + ], + "source": [ + "train_data=np.zeros((6,1056,751),dtype=np.float32)\n", + "train_label=np.zeros((6,1056),dtype=np.int64)\n", + "for i in range(6):\n", + " data,label = read_data(train_path,-(i+1))\n", + " train_data[i]=data\n", + " label = np.array(label[:])\n", + " train_label[i]=label\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test Data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:38.792783Z", + "iopub.status.busy": "2021-10-31T14:04:38.792562Z", + "iopub.status.idle": "2021-10-31T14:04:38.795508Z", + "shell.execute_reply": "2021-10-31T14:04:38.794844Z", + "shell.execute_reply.started": "2021-10-31T14:04:38.792741Z" + } + }, + "outputs": [], + "source": [ + "# test_data=np.zeros((6,1056,751),dtype=np.float32)\n", + "# test_label=np.zeros((6,1056),dtype=np.int64)\n", + "# for i in range(6):\n", + "# data,label = read_data(test_path,-(i+1))\n", + "# test_data[i]=data\n", + "# label = np.array(label[:])\n", + "# test_label[i]=label\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data_Reshape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:38.797797Z", + "iopub.status.busy": "2021-10-31T14:04:38.797560Z", + "iopub.status.idle": "2021-10-31T14:04:38.805968Z", + "shell.execute_reply": "2021-10-31T14:04:38.804683Z", + "shell.execute_reply.started": "2021-10-31T14:04:38.797752Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-4.29707\n" + ] + } + ], + "source": [ + "train_data = train_data.reshape(6336,751)\n", + "train_label = train_label.reshape(6336,)\n", + "# test_data = test_data.reshape(6336,751)\n", + "# test_label = test_label.reshape(6336,)\n", + "print(train_data[0][0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 验证标签的对错" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:38.809554Z", + "iopub.status.busy": "2021-10-31T14:04:38.809343Z", + "iopub.status.idle": "2021-10-31T14:04:40.014017Z" + } + }, + "outputs": [], + "source": [ + "import paddle\n", + "import paddle.nn.functional as F\n", + "import paddle.nn as nn\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:40.015362Z", + "iopub.status.busy": "2021-10-31T14:04:40.015130Z", + "iopub.status.idle": "2021-10-31T14:04:40.026178Z" + } + }, + "outputs": [], + "source": [ + "class CNN_GRU(nn.Layer):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv1D(1,32,9,1,padding='same',data_format='NCL')\n", + " self.padding1 = nn.MaxPool1D(kernel_size=2,stride=2,padding='valid')\n", + " self.conv2 = nn.Conv1D(32,32,9,1,padding='same',data_format='NCL')\n", + " self.padding2 = nn.MaxPool1D(kernel_size=2,stride=2,padding='valid')\n", + " self.conv3 = nn.Conv1D(32,32,9,1,padding='same',\n", + " data_format='NCL')\n", + " self.padding3 = nn.MaxPool1D(kernel_size=2,stride=2,padding='valid')\n", + " self.flatten1 = nn.Flatten()\n", + " self.flatten2 = nn.Flatten()\n", + " # self.gru = nn.GRU(1,64,1)\n", + " self.gru = nn.GRU(2976,64,1)\n", + " self.dense1 = nn.Linear(64,64)\n", + " self.dense2 = nn.Linear(64,4)\n", + " self.relu1 = nn.ReLU()\n", + " self.relu2 = nn.ReLU()\n", + " self.relu3 = nn.ReLU()\n", + " self.relu4 = nn.ReLU()\n", + " self.dropout = nn.Dropout(p=0.5)\n", + " def forward(self,x):\n", + " x=self.conv1(x)\n", + " x=self.relu1(x)\n", + " x=self.padding1(x)\n", + " x=self.conv2(x)\n", + " x=self.relu2(x)\n", + " x=self.padding2(x)\n", + " x=self.conv3(x)\n", + " x=self.relu3(x)\n", + " x=self.padding3(x)\n", + " x=self.flatten1(x)\n", + " # x = x.unsqueeze(-1)\n", + " x = x.unsqueeze(1)\n", + " x,h=self.gru(x)\n", + " x=self.dense1(x)\n", + " x=self.relu4(x)\n", + " x=self.dropout(x)\n", + " x=self.dense2(x)\n", + " x=self.flatten2(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:40.027091Z", + "iopub.status.busy": "2021-10-31T14:04:40.026898Z", + "iopub.status.idle": "2021-10-31T14:04:42.900923Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-------------------------------------------------------------------------------\n", + " Layer (type) Input Shape Output Shape Param # \n", + "===============================================================================\n", + " Conv1D-1 [[1, 1, 751]] [1, 32, 751] 320 \n", + " ReLU-1 [[1, 32, 751]] [1, 32, 751] 0 \n", + " MaxPool1D-1 [[1, 32, 751]] [1, 32, 375] 0 \n", + " Conv1D-2 [[1, 32, 375]] [1, 32, 375] 9,248 \n", + " ReLU-2 [[1, 32, 375]] [1, 32, 375] 0 \n", + " MaxPool1D-2 [[1, 32, 375]] [1, 32, 187] 0 \n", + " Conv1D-3 [[1, 32, 187]] [1, 32, 187] 9,248 \n", + " ReLU-3 [[1, 32, 187]] [1, 32, 187] 0 \n", + " MaxPool1D-3 [[1, 32, 187]] [1, 32, 93] 0 \n", + " Flatten-1 [[1, 32, 93]] [1, 2976] 0 \n", + " GRU-1 [[1, 1, 2976]] [[1, 1, 64], [1, 1, 64]] 584,064 \n", + " Linear-1 [[1, 1, 64]] [1, 1, 64] 4,160 \n", + " ReLU-4 [[1, 1, 64]] [1, 1, 64] 0 \n", + " Dropout-1 [[1, 1, 64]] [1, 1, 64] 0 \n", + " Linear-2 [[1, 1, 64]] [1, 1, 4] 260 \n", + " Flatten-2 [[1, 1, 4]] [1, 4] 0 \n", + "===============================================================================\n", + "Total params: 607,300\n", + "Trainable params: 607,300\n", + "Non-trainable params: 0\n", + "-------------------------------------------------------------------------------\n", + "Input size (MB): 0.00\n", + "Forward/backward pass size (MB): 0.83\n", + "Params size (MB): 2.32\n", + "Estimated Total Size (MB): 3.15\n", + "-------------------------------------------------------------------------------\n", + "\n", + "Tensor(shape=[1, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [[-0.36932158, -0.30767888, 1.12718070, 0.78158110]])\n" + ] + } + ], + "source": [ + "input_spec = paddle.static.InputSpec(\n", + " shape=(-1,1,751),\n", + " dtype='float32',\n", + " name='x'\n", + ")\n", + "model = CNN_GRU()\n", + "\n", + "paddle.summary(model,input_spec)\n", + "out = model(paddle.randn((1,1,750)))\n", + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:42.902024Z", + "iopub.status.busy": "2021-10-31T14:04:42.901792Z", + "iopub.status.idle": "2021-10-31T14:04:42.905556Z" + } + }, + "outputs": [], + "source": [ + "loss_function = paddle.nn.CrossEntropyLoss()\n", + "# lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(\n", + "# learning_rate=LEARNING_RATE,\n", + "# T_max=NUM_EPOCHS\n", + "# )\n", + "optimizer = paddle.optimizer.Adam(\n", + " learning_rate=1e-4,\n", + " parameters=model.parameters()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 8:2 T" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:42.906486Z", + "iopub.status.busy": "2021-10-31T14:04:42.906292Z", + "iopub.status.idle": "2021-10-31T14:04:43.315228Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(6336, 751) (6336,)\n" + ] + } + ], + "source": [ + "data = train_data\n", + "label = train_label-1\n", + "print(data.shape,label.shape)\n", + "\n", + "from sklearn.preprocessing import StandardScaler #标准化\n", + "from sklearn.preprocessing import MinMaxScaler\n", + "sk = StandardScaler()\n", + "data = sk.fit_transform(data)\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "data=np.expand_dims(data,1)\n", + "label=np.expand_dims(label,1)\n", + "# train_data,test_data,train_label,test_label = train_test_split(data,label,train_size=0.8)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 8:2 E" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:43.316343Z", + "iopub.status.busy": "2021-10-31T14:04:43.316132Z", + "iopub.status.idle": "2021-10-31T14:04:43.319119Z" + } + }, + "outputs": [], + "source": [ + "# data = test_data\n", + "# label = test_label-1\n", + "# print(data.shape,label.shape)\n", + "\n", + "# from sklearn.preprocessing import StandardScaler #标准化\n", + "# from sklearn.preprocessing import MinMaxScaler\n", + "# sk = StandardScaler()\n", + "# data = sk.fit_transform(data)\n", + "# from sklearn.model_selection import train_test_split\n", + "# label=label\n", + "# data=np.expand_dims(data,1)\n", + "# label=np.expand_dims(label,1)\n", + "## train_data,test_data,train_label,test_label = train_test_split(data,label,train_size=0.8)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E.T" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:43.320050Z", + "iopub.status.busy": "2021-10-31T14:04:43.319843Z", + "iopub.status.idle": "2021-10-31T14:04:43.323972Z" + } + }, + "outputs": [], + "source": [ + "# train_label=train_label-1\n", + "# test_label=test_label-1\n", + "# from sklearn.preprocessing import StandardScaler #标准化\n", + "# from sklearn.preprocessing import MinMaxScaler\n", + "# sk = StandardScaler()\n", + "# train_data = sk.fit_transform(train_data)\n", + "# test_data = sk.fit_transform(test_data)\n", + "# train_data=np.expand_dims(train_data,1)\n", + "# train_label=np.expand_dims(train_label,1)\n", + "# test_data=np.expand_dims(test_data,1)\n", + "# test_label=np.expand_dims(test_label,1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:43.324844Z", + "iopub.status.busy": "2021-10-31T14:04:43.324658Z", + "iopub.status.idle": "2021-10-31T14:04:43.328132Z" + } + }, + "outputs": [], + "source": [ + "NUM_EPOCHS=100\n", + "TRAIN_BATCH_SIZE=28\n", + "# print(train_label.shape,test_label.shape)\n", + "# print(train_data.shape,test_data.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 五折交叉验证" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:43.329003Z", + "iopub.status.busy": "2021-10-31T14:04:43.328800Z", + "iopub.status.idle": "2021-10-31T14:04:43.337231Z" + } + }, + "outputs": [], + "source": [ + "def k_fold_indices(num_samples, k=5):\n", + " indices = np.random.permutation(num_samples)\n", + " if k <= 1:\n", + " print(f\"Expect k>1, but received k={k}\")\n", + " n_samples_per_fold = num_samples // k\n", + "\n", + " anchors = [n_samples_per_fold for _ in range(k)]\n", + " for i in range(0, num_samples - k * n_samples_per_fold):\n", + " anchors[i] += 1\n", + " for i in range(1, len(anchors)):\n", + " anchors[i] = anchors[i - 1] + anchors[i]\n", + " assert anchors[-1] == num_samples\n", + " anchors.pop(-1)\n", + " assert len(anchors) == k - 1\n", + " \n", + " folds = np.split(indices, anchors)\n", + " train_indices = [np.empty(0, dtype=np.int64) for _ in range(k)]\n", + " valid_indices = []\n", + "\n", + " for i, fold in enumerate(folds):\n", + " valid_indices.append(fold)\n", + " for j in range(k):\n", + " if j == i:\n", + " continue\n", + " train_indices[j] = np.concatenate((train_indices[j], fold))\n", + "\n", + " return list(zip(train_indices, valid_indices))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:43.338052Z", + "iopub.status.busy": "2021-10-31T14:04:43.337872Z", + "iopub.status.idle": "2021-10-31T14:04:43.342351Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from paddle.io import Dataset\n", + "\n", + "\n", + "class MIDataset(Dataset):\n", + " def __init__(self, data,label):\n", + " self.data = data\n", + " self.label = label\n", + "\n", + " def __getitem__(self, idx):\n", + " data = self.data[idx]\n", + " label = self.label[idx]\n", + " return data, label\n", + "\n", + " def __len__(self):\n", + " return len(self.label)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-10-31T14:04:43.343163Z", + "iopub.status.busy": "2021-10-31T14:04:43.342981Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The loss value printed in the log is the current step, and the metric is the average value of previous steps.\n", + "Epoch 1/100\n", + "step 20/181 [==>...........................] - loss: 1.3826 - acc: 0.2786 - ETA: 2s - 15ms/step" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", + " return (isinstance(seq, collections.Sequence) and\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 181/181 [==============================] - loss: 1.2563 - acc: 0.3254 - 8ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 1.0529 - acc: 0.5055 - 4ms/step \n", + "Eval samples: 1268\n", + "Epoch 2/100\n", + "step 181/181 [==============================] - loss: 1.1867 - acc: 0.4773 - 7ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 1.1042 - acc: 0.6569 - 5ms/step \n", + "Eval samples: 1268\n", + "Epoch 3/100\n", + "step 181/181 [==============================] - loss: 0.8814 - acc: 0.6150 - 8ms/step \n", + "Eval begin...\n", + "step 46/46 [==============================] - loss: 0.8707 - acc: 0.8021 - 4ms/step \n", + "Eval samples: 1268\n", + "Epoch 4/100\n", + "step 150/181 [=======================>......] - loss: 0.5813 - acc: 0.7350 - ETA: 0s - 7ms/step" + ] + } + ], + "source": [ + "result=[]\n", + "for train_indices, valid_indices in k_fold_indices(data.shape[0], 5):\n", + " train_data = data[train_indices]\n", + " train_label = label[train_indices]\n", + " test_data = data[valid_indices]\n", + " test_label = label[valid_indices]\n", + " train_dataset = MIDataset(train_data,train_label)\n", + " test_dataset = MIDataset(test_data,test_label)\n", + " train_loader = paddle.io.DataLoader(\n", + " train_dataset, \n", + " batch_size=28, \n", + " shuffle=True,\n", + " drop_last=True)\n", + " test_loader = paddle.io.DataLoader(\n", + " test_dataset, \n", + " batch_size=28,\n", + " shuffle=True\n", + " )\n", + " model = CNN_GRU()\n", + " loss_function = paddle.nn.CrossEntropyLoss()\n", + " optimizer = paddle.optimizer.Adam(\n", + " learning_rate=1e-4,\n", + " parameters=model.parameters())\n", + " model_handler = paddle.Model(model)\n", + " model_handler.prepare(optimizer=optimizer,\n", + " loss=loss_function,\n", + " metrics=paddle.metric.Accuracy())\n", + " model_handler.fit(train_loader,\n", + " test_loader,\n", + " epochs=100,\n", + " batch_size=28,\n", + " verbose=1)\n", + " result.append(model_handler.evaluate(test_dataset, batch_size=test_data.shape[0]))\n", + "a=0\n", + "for i in range(5):\n", + " a=result[i]['acc']/5+a\n", + "a" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "0e2cd77eede4fb534e37c318dcb84244d460c76d4a4178970183713b6b86de86" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "py35-paddle1.2.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/read_data.py b/read_data.py new file mode 100644 index 0000000..afda2db --- /dev/null +++ b/read_data.py @@ -0,0 +1,70 @@ +import mne +import scipy.io as scio +import numpy as np + +def read_data(data_path,i): + data_path = data_path + DATA = scio.loadmat(data_path) + DATA = DATA["data"][0] + DATA = DATA[i] #第i次实验,从3到8 + EEG_DATA = DATA["X"][0][0].transpose(1,0) + EEG_label = DATA['y'][0][0][:,0] + + ch_names = ['Fz','Fp1','Fp2','AF3','AF4','AF7','AF8','C3','POz','Cz','PO3','C4','PO4','PO5','PO6','PO7','PO8','Oz','O1','Pz','P6','P7','EOG-left','EOG-central','EOG-right'] + ch_types = ['eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg', + 'eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg', + 'eeg','eeg','eog','eog','eog'] + + info = mne.create_info(ch_names = ch_names, + ch_types=ch_types, + sfreq=250) + info.set_montage('standard_1020') + raw = mne.io.RawArray(EEG_DATA,info) + + n_times = DATA["trial"][0][0][:,0] #时间戳 + event = np.zeros((4,12),int) + v,b,n,m=0,0,0,0 + for i in range (0,n_times.shape[0]): + if EEG_label[i]==1: + event[0,v]=n_times[i] + v+=1 + if EEG_label[i]==2: + event[1,b]=n_times[i] + b+=1 + if EEG_label[i]==3: + event[2,n]=n_times[i] + n+=1 + if EEG_label[i]==4: + event[3,m]=n_times[i] + m+=1 + + + events = np.zeros((48,3),int) + j=0 + for i in range(events.shape[0]): + if i<12: + events[i][0]=event[0][j] + events[i][2]=1 + elif i<24: + events[i][0]=event[1][j] + events[i][2]=2 + elif i<36: + events[i][0]=event[2][j] + events[i][2]=3 + elif i<48: + events[i][0]=event[3][j] + events[i][2]=4 + j+=1 + if j>=12: + j=0 + events = sorted(events, key = lambda events: events[0]) + event_id = dict(lefthand=1,righthand=2,feet=3,tongue=4) + picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, ecg=False, + exclude='bads') + epochs = mne.Epochs(raw, events, event_id, tmin=3, tmax=6, proj=True,baseline=(None, None), picks=picks,preload=True) + + + data = epochs.get_data() + data =np.array(data).reshape((48*22,751)) + label = EEG_label.repeat(22) + return data,label