How to use the @tensorflow/tfjs-layers.models function in @tensorflow/tfjs-layers

To help you get started, we’ve selected a few @tensorflow/tfjs-layers examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tensorflow / tfjs / integration_tests / benchmarks / index.js View on Github external
async function runBenchmark(artifactsDir, modelName, config) {
  const modelPath = artifactsDir + modelName + '/';
  console.log('Loading model "' + modelName + '" and benchmark data...');
  // Note: currently we load only the topology. The weight values don't matter
  // for the benchmarks and are initialized according to the initializer.
  const modelJSON = await (await fetch(modelPath + 'model.json')).json();
  const model = await tfl.models.modelFromJSON(modelJSON['modelTopology']);
  console.log('Done loading model "' + modelName + '" and benchmark data.');

  const benchmarkData = await (await fetch(modelPath + 'data.json')).json();

  const lossMap = {
    mean_squared_error: 'meanSquaredError',
    categorical_crossentropy: 'categoricalCrossentropy',
  };
  // TODO(cais): Maybe TF.js Layers should tolerate these Python-style names
  // for losses.

  const [xs, ys] = getRandomInputsAndOutputs(model, benchmarkData.batch_size);

  if (benchmarkData.train_epochs > 0) {
    const optimizer =
        optimizerMap[benchmarkData.optimizer] || benchmarkData.optimizer;
github tensorflow / tfjs / tfjs / integration_tests / models / common.ts View on Github external
export async function loadLayersModel(modelName: string):
    Promise {
  // tslint:disable-next-line:no-any
  let modelJSON: any;
  if (inNodeJS()) {
    // In Node.js.
    const modelJSONPath = `./data/${modelName}/model.json`;
    // tslint:disable-next-line:no-require-imports
    const fs = require('fs');
    modelJSON = JSON.parse(fs.readFileSync(modelJSONPath, 'utf-8'));
  } else {
    // In browser.
    const modelJSONPath = `${DATA_SERVER_ROOT}/${modelName}/model.json`;
    modelJSON = await (await fetch(modelJSONPath)).json();
  }
  return tfl.models.modelFromJSON(modelJSON['modelTopology']);
}