#!/usr/bin/env python
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Trainer for T2T models.

This binary perform training, evaluation, and inference using
the Estimator API with tf.learn Experiment objects.

To train your model, for example:
  t2t-trainer \
      --data_dir ~/data \
      --problems=algorithmic_identity_binary40 \
      --model=transformer
      --hparams_set=transformer_base
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

# Dependency imports

from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_utils
from tensor2tensor.utils import usr_dir

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

# See trainer_utils.py for additional command-line flags.
flags.DEFINE_string("t2t_usr_dir", "",
                    "Path to a Python module that will be imported. The "
                    "__init__.py file should include the necessary imports. "
                    "The imported files should contain registrations, "
                    "e.g. @registry.register_model calls, that will then be "
                    "available to the t2t-trainer.")
flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
                    "Temporary storage directory.")
flags.DEFINE_bool("generate_data", False, "Generate data before training?")

flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.")
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("master", "", "Address of TensorFlow master.")
flags.DEFINE_string("schedule", "train_and_evaluate",
                    "Method of tf.contrib.learn.Experiment to run.")
flags.DEFINE_bool("profile", False, "Profile performance?")


def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  trainer_utils.log_registry()
  trainer_utils.validate_flags()
  output_dir = os.path.expanduser(FLAGS.output_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
  if not FLAGS.data_dir:
    raise ValueError("You must specify a --data_dir")
  data_dir = os.path.expanduser(FLAGS.data_dir)
  tf.gfile.MakeDirs(output_dir)

  # Generate data if requested.
  if FLAGS.generate_data:
    tf.gfile.MakeDirs(data_dir)
    tf.gfile.MakeDirs(tmp_dir)
    for problem_name in FLAGS.problems.split("-"):
      tf.logging.info("Generating data for %s" % problem_name)
      problem = registry.problem(problem_name)
      problem.generate_data(data_dir, tmp_dir)

  # Run the trainer.
  def run_experiment():
    trainer_utils.run(
        data_dir=data_dir,
        model=FLAGS.model,
        output_dir=output_dir,
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        schedule=FLAGS.schedule)

  if FLAGS.profile:
    with tf.contrib.tfprof.ProfileContext("t2tprof",
                                          trace_steps=range(100),
                                          dump_steps=range(100)) as pctx:
      opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
      pctx.add_auto_profiling("op", opts, range(100))
      run_experiment()
  else:
    run_experiment()


if __name__ == "__main__":
  tf.app.run()
