Skip to content

Commit

Permalink
release
Browse files Browse the repository at this point in the history
  • Loading branch information
Advocate99 committed Mar 15, 2023
1 parent 8600278 commit ab1147a
Show file tree
Hide file tree
Showing 30 changed files with 4,490 additions and 1 deletion.
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
*.pyc
data/
output/
events*
*.log
*.mp4
*.bin
*.txt
*.wav
*.pkl
*.jpg
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Taming Diffusion Models for Audio-Driven Co-Speech Gesture Generation (CVPR 2023)

## Code coming soon.
We will update README soon.

## Abstract

Expand Down
41 changes: 41 additions & 0 deletions config/pose_diffusion_expressive.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: pose_diffusion

train_data_path: data/ted_expressive_dataset/train
val_data_path: data/ted_expressive_dataset/val
test_data_path: data/ted_expressive_dataset/test

wordembed_dim: 300
wordembed_path: data/fasttext/crawl-300d-2M-subword.bin

model_save_path: output/train_diffusion_expressive_0.1
random_seed: -1

pose_dim: 126
diff_hidden_dim: 512
block_depth: 8

# model params
model: pose_diffusion
mean_dir_vec: [-0.0737964, -0.9968923, -0.1082858, 0.9111595, 0.2399522, -0.102547 , -0.8936886, 0.3131501, -0.1039348, 0.2093927, 0.958293 , 0.0824881, -0.1689021, -0.0353824, -0.7588258, -0.2794763, -0.2495191, -0.614666 , -0.3877234, 0.005006 , -0.5301695, -0.5098616, 0.2257808, 0.0053111, -0.2393621, -0.1022204, -0.6583039, -0.4992898, 0.1228059, -0.3292085, -0.4753748, 0.2132857, 0.1742853, -0.2062069, 0.2305175, -0.5897119, -0.5452555, 0.1303197, -0.2181693, -0.5221036, 0.1211322, 0.1337591, -0.2164441, 0.0743345, -0.6464546, -0.5284583, 0.0457585, -0.319634 , -0.5074904, 0.1537192, 0.1365934, -0.4354402, -0.3836682, -0.3850554, -0.4927187, -0.2417618, -0.3054556, -0.3556116, -0.281753 , -0.5164358, -0.3064435, 0.9284261, -0.067134 , 0.2764367, 0.006997 , -0.7365526, 0.2421269, -0.225798 , -0.6387642, 0.3788997, 0.0283412, -0.5451686, 0.5753376, 0.1935219, 0.0632555, 0.2122412, -0.0624179, -0.6755542, 0.5212831, 0.1043523, -0.345288 , 0.5443628, 0.128029 , 0.2073687, 0.2197118, 0.2821399, -0.580695 , 0.573988 , 0.0786667, -0.2133071, 0.5532452, -0.0006157, 0.1598754, 0.2093099, 0.124119, -0.6504359, 0.5465003, 0.0114155, -0.3203954, 0.5512083, 0.0489287, 0.1676814, 0.4190787, -0.4018607, -0.3912126, 0.4841548, -0.2668508, -0.3557675, 0.3416916, -0.2419564, -0.5509825, 0.0485515, -0.6343101, -0.6817347, -0.4705639, -0.6380668, 0.4641643, 0.4540192, -0.6486361, 0.4604001, -0.3256226, 0.1883097, 0.8057457, 0.3257385, 0.1292366, 0.815372]
mean_pose: [-0.0046788, -0.5397806, 0.007695 , -0.0171913, -0.7060388,-0.0107034, 0.1550734, -0.6823077, -0.0303645, -0.1514748, -0.6819547, -0.0268262, 0.2094328, -0.469447 , -0.0096073, -0.2318253, -0.4680838, -0.0444074, 0.1667382, -0.4643363, -0.1895118, -0.1648597, -0.4552845, -0.2159728, 0.1387546, -0.4859474, -0.2506667, 0.1263615, -0.4856088, -0.2675801, 0.1149031, -0.4804542, -0.267329 , 0.1414847, -0.4727709, -0.2583424, 0.1262482, -0.4686185, -0.2682536, 0.1150217, -0.4633611, -0.2640182, 0.1475897, -0.4415648, -0.2438853, 0.1367996, -0.4383164, -0.248248 , 0.1267222, -0.435534 , -0.2455436, 0.1455485, -0.4557491, -0.2521977, 0.1305471, -0.4535603, -0.2611591, 0.1184687, -0.4495366, -0.257798 , 0.1451682, -0.4802511, -0.2081622, 0.1301337, -0.4865308, -0.2175783, 0.1208341, -0.4932623, -0.2311025, -0.1409241,-0.4742868, -0.2795303, -0.1287992, -0.4724431, -0.2963172,-0.1159225, -0.4676439, -0.2948754, -0.1427748, -0.4589126,-0.2861245, -0.126862 , -0.4547355, -0.2962466, -0.1140265,-0.451308 , -0.2913815, -0.1447202, -0.4260471, -0.2697673,-0.1333492, -0.4239912, -0.2738043, -0.1226859, -0.4238346,-0.2706725, -0.1446909, -0.440342 , -0.2789209, -0.1291436,-0.4391063, -0.2876539, -0.1160435, -0.4376317, -0.2836147,-0.1441438, -0.4729031, -0.2355619, -0.1293268, -0.4793807,-0.2468831, -0.1204146, -0.4847246, -0.2613876, -0.0056085,-0.9224338, -0.1677302, -0.0352157, -0.963936 , -0.1388849,0.0236298, -0.9650772, -0.1385154, -0.0697098, -0.9514691,-0.055632 , 0.0568838, -0.9565502, -0.0567985]

hidden_size: 300
input_context: audio

classifier_free: True
null_cond_prob: 0.1

# train params
epochs: 500
batch_size: 128
learning_rate: 0.0005

# eval params
eval_net_path: output/TED_Expressive_output/AE-cos1e-3/checkpoint_best.bin

# dataset params
motion_resampling_framerate: 15
n_poses: 34
n_pre_poses: 4
subdivision_stride: 10
loader_workers: 4
41 changes: 41 additions & 0 deletions config/pose_diffusion_ted.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: pose_diffusion

train_data_path: data/ted_dataset/lmdb_train
val_data_path: data/ted_dataset/lmdb_val
test_data_path: data/ted_dataset/lmdb_test

wordembed_dim: 300
wordembed_path: data/fasttext/crawl-300d-2M-subword.bin

model_save_path: output/train_diffusion_ted
random_seed: -1

pose_dim: 27
diff_hidden_dim: 256
block_depth: 8

# model params
model: pose_diffusion
mean_dir_vec: [ 0.0154009, -0.9690125, -0.0884354, -0.0022264, -0.8655276, 0.4342174, -0.0035145, -0.8755367, -0.4121039, -0.9236511, 0.3061306, -0.0012415, -0.5155854, 0.8129665, 0.0871897, 0.2348464, 0.1846561, 0.8091402, 0.9271948, 0.2960011, -0.013189 , 0.5233978, 0.8092403, 0.0725451, -0.2037076, 0.1924306, 0.8196916]
mean_pose: [ 0.0000306, 0.0004946, 0.0008437, 0.0033759, -0.2051629, -0.0143453, 0.0031566, -0.3054764, 0.0411491, 0.0029072, -0.4254303, -0.001311 , -0.1458413, -0.1505532, -0.0138192, -0.2835603, 0.0670333, 0.0107002, -0.2280813, 0.112117 , 0.2087789, 0.1523502, -0.1521499, -0.0161503, 0.291909 , 0.0644232, 0.0040145, 0.2452035, 0.1115339, 0.2051307]

hidden_size: 300
input_context: audio

classifier_free: True
null_cond_prob: 0.1

# train params
epochs: 500
batch_size: 128
learning_rate: 0.0005

# eval params
eval_net_path: output/train_h36m_gesture_autoencoder/gesture_autoencoder_checkpoint_best.bin

# dataset params
motion_resampling_framerate: 15
n_poses: 34
n_pre_poses: 4
subdivision_stride: 10
loader_workers: 4
212 changes: 212 additions & 0 deletions scripts/data_loader/data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
""" create data samples """
from collections import defaultdict

import lmdb
import math
import numpy as np
import pyarrow

import utils.data_utils
from data_loader.motion_preprocessor import MotionPreprocessor


class DataPreprocessor:
def __init__(self, clip_lmdb_dir, out_lmdb_dir, n_poses, subdivision_stride,
pose_resampling_fps, mean_pose, mean_dir_vec, disable_filtering=False):
self.n_poses = n_poses
self.subdivision_stride = subdivision_stride
self.skeleton_resampling_fps = pose_resampling_fps
self.mean_pose = mean_pose
self.mean_dir_vec = mean_dir_vec
self.disable_filtering = disable_filtering

self.src_lmdb_env = lmdb.open(clip_lmdb_dir, readonly=True, lock=False)
with self.src_lmdb_env.begin() as txn:
self.n_videos = txn.stat()['entries']

self.spectrogram_sample_length = utils.data_utils.calc_spectrogram_length_from_motion_length(self.n_poses, self.skeleton_resampling_fps)
self.audio_sample_length = int(self.n_poses / self.skeleton_resampling_fps * 16000)

# create db for samples
map_size = 1024 * 50 # in MB
map_size <<= 20 # in B
self.dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size=map_size)
self.n_out_samples = 0

def run(self):
n_filtered_out = defaultdict(int)
src_txn = self.src_lmdb_env.begin(write=False)

# sampling and normalization
cursor = src_txn.cursor()
for key, value in cursor:
video = pyarrow.deserialize(value)
vid = video['vid']
clips = video['clips']
for clip_idx, clip in enumerate(clips):
filtered_result = self._sample_from_clip(vid, clip)
for type in filtered_result.keys():
n_filtered_out[type] += filtered_result[type]

# print stats
with self.dst_lmdb_env.begin() as txn:
print('no. of samples: ', txn.stat()['entries'])
n_total_filtered = 0
for type, n_filtered in n_filtered_out.items():
print('{}: {}'.format(type, n_filtered))
n_total_filtered += n_filtered
print('no. of excluded samples: {} ({:.1f}%)'.format(
n_total_filtered, 100 * n_total_filtered / (txn.stat()['entries'] + n_total_filtered)))

# close db
self.src_lmdb_env.close()
self.dst_lmdb_env.sync()
self.dst_lmdb_env.close()

def _sample_from_clip(self, vid, clip):
clip_skeleton = clip['skeletons_3d']
clip_audio = clip['audio_feat']
clip_audio_raw = clip['audio_raw']
clip_word_list = clip['words']
clip_s_f, clip_e_f = clip['start_frame_no'], clip['end_frame_no']
clip_s_t, clip_e_t = clip['start_time'], clip['end_time']

n_filtered_out = defaultdict(int)

# skeleton resampling
clip_skeleton = utils.data_utils.resample_pose_seq(clip_skeleton, clip_e_t - clip_s_t, self.skeleton_resampling_fps)

# divide
aux_info = []
sample_skeletons_list = []
sample_words_list = []
sample_audio_list = []
sample_spectrogram_list = []

num_subdivision = math.floor(
(len(clip_skeleton) - self.n_poses)
/ self.subdivision_stride) + 1 # floor((K - (N+M)) / S) + 1
expected_audio_length = utils.data_utils.calc_spectrogram_length_from_motion_length(len(clip_skeleton), self.skeleton_resampling_fps)
assert abs(expected_audio_length - clip_audio.shape[1]) <= 5, 'audio and skeleton lengths are different'

for i in range(num_subdivision):
start_idx = i * self.subdivision_stride
fin_idx = start_idx + self.n_poses

sample_skeletons = clip_skeleton[start_idx:fin_idx]
subdivision_start_time = clip_s_t + start_idx / self.skeleton_resampling_fps
subdivision_end_time = clip_s_t + fin_idx / self.skeleton_resampling_fps
sample_words = self.get_words_in_time_range(word_list=clip_word_list,
start_time=subdivision_start_time,
end_time=subdivision_end_time)

# spectrogram
audio_start = math.floor(start_idx / len(clip_skeleton) * clip_audio.shape[1])
audio_end = audio_start + self.spectrogram_sample_length
if audio_end > clip_audio.shape[1]: # correct size mismatch between poses and audio
# logging.info('expanding audio array, audio start={}, end={}, clip_length={}'.format(
# audio_start, audio_end, clip_audio.shape[1]))
n_padding = audio_end - clip_audio.shape[1]
padded_data = np.pad(clip_audio, ((0, 0), (0, n_padding)), mode='symmetric')
sample_spectrogram = padded_data[:, audio_start:audio_end]
else:
sample_spectrogram = clip_audio[:, audio_start:audio_end]

# raw audio
audio_start = math.floor(start_idx / len(clip_skeleton) * len(clip_audio_raw))
audio_end = audio_start + self.audio_sample_length
if audio_end > len(clip_audio_raw): # correct size mismatch between poses and audio
# logging.info('expanding audio array, audio start={}, end={}, clip_length={}'.format(
# audio_start, audio_end, len(clip_audio_raw)))
n_padding = audio_end - len(clip_audio_raw)
padded_data = np.pad(clip_audio_raw, (0, n_padding), mode='symmetric')
sample_audio = padded_data[audio_start:audio_end]
else:
sample_audio = clip_audio_raw[audio_start:audio_end]

if len(sample_words) >= 2:
# filtering motion skeleton data
sample_skeletons, filtering_message = MotionPreprocessor(sample_skeletons, self.mean_pose).get()
is_correct_motion = (sample_skeletons != [])
motion_info = {'vid': vid,
'start_frame_no': clip_s_f + start_idx,
'end_frame_no': clip_s_f + fin_idx,
'start_time': subdivision_start_time,
'end_time': subdivision_end_time,
'is_correct_motion': is_correct_motion, 'filtering_message': filtering_message}

if is_correct_motion or self.disable_filtering:
sample_skeletons_list.append(sample_skeletons)
sample_words_list.append(sample_words)
sample_audio_list.append(sample_audio)
sample_spectrogram_list.append(sample_spectrogram)
aux_info.append(motion_info)
else:
n_filtered_out[filtering_message] += 1

if len(sample_skeletons_list) > 0:
with self.dst_lmdb_env.begin(write=True) as txn:
for words, poses, audio, spectrogram, aux in zip(sample_words_list, sample_skeletons_list,
sample_audio_list, sample_spectrogram_list,
aux_info):
# preprocessing for poses
poses = np.asarray(poses)
dir_vec = utils.data_utils.convert_pose_seq_to_dir_vec(poses)
normalized_dir_vec = self.normalize_dir_vec(dir_vec, self.mean_dir_vec)

# save
k = '{:010}'.format(self.n_out_samples).encode('ascii')
v = [words, poses, normalized_dir_vec, audio, spectrogram, aux]
v = pyarrow.serialize(v).to_buffer()
txn.put(k, v)
self.n_out_samples += 1

return n_filtered_out

@staticmethod
def normalize_dir_vec(dir_vec, mean_dir_vec):
return dir_vec - mean_dir_vec

@staticmethod
def get_words_in_time_range(word_list, start_time, end_time):
words = []

for word in word_list:
_, word_s, word_e = word[0], word[1], word[2]

if word_s >= end_time:
break

if word_e <= start_time:
continue

words.append(word)

return words

@staticmethod
def unnormalize_data(normalized_data, data_mean, data_std, dimensions_to_ignore):
"""
this method is from https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/generateMotionData.py#L12
"""
T = normalized_data.shape[0]
D = data_mean.shape[0]

origData = np.zeros((T, D), dtype=np.float32)
dimensions_to_use = []
for i in range(D):
if i in dimensions_to_ignore:
continue
dimensions_to_use.append(i)
dimensions_to_use = np.array(dimensions_to_use)

origData[:, dimensions_to_use] = normalized_data

# potentially inefficient, but only done once per experiment
stdMat = data_std.reshape((1, D))
stdMat = np.repeat(stdMat, T, axis=0)
meanMat = data_mean.reshape((1, D))
meanMat = np.repeat(meanMat, T, axis=0)
origData = np.multiply(origData, stdMat) + meanMat

return origData
Loading

0 comments on commit ab1147a

Please sign in to comment.