How to use the tensorflowonspark.TFCluster.run function in tensorflowonspark

To help you get started, we’ve selected a few tensorflowonspark examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github yahoo / TensorFlowOnSpark / examples / mnist / tf / mnist_spark.py View on Github external
parser.add_argument("--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr")
parser.add_argument("--images_labels", help="HDFS path to MNIST image_label files in parallelized format")
parser.add_argument("--mode", help="train|inference", default="train")
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("--num_ps", help="number of ps nodes", default=1)
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("--rdma", help="use rdma connection", default=False)
parser.add_argument("--readers", help="number of reader/enqueue threads per worker", type=int, default=10)
parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
args = parser.parse_args()
print("args:", args)

print("{0} ===== Start".format(datetime.now().isoformat()))
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, args.num_ps, args.tensorboard,
                        TFCluster.InputMode.TENSORFLOW, driver_ps_nodes=args.driver_ps_nodes)
cluster.shutdown()
print("{0} ===== Stop".format(datetime.now().isoformat()))
github yahoo / TensorFlowOnSpark / examples / slim / train_image_classifier.py View on Github external
if __name__ == '__main__':
  import argparse

  sc = SparkContext(conf=SparkConf().setAppName("train_image_classifier"))
  executors = sc._conf.get("spark.executor.instances")
  num_executors = int(executors) if executors is not None else 1

  parser = argparse.ArgumentParser()
  parser.add_argument("--num_ps_tasks", help="number of PS nodes", type=int, default=0)
  parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
  parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
  (args,rem) = parser.parse_known_args()

  assert(num_executors > args.num_ps_tasks)
  cluster = TFCluster.run(sc, main_fun, sys.argv, args.cluster_size, args.num_ps_tasks, args.tensorboard, TFCluster.InputMode.TENSORFLOW)
  cluster.shutdown()
github yahoo / TensorFlowOnSpark / examples / segmentation / segmentation_spark.py View on Github external
executors = sc._conf.get("spark.executor.instances")
  num_executors = int(executors) if executors is not None else 1

  parser = argparse.ArgumentParser()
  parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64)
  parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=1000)
  parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
  parser.add_argument("--epochs", help="number of epochs", type=int, default=3)
  parser.add_argument("--model_dir", help="path to save model/checkpoint", default="segmentation_model")
  parser.add_argument("--export_dir", help="path to export saved_model", default="segmentation_export")
  parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")

  args = parser.parse_args()
  print("args:", args)

  cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief')
  cluster.shutdown(grace_secs=30)
github yahoo / TensorFlowOnSpark / examples / mnist / tf / mnist_spark_dataset.py View on Github external
parser.add_argument("--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr")
parser.add_argument("--images_labels", help="HDFS path to MNIST image_label files in parallelized format")
parser.add_argument("--mode", help="train|inference", default="train")
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("--num_ps", help="number of ps nodes", default=1)
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("--rdma", help="use rdma connection", default=False)
parser.add_argument("--readers", help="number of reader/enqueue threads per worker", type=int, default=10)
parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
args = parser.parse_args()
print("args:", args)

print("{0} ===== Start".format(datetime.now().isoformat()))
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, args.num_ps, args.tensorboard,
                        TFCluster.InputMode.TENSORFLOW, driver_ps_nodes=args.driver_ps_nodes)
cluster.shutdown()
print("{0} ===== Stop".format(datetime.now().isoformat()))
github yahoo / TensorFlowOnSpark / examples / imagenet / inception / imagenet_eval.py View on Github external
assert dataset.data_files()
  if tf.gfile.Exists(FLAGS.eval_dir):
    tf.gfile.DeleteRecursively(FLAGS.eval_dir)
  tf.gfile.MakeDirs(FLAGS.eval_dir)

  cluster_spec, server = TFNode.start_cluster_server(ctx)

  inception_eval.evaluate(dataset)


if __name__ == '__main__':
  sc = SparkContext(conf=SparkConf().setAppName("grid_imagenet_eval"))
  num_executors = int(sc._conf.get("spark.executor.instances"))
  num_ps = 0

  cluster = TFCluster.run(sc, main_fun, sys.argv, num_executors, num_ps, False, TFCluster.InputMode.TENSORFLOW)
  cluster.shutdown()