How to use the tensorflowonspark.reservation.Server function in tensorflowonspark

To help you get started, we’ve selected a few tensorflowonspark examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github yahoo / TensorFlowOnSpark / test / test_reservation.py View on Github external
def test_reservation_enviroment_not_exists_get_server_ip_return_actual_host_ip(self):
    tfso_server = Server(5)
    assert tfso_server.get_server_ip() == util.get_ip_address()
github yahoo / TensorFlowOnSpark / test / test_reservation.py View on Github external
def test_reservation_enviroment_not_exists_start_listening_socket_return_socket(self):
    tfso_server = Server(1)
    print(tfso_server.start_listening_socket().getsockname()[1])
    assert type(tfso_server.start_listening_socket().getsockname()[1]) == int
github yahoo / TensorFlowOnSpark / test / test_reservation.py View on Github external
def test_reservation_enviroment_exists_get_server_ip_return_environment_value(self):
      tfso_server = Server(5)
      with mock.patch.dict(os.environ,{'TFOS_SERVER_HOST':'my_host_ip'}):
        assert tfso_server.get_server_ip() == "my_host_ip"
github yahoo / TensorFlowOnSpark / test / test_reservation.py View on Github external
def test_reservation_enviroment_exists_start_listening_socket_return_socket_listening_to_environment_port_value(self):
    tfso_server = Server(1)
    with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}):
      assert tfso_server.start_listening_socket().getsockname()[1] == 9999
github yahoo / TensorFlowOnSpark / test / test_reservation.py View on Github external
def test_reservation_server(self):
    """Test reservation server, expecting 1 reservation"""
    s = Server(1)
    addr = s.start()

    # add first reservation
    c = Client(addr)
    resp = c.register({'node': 1})
    self.assertEqual(resp, 'OK')

    # get list of reservations
    reservations = c.get_reservations()
    self.assertEqual(len(reservations), 1)

    # should return immediately with list of reservations
    reservations = c.await_reservations()
    self.assertEqual(len(reservations), 1)

    # request server stop
github yahoo / TensorFlowOnSpark / test / test_reservation.py View on Github external
def test_reservation_server_multi(self):
    """Test reservation server, expecting multiple reservations"""
    num_clients = 4
    s = Server(num_clients)
    addr = s.start()

    def reserve(num):
      c = Client(addr)
      # time.sleep(random.randint(0,5))     # simulate varying start times
      resp = c.register({'node': num})
      self.assertEqual(resp, 'OK')
      c.await_reservations()
      c.close()

    # start/register clients
    threads = [None] * num_clients
    for i in range(num_clients):
      threads[i] = threading.Thread(target=reserve, args=(i,))
      threads[i].start()
github yahoo / TensorFlowOnSpark / tensorflowonspark / TFCluster.py View on Github external
if num_workers > 0:
    cluster_template['worker'] = executors[:num_workers]

  logger.info("cluster_template: {}".format(cluster_template))

  # get default filesystem from spark
  defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
  # strip trailing "root" slash from "file:///" to be consistent w/ "hdfs://..."
  if defaultFS.startswith("file://") and len(defaultFS) > 7 and defaultFS.endswith("/"):
    defaultFS = defaultFS[:-1]

  # get current working dir of spark launch
  working_dir = os.getcwd()

  # start a server to listen for reservations and broadcast cluster_spec
  server = reservation.Server(num_executors)
  server_addr = server.start()

  # start TF nodes on all executors
  logger.info("Starting TensorFlow on executors")
  cluster_meta = {
    'id': random.getrandbits(64),
    'cluster_template': cluster_template,
    'num_executors': num_executors,
    'default_fs': defaultFS,
    'working_dir': working_dir,
    'server_addr': server_addr
  }
  if driver_ps_nodes:
    nodeRDD = sc.parallelize(range(num_ps, num_executors), num_executors - num_ps)
  else:
    nodeRDD = sc.parallelize(range(num_executors), num_executors)