Skip to content

Commit 912c363

Browse files
committed
Fully working version.
1 parent 67bf8e0 commit 912c363

File tree

3 files changed

+58
-25
lines changed

3 files changed

+58
-25
lines changed

donkeycar/management/kivy_ui.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from kivy.core.image import Image as CoreImage
1313
from kivy.properties import NumericProperty, ObjectProperty, StringProperty, \
1414
ListProperty, BooleanProperty
15+
from kivy.uix.label import Label
1516
from kivy.uix.popup import Popup
1617
from kivy.lang.builder import Builder
1718
from kivy.core.window import Window
@@ -109,7 +110,7 @@ def load_action(self):
109110
try:
110111
self.config = load_config(os.path.join(self.file_path, 'config.py'))
111112
rc_handler.data['car_dir'] = self.file_path
112-
train_screen().config = self.config
113+
# train_screen().config = self.config
113114
except FileNotFoundError:
114115
print(f'Directory {self.file_path} has no config.py')
115116
except Exception as e:
@@ -604,6 +605,11 @@ def on_keyboard(self, instance, keycode, scancode, key, modifiers):
604605
class ScrollableLabel(ScrollView):
605606
pass
606607

608+
609+
class DataFrameLabel(Label):
610+
pass
611+
612+
607613
class TrainScreen(Screen):
608614
config = ObjectProperty(force_dispatch=True, allownone=True)
609615

@@ -640,17 +646,29 @@ def value_list(self):
640646
return ['select']
641647

642648
def on_config(self, obj, config):
643-
if self.ids:
649+
if self.config and self.ids:
644650
self.ids.cfg_spinner.values = self.value_list()
645-
self.ids.scroll.text = self.get_database_text()
651+
if self.ids.check.state == 'down':
652+
text_df, text_tub = self.toggle_tub_df()
653+
self.ids.scroll_pilots.text = text_df
654+
self.ids.scroll_tubs.text = text_tub
655+
else:
656+
self.ids.scroll_pilots.text = self.get_database_text()
657+
self.ids.scroll_tubs.text = ''
646658

647659
def get_database_text(self):
648660
if self.config:
649661
database = PilotDatabase(self.config)
650662
df = database.to_df()
651-
return str(df)
652-
else:
653-
return ''
663+
df.drop(columns='History', inplace=True)
664+
return df.to_string()
665+
666+
def toggle_tub_df(self):
667+
if self.config:
668+
database = PilotDatabase(self.config)
669+
df, df_tub = database.to_df_tubgrouped()
670+
df.drop(columns='History', inplace=True)
671+
return df.to_string(), df_tub.to_string()
654672

655673

656674
class DonkeyApp(App):

donkeycar/management/ui.kv

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@
214214
Button:
215215
text: 'Reload Tub'
216216
on_press:
217-
app.tub_screen.ids.tub_loader.update_tub(reload=True)
217+
app.tub_screen.ids.tub_loader.update_tub()
218218

219219
<TubFilter>:
220220
id: tub_filter
@@ -420,6 +420,13 @@
420420
StatusLabel:
421421
id: status
422422

423+
<DataframeLabel>:
424+
font_name: 'data/fonts/RobotoMono-Regular.ttf'
425+
font_size: 24
426+
size_hint_y: None
427+
height: self.texture_size[1]
428+
text_size: self.width, None
429+
423430

424431
<TrainScreen>:
425432
config: app.tub_screen.ids.config_manager.config
@@ -466,26 +473,32 @@
466473
text: 'Set model type'
467474
MySpinner:
468475
id: train_spinner
476+
size_hint_x: 0.5
469477
text: 'linear'
470478
values: ['linear', 'categorical', 'tflite_linear']
471479
TextInput:
472480
id: comment
473481
multiline: False
474482
text: 'Comment'
475-
Button:
476-
id: train_button
477-
text: 'Train'
478-
on_press: root.train(train_spinner.text)
479-
ScrollableLabel:
483+
Button:
484+
id: train_button
485+
text: 'Train'
486+
on_press: root.train(train_spinner.text)
487+
BoxLayout:
488+
size_hint_y: None
489+
height: common_height
480490
Label:
481-
id: scroll
482-
size_hint_y: None
483-
height: self.texture_size[1]
484-
text_size: self.width, None
485-
Label:
486-
text: 'More stuff here'
487-
size_hint_y: 5
488-
491+
text: 'Group multiple tubs'
492+
ToggleButton:
493+
id: check
494+
on_press: root.on_config(None, None)
495+
text: 'On' if self.state == 'down' else 'Off'
496+
ScrollableLabel:
497+
DataFrameLabel:
498+
id: scroll_pilots
499+
ScrollableLabel:
500+
DataFrameLabel:
501+
id: scroll_tubs
489502
StatusLabel:
490503
id: status
491504

donkeycar/pipeline/training.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ def create_tf_data(self) -> tf.data.Dataset:
7878
return dataset.repeat().batch(self.batch_size)
7979

8080

81-
def get_model_train_details(cfg: Config, database: PilotDatabase, model: str,
82-
model_type: str) -> Tuple[str, int, str, bool]:
81+
def get_model_train_details(cfg: Config, database: PilotDatabase,
82+
model: str = None, model_type: str = None) \
83+
-> Tuple[str, int, str, bool]:
8384
if not model_type:
8485
model_type = cfg.DEFAULT_MODEL_TYPE
8586
train_type = model_type
@@ -97,8 +98,9 @@ def get_model_train_details(cfg: Config, database: PilotDatabase, model: str,
9798
return model_name, model_num, train_type, is_tflite
9899

99100

100-
def train(cfg: Config, tub_paths: str, model: str, model_type: str,
101-
transfer: str, comment: str = None) -> tf.keras.callbacks.History:
101+
def train(cfg: Config, tub_paths: str, model: str = None,
102+
model_type: str = None, transfer: str = None, comment: str = None) \
103+
-> tf.keras.callbacks.History:
102104
"""
103105
Train the model
104106
"""
@@ -155,7 +157,7 @@ def train(cfg: Config, tub_paths: str, model: str, model_type: str,
155157
'Tubs': tub_paths,
156158
'Time': time(),
157159
'History': history.history,
158-
'Transfer': transfer,
160+
'Transfer': os.path.basename(transfer) if transfer else None,
159161
'Comment': comment,
160162
}
161163
database.add_entry(database_entry)

0 commit comments

Comments
 (0)