Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
const metadata = {
'input_token_index': inputTokenIndex,
'target_token_index': targetTokenIndex,
'max_encoder_seq_length': maxEncoderSeqLength,
'max_decoder_seq_length': maxDecoderSeqLength,
};
fs.writeFileSync(metadataJsonPath, JSON.stringify(metadata));
console.log('Saved metadata at: ', metadataJsonPath);
const encoderInputDataBuf = tf.buffer([
inputTexts.length,
maxEncoderSeqLength,
numEncoderTokens,
]);
const decoderInputDataBuf = tf.buffer([
inputTexts.length,
maxDecoderSeqLength,
numDecoderTokens,
]);
const decoderTargetDataBuf = tf.buffer([
inputTexts.length,
maxDecoderSeqLength,
numDecoderTokens,
]);
for (
const [i, [inputText, targetText]]
of (zip(inputTexts, targetTexts).entries() as IterableIterator<[number, [string, string]]>)
) {
for (const [t, char] of inputText.split('').entries()) {
// encoder_input_data[i, t, input_token_index[char]] = 1.
async function decodeSequence (
inputSeq: tf.Tensor,
encoderModel: tf.LayersModel,
decoderModel: tf.LayersModel,
numDecoderTokens: number,
targetBeginIndex: number,
reverseTargetCharIndex: {[indice: number]: string},
maxDecoderSeqLength: number,
) {
// Encode the input as state vectors.
let statesValue = encoderModel.predict(inputSeq) as tf.Tensor[];
// Generate empty target sequence of length 1.
let targetSeq = tf.buffer([
1,
1,
numDecoderTokens,
]);
// Populate the first character of target sequence with the start character.
targetSeq.set(1, 0, 0, targetBeginIndex);
// Sampling loop for a batch of sequences
// (to simplify, here we assume a batch of size 1).
let stopCondition = false;
let decodedSentence = '';
while (!stopCondition) {
const [outputTokens, h, c] = decoderModel.predict(
[targetSeq.toTensor(), ...statesValue]
) as [
decodeSequence(inputSeq) {
// Encode the inputs state vectors.
let statesValue = this.encoderModel.predict(inputSeq);
// Generate empty target sequence of length 1.
let targetSeq = tf.buffer([1, 1, this.numDecoderTokens]);
// Populate the first character of the target sequence with the start
// character.
targetSeq.set(1, 0, 0, this.targetTokenIndex['\t']);
// Sample loop for a batch of sequences.
// (to simplify, here we assume that a batch of size 1).
let stopCondition = false;
let decodedSentence = '';
while (!stopCondition) {
const predictOutputs =
this.decoderModel.predict([targetSeq.toTensor()].concat(statesValue));
const outputTokens = predictOutputs[0];
const h = predictOutputs[1];
const c = predictOutputs[2];
// Sample a token.
this.state = this.zeroState;
}
const results = [];
const userInput = Array.from(seed);
const encodedInput = [];
userInput.forEach((char) => {
encodedInput.push(this.vocab[char]);
});
let input = encodedInput[0];
let probabilitiesNormalized = []; // will contain final probabilities (normalized)
for (let i = 0; i < userInput.length + length + -1; i += 1) {
const onehotBuffer = tf.buffer([1, this.vocabSize]);
onehotBuffer.set(1.0, 0, input);
const onehot = onehotBuffer.toTensor();
let output;
if (this.model.embedding) {
const embedded = tf.matMul(onehot, this.model.embedding);
output = tf.multiRNNCell(this.cells, embedded, this.state.c, this.state.h);
} else {
output = tf.multiRNNCell(this.cells, onehot, this.state.c, this.state.h);
}
this.state.c = output[0];
this.state.h = output[1];
const outputH = this.state.h[1];
const weightedResult = tf.matMul(outputH, this.model.fullyConnectedWeights);
const logits = tf.add(weightedResult, this.model.fullyConnectedBiases);
encodeString(str) {
const strLen = str.length;
const encoded =
tf.buffer([1, this.maxEncoderSeqLength, this.numEncoderTokens]);
for (let i = 0; i < strLen; ++i) {
if (i >= this.maxEncoderSeqLength) {
console.error(
'Input sentence exceeds maximum encoder sequence length: ' +
this.maxEncoderSeqLength);
}
const tokenIndex = this.inputTokenIndex[str[i]];
if (tokenIndex == null) {
console.error(
'Character not found in input token index: "' + tokenIndex + '"');
}
encoded.set(1, 0, i, tokenIndex);
}
return encoded.toTensor();
}
if (!fs.existsSync(path.dirname(metadataJsonPath))) {
mkdirp.sync(path.dirname(metadataJsonPath));
}
const metadata = {
'input_token_index': inputTokenIndex,
'target_token_index': targetTokenIndex,
'max_encoder_seq_length': maxEncoderSeqLength,
'max_decoder_seq_length': maxDecoderSeqLength,
};
fs.writeFileSync(metadataJsonPath, JSON.stringify(metadata));
console.log('Saved metadata at: ', metadataJsonPath);
const encoderInputDataBuf = tf.buffer([
inputTexts.length,
maxEncoderSeqLength,
numEncoderTokens,
]);
const decoderInputDataBuf = tf.buffer([
inputTexts.length,
maxDecoderSeqLength,
numDecoderTokens,
]);
const decoderTargetDataBuf = tf.buffer([
inputTexts.length,
maxDecoderSeqLength,
numDecoderTokens,
]);
for (
function discountRewards(rewards, discountRate) {
const discountedBuffer = tf.buffer([rewards.length]);
let prev = 0;
for (let i = rewards.length - 1; i >= 0; --i) {
const current = discountRate * prev + rewards[i];
discountedBuffer.set(current, i);
prev = current;
}
return discountedBuffer.toTensor();
}
export function encodeInputDateStrings(dateStrings) {
const n = dateStrings.length;
const x = tf.buffer([n, INPUT_LENGTH], 'float32');
for (let i = 0; i < n; ++i) {
for (let j = 0; j < INPUT_LENGTH; ++j) {
if (j < dateStrings[i].length) {
const char = dateStrings[i][j];
const index = INPUT_VOCAB.indexOf(char);
if (index === -1) {
throw new Error(`Unknown char: ${char}`);
}
x.set(index, i, j);
}
}
}
return x.toTensor();
}
async feed(inputSeed, callback) {
await this.ready;
const seed = Array.from(inputSeed);
const encodedInput = [];
seed.forEach((char) => {
encodedInput.push(this.vocab[char]);
});
let input = encodedInput[0];
for (let i = 0; i < seed.length; i += 1) {
const onehotBuffer = tf.buffer([1, this.vocabSize]);
onehotBuffer.set(1.0, 0, input);
const onehot = onehotBuffer.toTensor();
let output;
if (this.model.embedding) {
const embedded = tf.matMul(onehot, this.model.embedding);
output = tf.multiRNNCell(this.cells, embedded, this.state.c, this.state.h);
} else {
output = tf.multiRNNCell(this.cells, onehot, this.state.c, this.state.h);
}
this.state.c = output[0];
this.state.h = output[1];
input = encodedInput[i];
}
if (callback) {
callback();
}