How to use the @tensorflow/tfjs.train function in @tensorflow/tfjs

To help you get started, we’ve selected a few @tensorflow/tfjs examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tensorflow / tfjs-examples / mnist-transfer-cnn / index.js View on Github external
if (trainingMode === 'freeze-feature-layers') {
      console.log('Freezing feature layers of the model.');
      for (let i = 0; i < 7; ++i) {
        this.model.layers[i].trainable = false;
      }
    } else if (trainingMode === 'reinitialize-weights') {
      // Make a model with the same topology as before, but with re-initialized
      // weight values.
      const returnString = false;
      this.model = await tf.models.modelFromJSON({
        modelTopology: this.model.toJSON(null, returnString)
      });
    }
    this.model.compile({
      loss: 'categoricalCrossentropy',
      optimizer: tf.train.adam(0.01),
      metrics: ['acc'],
    });

    // Print model summary again after compile(). You should see a number
    // of the model's weights have become non-trainable.
    this.model.summary();

    const batchSize = 128;
    const epochs = ui.getEpochs();

    const surfaceInfo = {name: trainingMode, tab: 'Transfer Learning'};
    console.log('Calling model.fit()');
    await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
      batchSize: batchSize,
      epochs: epochs,
      validationData: [this.gte5TestData.x, this.gte5TestData.y],
github tensorflow / tfjs-examples / simple-object-detection / train.js View on Github external
console.log('Training using CPU.');
    require('@tensorflow/tfjs-node');
  }

  const modelSaveURL = 'file://./dist/object_detection_model';

  const tBegin = tf.util.now();
  console.log(`Generating ${args.numExamples} training examples...`);
  const synthDataCanvas = canvas.createCanvas(CANVAS_SIZE, CANVAS_SIZE);
  const synth =
      new synthesizer.ObjectDetectionImageSynthesizer(synthDataCanvas, tf);
  const {images, targets} =
      await synth.generateExampleBatch(args.numExamples, numCircles, numLines);

  const {model, fineTuningLayers} = await buildObjectDetectionModel();
  model.compile({loss: customLossFunction, optimizer: tf.train.rmsprop(5e-3)});
  model.summary();

  // Initial phase of transfer learning.
  console.log('Phase 1 of 2: initial transfer learning');
  await model.fit(images, targets, {
    epochs: args.initialTransferEpochs,
    batchSize: args.batchSize,
    validationSplit: args.validationSplit
  });

  // Fine-tuning phase of transfer learning.
  // Unfreeze layers for fine-tuning.
  for (const layer of fineTuningLayers) {
    layer.trainable = true;
  }
  model.compile({loss: customLossFunction, optimizer: tf.train.rmsprop(2e-3)});
github aayusharora / GeneticAlgorithms / part1 / src / nn.js View on Github external
}))

    /* this is the second output layer with 6 inputs coming from the previous hidden layer
    activation is again sigmoid and output is given as 2 units 10 for not jump and 01 for jump
    */
    dino.model.add(tf.layers.dense({
      inputShape:[6],
      activation:'sigmoid',
      units:2
    }))

    /* compiling the model using meanSquaredError loss function and adam 
    optimizer with a learning rate of 0.1 */
    dino.model.compile({
      loss:'meanSquaredError',
      optimizer : tf.train.adam(0.1)
    })

    // object which will containn training data and appropriate labels
    dino.training = {
      inputs: [],
      labels: []
    };
    
  } else {
    // Train the model before restarting.
    // log into console that model will now be trained
    console.info('Training');
    // convert the inputs and labels to tensor2d format and  then training the model
    console.info(tf.tensor2d(dino.training.inputs))
    dino.model.fit(tf.tensor2d(dino.training.inputs), tf.tensor2d(dino.training.labels));
  }
github charliegerard / gestures-ml-js / arduino-mkr1000 / training-old-hp.js View on Github external
const createModel = async(xTrain, yTrain, xTest, yTest) => {
  const params = {learningRate: 0.1, epochs: 40};
  // Define the topology of the model: two dense layers.
  const model = tf.sequential();
  model.add(tf.layers.dense(
      {units: 10, activation: 'sigmoid', inputShape: [xTrain.shape[1]]}));
  model.add(tf.layers.dense({units: numClasses, activation: 'softmax'}));
  model.summary();

  const optimizer = tf.train.adam(params.learningRate);
  model.compile({
    optimizer: optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
  });

  const trainLogs = [];

  await model.fit(xTrain, yTrain, {
    epochs: params.epochs,
    validationData: [xTest, yTest],
    callbacks: {
      onEpochEnd: async (epoch, logs) => {
        // Plot the loss and accuracy values at the end of every training epoch.
        trainLogs.push(logs);
      },
github jinglescode / demos / src / app / components / tfjs-timeseries-stocks / tfjs-timeseries-stocks-main.ts View on Github external
let lstm_cells = [];
    for (let index = 0; index < n_layers; index++) {
         lstm_cells.push(tf.layers.lstmCell({units: rnn_output_neurons}));
    }

    model.add(tf.layers.rnn({
      cell: lstm_cells,
      inputShape: rnn_input_shape,
      returnSequences: false
    }));

    model.add(tf.layers.dense({units: output_layer_neurons, inputShape: [output_layer_shape]}));

    model.compile({
      optimizer: tf.train.adam(learning_rate),
      loss: 'meanSquaredError'
    });

    const hist = await model.fit(xs, ys,
      { batchSize: rnn_batch_size, epochs: n_epochs, callbacks: {
        onEpochEnd: async (epoch, log) => {
          callback(epoch, log, model_params);
        }
      }
    });

    // await model.save('localstorage://tfjs-stocks');
    // const model = await tf.loadLayersModel('localstorage://tfjs-stocks');
    // const hist = {};

    return { model: model, stats: hist };
github cstefanache / tfjs-model-view / app / iris / iris.js View on Github external
// Define the topology of the model: two dense layers.
  const model = tf.sequential();
  model.add(tf.layers.dense({
    units: 10,
    activation: 'sigmoid',
    inputShape: [xTrain.shape[1]]
  }));

  model.add(tf.layers.dense({
    units: 3,
    activation: 'softmax'
  }));

  model.summary();

  const optimizer = tf.train.adam(0.02);
  model.compile({
    optimizer: optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  });

  new ModelView(model, {
    printStats: true,
    radius: 25,
    renderLinks: true,
    layer: {
      'dense_Dense1_input': {
        domain: [0, 7]
      },
      'dense_Dense2/dense_Dense2': {
        nodePadding: 30
github googlecreativelab / teachablemachine-community / libraries / pose / src / teachable-posenet.ts View on Github external
useBias: true
            }),
            // Layer 2 dropout
            tf.layers.dropout({rate: 0.5}),
            // Layer 3. The number of units of the last layer should correspond
            // to the number of classes we want to predict.
            tf.layers.dense({
                units: this.numClasses,
                kernelInitializer: varianceScaling, // 'varianceScaling'
                useBias: false,
                activation: 'softmax'
            })
            ]
        });
        // const optimizer = tf.train.adam(params.learningRate);
        const optimizer = tf.train.rmsprop(params.learningRate);
        this.model.compile({
            optimizer,
            loss: 'categoricalCrossentropy',
            metrics: ['accuracy']
        });

        if (!(params.batchSize > 0)) {
            throw new Error(
            `Batch size is 0 or NaN. Please choose a non-zero fraction`
            );
        }

        const trainData = this.trainDataset.batch(params.batchSize);
        const validationData = this.validationDataset.batch(params.batchSize);

        // For debugging: check for shuffle or result from trainDataset
github atanasster / crypto-grommet / tensorflow / config / optimizers / momentum.js View on Github external
  tf = () => tf.train.momentum(this.getPropValue('lr'), this.getPropValue('momentum'), this.getPropValue('nesterov'))
}
github aarongoin / WebNN / lib / Model.js View on Github external
this.model = tf.sequential();
		this.description = model;
		this.lr = learning_rate;
		this.input_shape = model.inputs;
		this.output_shape = model.outputs;

		for (let layer of model.layers) {
			let key = Object.keys(layer)[0];
			if (key !== 'output') {
				this.model.add(tf.layers[key](layer[key]));
			}
			else {
				let optimizer = Object.assign({},
					layer.output,
					{ optimizer: tf.train[layer.output.optimizer](learning_rate) }
				);
				this.model.compile(optimizer);
			}
		}
		this.getSize();

		this.mergeLayer = tf.layers.average();
		if (weights) this.load(weights);
		this.steps = steps;

		this.updateLearningRate = this.updateLearningRate.bind(this);
	}
github victordibia / anomagram / app / src / components / train / Train.jsx View on Github external
//construct model
        switch (this.state.optimizer) {
            case "adam":
                this.optimizer = tf.train.adam(this.state.learningRate, this.state.adamBeta1)
                break
            case "adamax":
                this.optimizer = tf.train.adamax(this.state.learningRate, this.state.adamBeta1)
                break
            case "adadelta":
                this.optimizer = tf.train.adadelta(this.state.learningRate)
                break
            case "rmsprop":
                this.optimizer = tf.train.rmsprop(this.state.learningRate)
                break
            case "momentum":
                this.optimizer = tf.train.momentum(this.state.learningRate, this.momentum)
                break
            case "sgd":
                this.optimizer = tf.train.sgd(this.state.learningRate)
                break
            default:
                break;
        }



        let modelParams = {
            numFeatures: this.state.numFeatures,
            hiddenLayers: this.state.hiddenLayers,
            latentDim: this.state.latentDim,
            hiddenDim: this.state.hiddenDim,
            optimizer: this.optimizer,