// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {expect} from 'chai';
import {env} from 'onnxruntime-common';

import {Backend, InferenceHandler, resolveBackend, SessionHandler} from '../../../../lib/onnxjs/backend';
import {WebGLInferenceHandler} from '../../../../lib/onnxjs/backends/webgl/inference-handler';
import {Profiler} from '../../../../lib/onnxjs/instrument';
import {Tensor} from '../../../../lib/onnxjs/tensor';

import {createAscendingArray} from './test-utils';

interface TestData {
  elementCount: number;
  inputShape: number[];
  outputShape: number[];
}
function getTestData(): TestData[] {
  return [
    // test 2D tensor
    {
      elementCount: 16,
      inputShape: [4, 4],
      outputShape: [2, 8],
    },
    {
      elementCount: 16,
      inputShape: [4, 4],
      outputShape: [1, 16],
    },
    {
      elementCount: 8,
      inputShape: [2, 4],
      outputShape: [4, 2],
    },
    {
      elementCount: 8,
      inputShape: [2, 4],
      outputShape: [1, 8],
    },
    {
      elementCount: 6,
      inputShape: [2, 3],
      outputShape: [1, 6],
    },
    {
      elementCount: 6,
      inputShape: [2, 3],
      outputShape: [3, 2],
    },

    // test 3d tensor
    {
      elementCount: 16,
      inputShape: [2, 2, 4],
      outputShape: [4, 2, 2],
    },
    {
      elementCount: 16,
      inputShape: [2, 2, 4],
      outputShape: [2, 4, 2],
    },
    {
      elementCount: 16,
      inputShape: [2, 2, 4],
      outputShape: [1, 1, 2, 8],
    },

    // test 4d tensor
    {
      elementCount: 32,
      inputShape: [2, 2, 2, 4],
      outputShape: [4, 2, 2, 2],
    },
    {
      elementCount: 32,
      inputShape: [2, 2, 2, 4],
      outputShape: [2, 4, 2, 2],
    },

    {
      elementCount: 32,
      inputShape: [2, 2, 2, 4],
      outputShape: [2, 2, 4, 2],
    },
    {
      elementCount: 32,
      inputShape: [2, 2, 2, 4],
      outputShape: [2, 1, 4, 4],
    },
    {
      elementCount: 18432,
      inputShape: [512, 36, 1, 1],
      outputShape: [512, 36],
    },
    {
      elementCount: 18432,
      inputShape: [512, 36],
      outputShape: [512, 36, 1, 1],
    },
  ];
}

let backend: Backend|undefined;
let sessionhandler: SessionHandler|undefined;
let inferenceHandler: InferenceHandler|undefined;

describe('#UnitTest# - reshape - packed', () => {
  before('Initialize Context', async () => {
    const profiler = Profiler.create();
    backend = await resolveBackend('webgl');
    sessionhandler = backend.createSessionHandler({profiler});
    inferenceHandler = sessionhandler.createInferenceHandler();
  });

  const testDataSet = getTestData();
  for (let k = 0; k < testDataSet.length; ++k) {
    const testData = testDataSet[k];
    describe(`Test reshape ${JSON.stringify(testData)}`, () => {});
    it(`Test packed reshape kernel ${JSON.stringify(testData.outputShape)}`, () => {
      const webglInferenceHandler = inferenceHandler as WebGLInferenceHandler;

      if (!env.webgl.pack) {
        console.log('Skipping in unpacked texture mode.');
        return;
      }

      const elementCount = testData.elementCount;
      const inputTensorShape = testData.inputShape;
      const outputTensorShape = testData.outputShape;

      // create input data and tensor.
      const inputData = createAscendingArray(elementCount);
      const inputTensorA = new Tensor(inputTensorShape, 'float32', undefined, undefined, inputData);

      // run kernal and get output
      const resultTensor = webglInferenceHandler.reshapePacked(inputTensorA, outputTensorShape);
      const result = resultTensor.data;

      webglInferenceHandler.session.textureManager.glContext.checkError();
      // verify result.
      expect(result).to.not.equal(null);

      expect(result).to.have.lengthOf(elementCount);

      expect(result).to.deep.equal(inputData);
    });
  }
});
