Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
YeexiaoZheng committed Jul 14, 2022
1 parent 9ffb49f commit b4427c9
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 6 deletions.
177 changes: 176 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,177 @@
# Multimodal-Sentiment-Analysis
多模态情感分析——基于BERT+ResNet的多种融合方法
多模态情感分析——基于BERT+ResNet50的多种融合方法,数据学院人工智能课程第五次实验代码

本项目基于Hugging Face和torchvision实现,共有五种融合方法(2Naive 3Attention),在Models文件夹中查看

## Project Structure

```
|-- Multimodal-Sentiment-Analysis
|-- Config.py
|-- main.py
|-- README.md
|-- requirements.txt
|-- Trainer.py
|-- data
| |-- .DS_Store
| |-- test.json
| |-- test_without_label.txt
| |-- train.json
| |-- train.txt
| |-- data
|-- Models
| |-- CMACModel.py
| |-- HSTECModel.py
| |-- NaiveCatModel.py
| |-- NaiveCombineModel.py
| |-- OTEModel.py
| |-- __init__.py
|-- src
| |-- CrossModalityAttentionCombineModel.png
| |-- HiddenStateTransformerEncoderCombineModel.png
| |-- OutputTransformerEncoderModel.png
|-- utils
|-- common.py
|-- DataProcess.py
|-- __init__.py
|-- APIs
| |-- APIDataset.py
| |-- APIDecode.py
| |-- APIEncode.py
| |-- APIMetric.py
| |-- __init__.py
```

## Requirements

chardet==4.0.0
numpy==1.22.2
Pillow==9.2.0
scikit_learn==1.1.1
torch==1.8.2
torchvision==0.9.2
tqdm==4.63.0
transformers==4.18.0

```shell
pip install -r requirements.txt
```

## Model

两个Naive方法就不展示了

**CrossModalityAttentionCombine**

![CrossModalityAttentionCombineModel](D:\0-GitHub\Multimodal-Sentiment-Analysis\src\CrossModalityAttentionCombineModel.png)



**HiddenStateTransformerEncoder**

![HiddenStateTransformerEncoderCombineModel](D:\0-GitHub\Multimodal-Sentiment-Analysis\src\HiddenStateTransformerEncoderCombineModel.png)

**OutputTransformerEncoder**

![OutputTransformerEncoderModel](D:\0-GitHub\Multimodal-Sentiment-Analysis\src\OutputTransformerEncoderModel.png)

## Train

需下载数据集,并放在data文件夹中解压,数据集地址:(后续更新)

```shell
python main.py --do_train --epoch 10 --text_pretrained_model roberta-base --fuse_model_type OTE 单模态(--text_only --img_only)
```

fuse_model_type可选:CMAC、HSTEC、OTE、NaiveCat、NaiveCombine

text_pretrain_model可在Hugging Face上选择合适的

## Test

```shell
python main.py --do_test --text_pretrained_model roberta-base --fuse_model_type OTE --load_model_path $your_model_path$ 单模态(--text_only --img_only)
```

## Config

```python
class config:
# 根目录
root_path = os.getcwd()
data_dir = os.path.join(root_path, './data/data/')
train_data_path = os.path.join(root_path, 'data/train.json')
test_data_path = os.path.join(root_path, 'data/test.json')
output_path = os.path.join(root_path, 'output')
output_test_path = os.path.join(output_path, 'test.txt')
load_model_path = None

# 一般超参
epoch = 20
learning_rate = 3e-5
weight_decay = 0
num_labels = 3
loss_weight = [1.68, 9.3, 3.36]

# Fuse相关
fuse_model_type = 'NaiveCombine'
only = None
middle_hidden_size = 64
attention_nhead = 8
attention_dropout = 0.4
fuse_dropout = 0.5
out_hidden_size = 128

# BERT相关
fixed_text_model_params = False
bert_name = 'roberta-base'
bert_learning_rate = 5e-6
bert_dropout = 0.2

# ResNet相关
fixed_img_model_params = False
image_size = 224
fixed_image_model_params = True
resnet_learning_rate = 5e-6
resnet_dropout = 0.2
img_hidden_seq = 64


# Dataloader params
checkout_params = {'batch_size': 4, 'shuffle': False}
train_params = {'batch_size': 16, 'shuffle': True, 'num_workers': 2}
val_params = {'batch_size': 16, 'shuffle': False, 'num_workers': 2}
test_params = {'batch_size': 8, 'shuffle': False, 'num_workers': 2}

```



## Result

| Model | Acc |
| ----------------------------- | ---------- |
| NaiveCat | 71.25 |
| NaiveCombine | 73.625 |
| CrossModalityAttentionCombine | 67.1875 |
| HiddenStateTransformerEncoder | 73.125 |
| **OutputTransformerEncoder** | **74.625** |

#### 消融实验

OutputTransformerEncoderModel Result:(另一模态输入文本为空字符串或空白图片)

| Feature | Acc |
| ---------- | ------ |
| Text Only | 71.875 |
| Image Only | 63 |

## Attribution

Joint Fine-Tuning for Multimodal Sentiment Analysis:[guitld/Transfer-Learning-with-Joint-Fine-Tuning-for-Multimodal-Sentiment-Analysis: This is the code for the Paper "Guilherme L. Toledo, Ricardo M. Marcacini: Transfer Learning with Joint Fine-Tuning for Multimodal Sentiment Analysis (LXAI Research Workshop at ICML 2022)". (github.com)](https://github.com/guitld/Transfer-Learning-with-Joint-Fine-Tuning-for-Multimodal-Sentiment-Analysis)

Is cross-attention preferable to self-attention for multi-modal emotion recognition:[smartcameras/SelfCrossAttn: PyTorch implementation of the models described in the IEEE ICASSP 2022 paper "Is cross-attention preferable to self-attention for multi-modal emotion recognition?" (github.com)](https://github.com/smartcameras/SelfCrossAttn)

Multimodal_Sentiment_Analysis_With_Image-Text_Interaction_Network:[Multimodal Sentiment Analysis With Image-Text Interaction Network | IEEE Journals & Magazine | IEEE Xplore](https://ieeexplore.ieee.org/abstract/document/9736584/)

CLMLF:[Link-Li/CLMLF (github.com)](https://github.com/Link-Li/CLMLF)
7 changes: 2 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
# args
parser = argparse.ArgumentParser()
parser.add_argument('--do_train', action='store_true', help='训练模型')
parser.add_argument('--text_pretrained_model', default='bert-base-uncased', help='文本分析模型', type=str)
parser.add_argument('--fuse_model_type', default='NaiveCombine', help='融合模型类别', type=str)
parser.add_argument('--text_pretrained_model', default='roberta-base', help='文本分析模型', type=str)
parser.add_argument('--fuse_model_type', default='OTE', help='融合模型类别', type=str)
parser.add_argument('--lr', default=5e-5, help='设置学习率', type=float)
parser.add_argument('--weight_decay', default=1e-2, help='设置权重衰减', type=float)
parser.add_argument('--epoch', default=10, help='设置训练轮数', type=int)

parser.add_argument('--do_valid_only_text', action='store_true', help='验证训练集数据(仅文本)')
parser.add_argument('--do_valid_only_img', action='store_true', help='验证训练集数据(仅图像)')

parser.add_argument('--do_test', action='store_true', help='预测测试集数据')
parser.add_argument('--load_model_path', default=None, help='已经训练好的模型路径', type=str)
parser.add_argument('--text_only', action='store_true', help='仅用文本预测')
Expand Down
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
chardet==4.0.0
numpy==1.22.2
Pillow==9.2.0
scikit_learn==1.1.1
torch==1.8.2
torchvision==0.9.2
tqdm==4.63.0
transformers==4.18.0
Binary file added src/CrossModalityAttentionCombineModel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/HiddenStateTransformerEncoderCombineModel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/OutputTransformerEncoderModel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit b4427c9

Please sign in to comment.