How to use the @tensorflow/tfjs-node-gpu.train function in @tensorflow/tfjs-node-gpu

To help you get started, we’ve selected a few @tensorflow/tfjs-node-gpu examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tensorflow / tfjs-examples / mnist-acgan / gan.js View on Github external
tf = require('@tensorflow/tfjs-node-gpu');
  } else {
    console.log('Using CPU');
    tf = require('@tensorflow/tfjs-node');
  }

  if (!fs.existsSync(path.dirname(args.generatorSavePath))) {
    fs.mkdirSync(path.dirname(args.generatorSavePath));
  }
  const saveURL = `file://${args.generatorSavePath}`;
  const metadataPath = path.join(args.generatorSavePath, 'acgan-metadata.json');

  // Build the discriminator.
  const discriminator = buildDiscriminator();
  discriminator.compile({
    optimizer: tf.train.adam(args.learningRate, args.adamBeta1),
    loss: ['binaryCrossentropy', 'sparseCategoricalCrossentropy']
  });
  discriminator.summary();

  // Build the generator.
  const generator = buildGenerator(args.latentSize);
  generator.summary();

  const optimizer = tf.train.adam(args.learningRate, args.adamBeta1);
  const combined = buildCombinedModel(
      args.latentSize, generator, discriminator, optimizer);

  await data.loadData();
  let {images: xTrain, labels: yTrain} = data.getTrainData();
  yTrain = tf.expandDims(yTrain.argMax(-1), -1);
github tensorflow / tfjs-examples / mnist-acgan / gan.js View on Github external
const saveURL = `file://${args.generatorSavePath}`;
  const metadataPath = path.join(args.generatorSavePath, 'acgan-metadata.json');

  // Build the discriminator.
  const discriminator = buildDiscriminator();
  discriminator.compile({
    optimizer: tf.train.adam(args.learningRate, args.adamBeta1),
    loss: ['binaryCrossentropy', 'sparseCategoricalCrossentropy']
  });
  discriminator.summary();

  // Build the generator.
  const generator = buildGenerator(args.latentSize);
  generator.summary();

  const optimizer = tf.train.adam(args.learningRate, args.adamBeta1);
  const combined = buildCombinedModel(
      args.latentSize, generator, discriminator, optimizer);

  await data.loadData();
  let {images: xTrain, labels: yTrain} = data.getTrainData();
  yTrain = tf.expandDims(yTrain.argMax(-1), -1);

  // Save the generator model once before starting the training.
  await generator.save(saveURL);

  let numTensors;
  let logWriter;
  if (args.logDir) {
    console.log(`Logging to tensorboard at logdir: ${args.logDir}`);
    logWriter = tf.node.summaryFileWriter(args.logDir);
  }
github bobiblazeski / js-gym / dist / agents.gpu.js View on Github external
minBufferSize=MIN_BUFFER_SIZE, updateEvery=UPDATE_EVERY,
      bufferSize=BUFFER_SIZE, batchSize=BATCH_SIZE} = {},
      buffer) {
    this.epsilon = epsilon;
    this.epsilonDecay = epsilonDecay;
    this.minEpsilon = minEpsilon;
    this.minBufferSize = minBufferSize;
    this.updateEvery = updateEvery;
    this.noise = new OUNoise(actionSize);
    this.buffer = buffer || new ReplayBuffer(bufferSize, batchSize);

    this.actor = makeActor();
    this.actorTarget = makeActor();
    this.critic =makeCritic();
    this.criticTarget =makeCritic();
    this.actorOptimizer = tf.train.adam(lrActor);
    this.criticOptimizer = tf.train.adam(lrCritic);

    hardUpdate(this.actor, this.actorTarget);
    hardUpdate(this.critic, this.criticTarget);      
  }
github bobiblazeski / js-gym / dist / agents.gpu.js View on Github external
bufferSize=BUFFER_SIZE, batchSize=BATCH_SIZE} = {},
      buffer) {
    this.epsilon = epsilon;
    this.epsilonDecay = epsilonDecay;
    this.minEpsilon = minEpsilon;
    this.minBufferSize = minBufferSize;
    this.updateEvery = updateEvery;
    this.noise = new OUNoise(actionSize);
    this.buffer = buffer || new ReplayBuffer(bufferSize, batchSize);

    this.actor = makeActor();
    this.actorTarget = makeActor();
    this.critic =makeCritic();
    this.criticTarget =makeCritic();
    this.actorOptimizer = tf.train.adam(lrActor);
    this.criticOptimizer = tf.train.adam(lrCritic);

    hardUpdate(this.actor, this.actorTarget);
    hardUpdate(this.critic, this.criticTarget);      
  }