-
Notifications
You must be signed in to change notification settings - Fork 0
/
process_data.py
49 lines (35 loc) · 1.37 KB
/
process_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import random
from readers import read_dfs
import pandas as pd
import sys
import numpy as np
import matplotlib.pyplot as plt
import os
exp_dir = "./processed_data/"
atp, uap, uf = read_dfs()
tags_counts = (atp.groupby("tagID")
.count()[["artistID"]]
.sort_values("artistID"))
admissable_tag_ids = tags_counts.iloc[-1000:, :].index
_atp = atp[atp["tagID"].isin(admissable_tag_ids)]
artist_vectors = (
pd.get_dummies(
_atp[["tagValue", "name"]],
columns=["tagValue"])
.groupby("name")
.sum(numeric_only=True))
user_artist_counts = uap.pivot_table(
index="userID", columns="name", values="weight", aggfunc="sum").fillna(0)
normalized_plays = (user_artist_counts /
user_artist_counts.values.sum(axis=1)[:, None])
popular_artists = uap.groupby("name").sum(numeric_only=True
)["weight"].sort_values(ascending=True)/1000
user_play_pair = (uap.pivot_table(
index="userID", columns="name", values="weight", aggfunc=sum).fillna(0) > 0).astype(int)
# Clean up string encodings
artist_vectors.index = artist_vectors.index.str.encode(
"latin-1").str.decode("utf-8")
user_play_pair.columns = user_play_pair.columns.str.encode(
"latin-1").str.decode("utf-8")
artist_vectors.to_csv(exp_dir + "artist_vectors.csv")
user_play_pair.to_csv(exp_dir + "user_play_pair.csv")