Skip to content

Commit

Permalink
[MuseCoco] Update data processing code
Browse files Browse the repository at this point in the history
  • Loading branch information
btyu committed Sep 21, 2023
1 parent d967a51 commit 5b3890d
Show file tree
Hide file tree
Showing 51 changed files with 5,220 additions and 11 deletions.
1 change: 0 additions & 1 deletion musecoco/2-attribute2music_dataprepare/config.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
midi_data_extractor_path = 'MidiDataExtractor' # path to MidiDataExtractor
attribute_list = ['I1s2', 'R1', 'R3', 'S2s1', 'S4', 'B1s1', 'TS1s1', 'K1', 'T1s1', 'P4', 'EM1', 'TM1']
4 changes: 1 addition & 3 deletions musecoco/2-attribute2music_dataprepare/extract_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import argparse
import os
import sys
from tqdm.auto import tqdm
import msgpack
import json

from file_list import generate_file_list
from config import midi_data_extractor_path, attribute_list
from config import attribute_list

sys.path.append(os.path.abspath(midi_data_extractor_path))
import midi_data_extractor as mde
# import midiprocessor as mp

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .data_extractor import DataExtractor
from . import attribute_unit
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import importlib
from .unit_base import UnitBase


def load_unit_class(attribute_label):
unit_file_label = []
for letter in attribute_label:
if letter in '0123456789':
break
else:
unit_file_label.append(letter.lower())
unit_file_label = ''.join(unit_file_label)
module = importlib.import_module(
'.attribute_unit.unit_%s' % unit_file_label, package='midi_data_extractor'
)
unit_cls = getattr(module, 'Unit%s' % attribute_label)
return unit_cls


def load_raw_unit_class(raw_attribute_label):
unit_file_label = []
for letter in raw_attribute_label:
if letter in '0123456789':
break
else:
unit_file_label.append(letter.lower())
unit_file_label = ''.join(unit_file_label)
module = importlib.import_module(
'.attribute_unit.raw_unit_%s' % unit_file_label, package='midi_data_extractor'
)
unit_cls = getattr(module, 'RawUnit%s' % raw_attribute_label)
return unit_cls


def convert_value_into_unit(attribute_label, attribute_value, encoder=None):
unit_cls = load_unit_class(attribute_label)
unit = unit_cls(attribute_value, encoder=encoder)
return unit


def convert_value_dict_into_unit_dict(value_dict, encoder=None):
unit_dict = {}
for attr_label in value_dict:
unit_dict[attr_label] = convert_value_into_unit(attr_label, value_dict[attr_label], encoder=encoder)
return unit_dict
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .raw_unit_base import RawUnitBase


class RawUnitB1(RawUnitBase):
"""
抽取bar的个数
"""

@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
int,bar个数
"""
return bar_end - bar_begin
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from abc import ABC


class RawUnitBase(ABC):

@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
从函数输入的内容中获取该attribute的信息,返回的信息应等于需要的信息,或者是需要信息的超集。
你重写的函数里面,应该写清楚输出信息的格式和内容
:param encoder: mp.MidiEncoder实例
:param midi_dir: 数据集的全路径
:param midi_path: MIDI相对于数据集路径的相对路径
:param pos_info: pos_info,对于每个小节开头位置,都补齐了ts和tempo,方便使用
:param bars_positions: dict,小节在pos_info中的开始和结束位置
:param bars_chords: 小节序列的和弦信息,每个小节给两个bar。有可能为None,此时对于此MIDI无法抽取chord信息。
:param bars_insts: 每个小节所用到的instrument id,列表,每个item是set
:param bar_begin: 现在要抽取的信息的开始小节(从0开始)
:param bar_end: 现在要抽取的信息的结束小节(不含)
:param kwargs: 其他信息,默认为空字典
:return:
"""
raise NotImplementedError

@classmethod
def repr_value(cls, value):
return value

@classmethod
def derepr_value(cls, rep_value):
return rep_value


class RawUnitForExistedValue(RawUnitBase):
@classmethod
def get_fields(cls):
raise NotImplementedError

@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
fields = cls.get_fields()
if isinstance(fields, str):
fields = (fields,)

r = {}
for field in fields:
r[field] = kwargs[field] if field in kwargs else None
return r
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from .raw_unit_base import RawUnitBase


class RawUnitC1(RawUnitBase):
"""
段落的chord序列,每小节给两个chord
"""

@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- list:段落的chord序列,每小节给两个chord
当MIDI的和弦因为某些问题无法检测时,返回None
"""
if bars_chords is None:
return None

num_bars = len(bars_positions)
assert num_bars * 2 == len(bars_chords)
seg_bars_chords = bars_chords[bar_begin * 2 : bar_end * 2]

return seg_bars_chords
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .raw_unit_base import RawUnitForExistedValue


class RawUnitEM1(RawUnitForExistedValue):
@classmethod
def get_fields(cls):
return 'emotion'
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from .raw_unit_base import RawUnitBase


class RawUnitI1(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path, pos_info, bars_positions, bars_chords, bars_insts,
bar_begin, bar_end, **kwargs
):
"""
抽取使用的乐器。
:return:
- tuple,使用到的乐器的ID。无None。
"""
insts = set()
for bar_insts in bars_insts[bar_begin: bar_end]:
for inst_id in bar_insts:
insts.add(inst_id)

insts = tuple(insts)
return insts


class RawUnitI2(RawUnitBase):
"""
- tuple, 前半段使用的乐器,当bar数量为非正偶数的时候返回None
- tuple,后半段使用的乐器,当bar数量为非正偶数的时候返回None
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path, pos_info, bars_positions, bars_chords, bars_insts,
bar_begin, bar_end, **kwargs
):
num_bars = bar_end - bar_begin
if num_bars <= 0 or num_bars % 2 == 1:
return None, None

left_insts = set()
right_insts = set()
for bar_insts in bars_insts[bar_begin: bar_begin + num_bars // 2]:
for inst_id in bar_insts:
left_insts.add(inst_id)
for bar_insts in bars_insts[bar_begin + num_bars // 2: bar_end]:
for inst_id in bar_insts:
right_insts.add(inst_id)

left_insts = tuple(left_insts)
right_insts = tuple(right_insts)

return left_insts, right_insts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from .raw_unit_base import RawUnitBase


class RawUnitK1(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
大调或小调
:return:
- str: major为大调,minor为小调。可能为None,表示不知道。
"""
r = None
if 'is_major' in kwargs:
is_major = kwargs['is_major']
if is_major is True:
r = 'major'
elif is_major is False:
r = 'minor'
elif is_major is None:
r = None
else:
raise ValueError('is_major argument is set to a wrong value:', is_major)
return r
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .raw_unit_base import RawUnitBase


class RawUnitM1(RawUnitBase):
"""
各轨的SSM
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- dict: key是inst_id
value是dict, key为(i, j)表示bar i和bar j(仅包含j < i的情况),value为两bar之间的相似性
"""
ssm = kwargs['ssm']
r = {}
for inst_id in ssm:
r[inst_id] = ssm[bar_begin: bar_end, bar_begin: bar_end]
return r
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from .raw_unit_base import RawUnitBase

from ..utils.data import convert_dict_key_to_str, convert_dict_key_to_int


class RawUnitN2(RawUnitBase):
"""
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path, pos_info, bars_positions, bars_chords, bars_insts,
bar_begin, bar_end, **kwargs
):
"""
:return:
- dict, 各乐器的音符数量
"""
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]

num_note_record = {}
for idx in range(begin, end):
pos_item = pos_info[idx]
insts_notes = pos_item[4]
if insts_notes is None:
continue
for inst_id in insts_notes:
inst_notes = insts_notes[inst_id]
if inst_id not in num_note_record:
num_note_record[inst_id] = 0
num_note_record[inst_id] += len(inst_notes)

return num_note_record

@classmethod
def repr_value(cls, value):
return convert_dict_key_to_str(value)

@classmethod
def derepr_value(cls, rep_value):
return convert_dict_key_to_int(rep_value)

Loading

0 comments on commit 5b3890d

Please sign in to comment.