伍佰目录 短网址
  当前位置:海洋目录网 » 站长资讯 » 站长资讯 » 文章详细 订阅RssFeed

ValueError:GraphDef cannot be larger than 2GB.解决办法

来源:本站原创 浏览:139次 时间:2021-09-03

在使用TensorFlow 1.X版本的estimator的时候经常会碰到类似于ValueError:GraphDef cannot be larger than 2GB的报错信息,可能的原因是数据太大无法写入graph。

一般来说,常见的数据构建方法如下:

def input_fn():  features, labels = (np.random.sample((100,2)), np.random.sample((100,1)))  dataset = tf.data.Dataset.from_tensor_slices((features,labels))  dataset = dataset.shuffle(100000).repeat().batch(batch_size)  return dataset...estimator.train(input_fn)

TensorFlow在读取数据的时候会将数据也写入Graph,所以当数据量很大的时候会碰到这种情况,之前做实验在多GPU的时候也会遇到这种情况,即使我把batch size调到很低。所以解决办法有两种思路,一直不保存graph,而是使用feed_dict的方式来构建input pipeline。

不写入graph

我的代码环境是TensorFlow1.14,所以我以这个版本为例进行介绍。

首先总结一下estimator的运行原理(假设在单卡情况下),以estimator.train为例(eval和predict类似),其调用顺序如下:

  1. estimator.train->_train_model

  2. _train_model->_train_model_default

  3. _train_model_default->_train_with_estimator_spec

  4. _train_with_estimator_spec->MonitoredTrainingSession

class Estimator():...def train():...loss = self._train_model(input_fn, hooks, saving_listeners)...def _train_model(self, input_fn, hooks, saving_listeners):if self._train_distribution:return self._train_model_distributed(input_fn, hooks, saving_listeners)else:return self._train_model_default(input_fn, hooks, saving_listeners)  def _train_model_default(self, input_fn, hooks, saving_listeners):...return self._train_with_estimator_spec(estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners) def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners):....with training.MonitoredTrainingSession(master=self._config.master,is_chief=self._config.is_chief,checkpoint_dir=self._model_dir,scaffold=estimator_spec.scaffold,hooks=worker_hooks,chief_only_hooks=(tuple(chief_hooks) +  tuple(estimator_spec.training_chief_hooks)),save_checkpoint_secs=0,  # Saving is handled by a hook.save_summaries_steps=save_summary_steps,config=self._session_config,max_wait_secs=self._config.session_creation_timeout_secs,log_step_count_steps=log_step_count_steps) as mon_sess:

单步调试后发现,estimator写入event文件发生在调用MonitoredTrainingSession的时刻,而真正写入event是在执行hook的时候,例如在我的实验中我设置了log_step_count_steps这个值,这个值会每隔指定次数steps就会打印出计算速度和当前的loss值。而实现这一功能的是StepCounterHook,它定义在tensorflow/tensorflow/python/training/basic_session_run_hooks.py中,部分定义如下:

class StepCounterHook(session_run_hook.SessionRunHook):  """Hook that counts steps per second."""  def __init__(...):  ...    self._summary_writer = summary_writer  def begin(self):    if self._summary_writer is None and self._output_dir:      self._summary_writer = SummaryWriterCache.get(self._output_dir)    self._summary_tag = training_util.get_global_step().op.name + "/sec"  def before_run(self, run_context):  # pylint: disable=unused-argument    return SessionRunArgs(self._global_step_tensor)  def _log_and_record(self, elapsed_steps, elapsed_time, global_step):    steps_per_sec = elapsed_steps / elapsed_time    if self._summary_writer is not None:      summary = Summary(value=[          Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)      ])      self._summary_writer.add_summary(summary, global_step)    logging.info("%s: %g", self._summary_tag, steps_per_sec)

所以我们只需要将出现类似于self._summary_writer.add_summary的地方注释掉,这样estimator在运行过程中就不会再生成event文件,也就不会有2GB的问题了。

feed_dict

为了在大数据量时使用 dataset,我们可以用 placeholder 创建 dataset。这时数据就不会直接写到 graph 中,graph 中只有一个 placeholder 占位符。但是,用了 placeholder 就需要我们在一开始对它进行初始化填数据,需要调用 sess.run(iter.initializer, feed_dict={ x: data })

但是estimator并没有显示的session可以调用,那应该怎么办呢?其实我们可以使用SessionRunHook来解决这个问题。tf.train.SessionRunHook()类定义在tensorflow/python/trainin����Ա,����Աg/session_run_hook.py,该类的具体介绍可参见【转】tf.SessionRunHook使用方法。

仔细看一下 estimator 的 train 和 evaluate 函数定义可以发现它们都接收 hooks 参数,这个参数的定义是:List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop. 也就是说我们可以自己定义一个SessionRunHook作为参数传递到hook就可以了。

train(    input_fn,    hooks=None,    steps=None,    max_steps=None,    saving_listeners=None)

我们现在想要在训练之前初始化 dataset 的 placeholder,那么我们就应该具体实现 SessionRunHook 的after_create_session 成员函数:

class IteratorInitializerHook(tf.train.SessionRunHook):   def __init__(self):       super(IteratorInitializerHook, self).__init__()       self.iterator_initializer_fn = None   def after_create_session(self, session, coord):       del coord       self.iterator_initializer_fn(session)def make_input_fn():   iterator_initializer_hook = IteratorInitializerHook()   def input_fn():       x = tf.placeholder(tf.float32, shape=[None,2])       dataset = tf.data.Dataset.from_tensor_slices(x)       dataset = dataset.shuffle(100000).repeat().batch(batch_size)       iter = dataset.make_initializable_iterator()       data = np.random.sample((100,2))       iterator_initializer_hook.iterator_initializer_fn = (           lambda sess: sess.run(iter.initializer, feed_dict={x: data})       )       return iter.get_next()   return input_fn, iterator_initializer_hook...input_fn, iterator_initializer_hook = make_input_fn()estimator.train(input_fn, hooks=[iterator_initializer_hook])
参考
  • tf.train.SessionRunHook 让 estimator 训练过程可以个性化定制
  • Hook? tf.train.SessionRunHook()介绍【精】


MARSGGBO♥原创





2019-10-21 11:04:22



  推荐站点

  • At-lib分类目录At-lib分类目录

    At-lib网站分类目录汇集全国所有高质量网站,是中国权威的中文网站分类目录,给站长提供免费网址目录提交收录和推荐最新最全的优秀网站大全是名站导航之家

    www.at-lib.cn
  • 中国链接目录中国链接目录

    中国链接目录简称链接目录,是收录优秀网站和淘宝网店的网站分类目录,为您提供优质的网址导航服务,也是网店进行收录推广,站长免费推广网站、加快百度收录、增加友情链接和网站外链的平台。

    www.cnlink.org
  • 35目录网35目录网

    35目录免费收录各类优秀网站,全力打造互动式网站目录,提供网站分类目录检索,关键字搜索功能。欢迎您向35目录推荐、提交优秀网站。

    www.35mulu.com
  • 就要爱网站目录就要爱网站目录

    就要爱网站目录,按主题和类别列出网站。所有提交的网站都经过人工审查,确保质量和无垃圾邮件的结果。

    www.912219.com
  • 伍佰目录伍佰目录

    伍佰网站目录免费收录各类优秀网站,全力打造互动式网站目录,提供网站分类目录检索,关键字搜索功能。欢迎您向伍佰目录推荐、提交优秀网站。

    www.wbwb.net