diff --git a/rinokeras/core/v1x/train/RinokerasGraph.py b/rinokeras/core/v1x/train/RinokerasGraph.py index b12a461..b8cb2a3 100644 --- a/rinokeras/core/v1x/train/RinokerasGraph.py +++ b/rinokeras/core/v1x/train/RinokerasGraph.py @@ -2,6 +2,7 @@ from typing import Sequence, Union, Any, Optional, Dict import pickle as pkl +import h5py import tensorflow as tf from tensorflow.python.client import timeline from tqdm import tqdm @@ -159,19 +160,36 @@ def run_epoch(self, data_len: Optional[int] = None, epoch_num: Optional[int] = None, summary_writer: Optional[tf.summary.FileWriter] = None, - save_outputs: Optional[str] = None) -> MetricsAccumulator: - all_outputs = [] + save_outputs: Optional[str] = None, + save_format: Optional[str] = 'pkl') -> MetricsAccumulator: + + if save_format == 'pkl': + all_outputs = [] + elif save_format == 'h5': + h5_outf = h5py.File(save_outputs, 'w') + i = 0 + else: + raise Exception('Unsupported save format: {}'.format(save_format)) with self.add_progress_bar(data_len, epoch_num).initialize(): assert self.epoch_metrics is not None while True: if save_outputs is not None: loss, outputs = self.run('default', return_outputs=True) - all_outputs.append(outputs) + if save_format == 'pkl': + all_outputs.append(outputs) + elif save_format == 'h5': + grp = h5_outf.create_group(str(i)) + outputs = outputs[0] # can we rely on this being a tuple of length 1? + for key in outputs.keys(): + grp.create_dataset(key, data=outputs[key]) + i += 1 else: self.run('default') - if save_outputs is not None: + if save_format == 'h5': + h5_outf.close() + if save_outputs is not None and save_format == 'pkl': with open(save_outputs, 'wb') as f: pkl.dump(all_outputs, f)