Skip to content

Commit

Permalink
change trial_data_messages_to_dict to use trials_to_dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
taochao committed Jun 27, 2021
1 parent 9ac11cb commit 1ad6f87
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions mnl_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from absl import logging
from absl import flags

from banditpylib import trials_to_dataframe
from banditpylib.bandits import Bandit, MNLBandit, CvarReward, \
MeanReward
from banditpylib.data_pb2 import Trial
from banditpylib.protocols import Protocol, trial_data_messages_to_dict
from banditpylib.protocols import Protocol
from banditpylib.learners.mnl_bandit_learner import Learner, UCB, \
ThompsonSampling

Expand Down Expand Up @@ -294,7 +295,7 @@ def generate_data_with_cvar(params_filename,


def make_figure_using_cvar(data_filename, figure_filename):
data_df = trial_data_messages_to_dict(data_filename)
data_df = trials_to_dataframe(data_filename)
ax = sns.lineplot(x='total_actions', y='other', hue='learner', data=data_df)
ax.xaxis.get_offset_text().set_fontsize(FONT_SIZE)
plt.xlabel(r'$t$', fontweight='bold', fontsize=FONT_SIZE)
Expand All @@ -306,7 +307,7 @@ def make_figure_using_cvar(data_filename, figure_filename):


def make_figure(data_filename, figure_filename):
data_df = trial_data_messages_to_dict(data_filename)
data_df = trials_to_dataframe(data_filename)
sns.lineplot(x='total_actions', y='regret', hue='learner', data=data_df)
plt.savefig(figure_filename, format='pdf')

Expand All @@ -316,7 +317,7 @@ def make_figure_with_worst_regret():
for filename in os.listdir(os.path.join(os.getcwd(), 'arxiv')):
# read all data files
if 'data' in filename:
trials = trial_data_messages_to_dict(os.path.join('arxiv', filename))
trials = trials_to_dataframe(os.path.join('arxiv', filename))
data_df = data_df.append(pd.DataFrame.from_dict(trials)[[
'learner', 'total_actions', 'regret'
]].groupby(['learner', 'total_actions']).mean().reset_index(),
Expand Down

0 comments on commit 1ad6f87

Please sign in to comment.