How to use the @tensorflow/tfjs.oneHot function in @tensorflow/tfjs

To help you get started, we’ve selected a few @tensorflow/tfjs examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github charliegerard / gestures-ml-js / arduino-mkr1000 / training-old-hp.js View on Github external
for (let i = 0; i < numExamples; ++i) {
    shuffledData.push(data[indices[i]]);
    shuffledTargets.push(targets[indices[i]]);
  }

  // Split the data into a training set and a tet set, based on `testSplit`.
  const numTestExamples = Math.round(numExamples * testSplit);
  const numTrainExamples = numExamples - numTestExamples;

  const xDims = shuffledData[0].length;

  const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);

  // Create a 1D `tf.Tensor` to hold the labels, and convert the number label
  // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
  const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), numClasses);

  const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
  const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
  const yTrain = ys.slice([0, 0], [numTrainExamples, numClasses]);
  const yTest = ys.slice([0, 0], [numTestExamples, numClasses]);
  return [xTrain, yTrain, xTest, yTest];
}
github tensorflow / tfjs-examples / iris-fitDataset / data.js View on Github external
export function flatOneHot(idx) {
  // TODO(bileschi): Remove 'Array.from' from here once tf.data supports typed
  // arrays https://github.com/tensorflow/tfjs/issues/1041
  // TODO(bileschi): Remove '.dataSync()' from here once tf.data supports
  // datasets built from tensors.
  // https://github.com/tensorflow/tfjs/issues/1046
  return Array.from(tf.oneHot([idx], 3).dataSync());
}
github PAIR-code / federated-learning / experiments / mnist / experiment_mnist_transfer_learning.ts View on Github external
optimizer.minimize(() => loss(img, tf.oneHot(label, 10).toFloat()));
    sync.numExamples += batchSize;
    i++;
    j++;

    if (j % syncEvery) {
      continue;
    }
    await new Promise((res, rej) => setTimeout(res(), wait));

    try {
      await sync.uploadVars();

      wait = 100 + 50 * Math.random();
      log('up sync', i, 'batch loss',
          loss(img, tf.oneHot(label, 10).toFloat()).mean().dataSync()[0]);
    } catch (exn) {
      wait = wait * 2.0;  // exp backoff
      j--;                // try again next iter
      log('timeout', exn);
    }
  }
  // process any pending updates
  await new Promise((res, rej) => setTimeout(res(), 50));
  log('done, evaluating final loss');
  done = true;
  const evalRes = evaluate();
  log('final loss', evalRes[0], 'init loss:', preEvalRes[0]);
  // sync.dispose();
  return;
}
github zxch3n / PomodoroLogger / src / main / learner / learner.ts View on Github external
{
        batchSize = 32,
        epochs = 20,
        callback = undefined
    }: { batchSize?: number; epochs?: number; callback?: (epoch: number, log?: any) => void } = {}
) {
    const [embeddings, projects] = await embeddingTitles(pairs);
    const [projectEncoding, invertEncode] = oneHotEncode(projects);
    const outputSize = Object.values(projectEncoding).length;
    console.log('emb', embeddings.shape);
    console.log('projects', projects.length);
    console.log('Encode', projectEncoding);
    console.log('outputSize', outputSize);
    const model = await createModel(embeddings.shape[1], outputSize, 1, 100);
    const projectsTensor = tf.tensor1d(projects.map(p => projectEncoding[p]), 'int32');
    const oneHotTensor = tf.oneHot(projectsTensor, outputSize);
    await trainModel({
        model,
        batchSize,
        epochs,
        callback,
        input: embeddings,
        labels: oneHotTensor,
        shuffle: true
    });
    return {
        model,
        projectEncoding,
        invertEncode
    };
}
github thekevinscott / ml-classifier / src / prepareData.ts View on Github external
const oneHot = (labelIndex: number, classLength: number) => tf.tidy(() => tf.oneHot(tf.tensor1d([labelIndex]).toInt(), classLength));
github tensorflow / tfjs-examples / snake-dqn / agent.js View on Github external
const lossFunction = () => tf.tidy(() => {
      const stateTensor = getStateTensor(
          batch.map(example => example[0]), this.game.height, this.game.width);
      const actionTensor = tf.tensor1d(
          batch.map(example => example[1]), 'int32');
      const qs = this.onlineNetwork.apply(stateTensor, {training: true})
          .mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1);

      const rewardTensor = tf.tensor1d(batch.map(example => example[2]));
      const nextStateTensor = getStateTensor(
          batch.map(example => example[4]), this.game.height, this.game.width);
      const nextMaxQTensor =
          this.targetNetwork.predict(nextStateTensor).max(-1);
      const doneMask = tf.scalar(1).sub(
          tf.tensor1d(batch.map(example => example[3])).asType('float32'));
      const targetQs =
          rewardTensor.add(nextMaxQTensor.mul(doneMask).mul(gamma));
      return tf.losses.meanSquaredError(targetQs, qs);
    });
github tensorflow / tfjs-models / speech-commands / training / soft-fft / utils / dataset.ts View on Github external
        tf.tidy(() => tf.oneHot(tf.tensor1d([label]).toInt(), this.numClasses));
github tensorflow / tfjs / demos / pacman / controller_dataset.ts View on Github external
    const y = tf.tidy(() => tf.oneHot(tf.tensor1d([label]), this.numClasses));
github tensorflow / tfjs-examples / mnist-node / data.js View on Github external
const imagesShape = [size, IMAGE_HEIGHT, IMAGE_WIDTH, 1];
    const images = new Float32Array(tf.util.sizeFromShape(imagesShape));
    const labels = new Int32Array(tf.util.sizeFromShape([size, 1]));

    let imageOffset = 0;
    let labelOffset = 0;
    for (let i = 0; i < size; ++i) {
      images.set(this.dataset[imagesIndex][i], imageOffset);
      labels.set(this.dataset[labelsIndex][i], labelOffset);
      imageOffset += IMAGE_FLAT_SIZE;
      labelOffset += 1;
    }

    return {
      images: tf.tensor4d(images, imagesShape),
      labels: tf.oneHot(tf.tensor1d(labels, 'int32'), LABEL_FLAT_SIZE).toFloat()
    };
  }
}
github piximi / application / src / network.ts View on Github external
let y = tensorflow.tidy(() => {
    return tensorflow.oneHot(ys, categories.length - 1);
  });