Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
function deprocessImage(x) {
return tf.tidy(() => {
const {mean, variance} = tf.moments(x);
x = x.sub(mean);
// Add a small positive number (EPSILON) to the denominator to prevent
// division-by-zero.
x = x.div(tf.sqrt(variance).add(tf.backend().epsilon()));
// Clip to [0, 1].
x = x.add(0.5);
x = tf.clipByValue(x, 0, 1);
x = x.mul(255);
return tf.clipByValue(x, 0, 255).asType('int32');
});
}
playStep() {
this.epsilon = this.frameCount >= this.epsilonDecayFrames ?
this.epsilonFinal :
this.epsilonInit + this.epsilonIncrement_ * this.frameCount;
this.frameCount++;
// The epsilon-greedy algorithm.
let action;
const state = this.game.getState();
if (Math.random() < this.epsilon) {
// Pick an action at random.
action = getRandomAction();
} else {
// Greedily pick an action based on online DQN output.
tf.tidy(() => {
const stateTensor =
getStateTensor(state, this.game.height, this.game.width)
action = ALL_ACTIONS[
this.onlineNetwork.predict(stateTensor).argMax(-1).dataSync()[0]];
});
}
const {state: nextState, reward, done, fruitEaten} = this.game.step(action);
this.replayMemory.append([state, action, reward, done, nextState]);
this.cumulativeReward_ += reward;
if (fruitEaten) {
this.fruitsEaten_++;
}
const output = {
protected getRnnInputFeats() {
// Initialize decoder feats array.
const feats: tf.Tensor1D = tf.tidy(() => {
// Add button input to decoder feats and translate to [-1, 1].
const buttonTensor = tf.tensor1d([this.button], 'float32');
const buttonScaled =
tf.sub(tf.mul(2., tf.div(buttonTensor, NUM_BUTTONS - 1)), 1);
return buttonScaled.as1D();
});
return feats;
}
convertRawToTensors(dataRaw) {
const meta = Object.assign({}, this.meta);
const dataLength = dataRaw.length;
return tf.tidy(() => {
const inputArr = [];
const outputArr = [];
dataRaw.forEach(row => {
// get xs
const xs = Object.keys(meta.inputs)
.map(k => {
return row.xs[k];
})
.flat();
inputArr.push(xs);
// get ys
const ys = Object.keys(meta.outputs)
.map(k => {
function predict(x) {
// y = a * x ^ 3 + b * x ^ 2 + c * x + d
return tf.tidy(() => {
return a.mul(x.pow(tf.scalar(3, 'int32')))
.add(b.mul(x.square()))
.add(c.mul(x))
.add(d);
});
}
export const doSinglePrediction = async (model, img, options = {}) => {
// First get input tensor
const resized = tf.tidy(() => {
img = tf.browser.fromPixels(img)
if (NUM_CHANNELS === 1) {
// Bring it down to gray
const gray_mid = img.mean(2)
img = gray_mid.expandDims(2) // back to (width, height, 1)
}
// assure (img.shape[0] === IMAGE_WIDTH && img.shape[1] === IMAGE_WIDTH
const alignCorners = true
return tf.image.resizeBilinear(
img,
[IMAGE_WIDTH, IMAGE_HEIGHT],
alignCorners
)
})
const logits = tf.tidy(() => {
function gameToFeaturesAndLabel(gameState) {
return tf.tidy(() => {
const player1Hand = tf.tensor1d(gameState.player1Hand, 'int32');
const handOneHot = tf.oneHot(
tf.sub(player1Hand, tf.scalar(1, 'int32')),
game.GAME_STATE.max_card_value);
const features = tf.sum(handOneHot, 0);
const label = tf.tensor1d([gameState.player1Win]);
return {xs: features, ys: label};
});
}
function normalize(x) {
return tf.tidy(() => {
const mean = tf.mean(x);
const std = tf.sqrt(tf.mean(tf.square(tf.add(x, tf.neg(mean)))));
return tf.div(tf.add(x, tf.neg(mean)), std);
});
}
function preprocess(webcam) {
return tf.tidy(() => {
const frame = tf.fromPixels(webcam);
const cropped = squareCrop(frame).toFloat();
const scaled =
tf.image.resizeBilinear(cropped, [MODEL_INPUT_WIDTH, MODEL_INPUT_WIDTH]);
const prepped = scaled.sub(255 / 2).div(255 / 2).expandDims(0);
return prepped;
});
}
async getPrediction(){
if(this.state.loadedData) {
let canvas = this.canvas;
let imageData = DataProvider.getScaledData(canvas,28, 28);
await tfjs.tidy(() => {
let img = tfjs.fromPixels(imageData, 1);
img = tfjs.reshape(img, [1, 28, 28, 1]);
img = tfjs.cast(img, 'float32');
img = img.div(tfjs.scalar(255));
const output = this.model.predict(img);
this.preds = Array.from(output.dataSync());
});
this.setState({
result : Converter.findMaxProp(this.preds),
charData: {
labels: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
datasets: [
{