How to use the tensorflowonspark.TFCluster function in tensorflowonspark

To help you get started, we’ve selected a few tensorflowonspark examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github yahoo / TensorFlowOnSpark / test / test_TFCluster.py View on Github external
def test_basic_tf(self):
    """Single-node TF graph (w/ args) running independently on multiple executors."""
    def _map_fun(args, ctx):
      import tensorflow as tf
      x = tf.constant(args['x'])
      y = tf.constant(args['y'])
      sum = tf.math.add(x, y)
      assert sum.numpy() == 3

    args = {'x': 1, 'y': 2}
    cluster = TFCluster.run(self.sc, _map_fun, tf_args=args, num_executors=self.num_workers, num_ps=0)
    cluster.shutdown()
github yahoo / TensorFlowOnSpark / examples / mnist / keras / mnist_mlp_estimator.py View on Github external
# for TENSORFLOW mode, each node will load/train/infer entire dataset in memory per original example
    cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, log_dir=args.model_dir, master_node='master')
    cluster.shutdown()
  else:  # 'spark'
    # for SPARK mode, just use CSV format as an example
    images = sc.textFile(args.images).map(lambda ln: [float(x) for x in ln.split(',')])
    labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
    dataRDD = images.zip(labels)
    if args.mode == 'train':
      cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='master')
      cluster.train(dataRDD, args.epochs)
      cluster.shutdown()
    else:
      # Note: using "parallel" inferencing, not "cluster"
      # each node loads the model and runs independently of others
      cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, 0, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir)
      resultRDD = cluster.inference(dataRDD)
      resultRDD.saveAsTextFile(args.output)
      cluster.shutdown()