<template>
  <div class="remove-container">
    <removeClothBackground modelName="backGroundRemover" :modelFilepath="modelFilepath" :imageSize="imageSize"
      :warmupModel="warmupModel" :preProcess="preProcess" :postProcess="postProcess" :currentFileList="currentFileList"
      :resizeSize="resizeSize" />
  </div>
</template>

<script setup>
import ndarray from "ndarray";
import ops from "ndarray-ops";
import removeClothBackground from "@/components/models/removeClothBackground/index.vue";
import { runModelUtils } from "@/utils/index";
import { Tensor } from "onnxruntime-web";
import { ref } from "vue";
import dataUtils from '@/utils/util.js'
import { useMeta } from 'vue-meta'
useMeta({
  title: '',
  htmlAttrs: { lang: 'en', amp: true }
})

const proxyUrl = process.env.NODE_ENV === 'development' ? '/api' : process.env.VUE_APP_API;
const MODEL_FILEPATH = proxyUrl + "/models/cloth_segm_u2net.onnx";

const modelFilepath = ref(MODEL_FILEPATH);

const imageSize = {
  width: 412,
  height: 412,
};

const warmupModel = (session) => {
  return runModelUtils.warmupModel(session, [
    1,
    3,
    imageSize.width,
    imageSize.height,
  ]);
};

// 模型预测预处理
const preProcess = async (ctx, originSize) => {
  // 模型预测建议尺寸
  const resizeSize = dataUtils.getImageScaleSize(originSize.width, originSize.height);
  const currentWidth = resizeSize.width;
  const currentHeight = resizeSize.height;

  // 图片原尺寸的imageData对象
  const imageData = ctx?.getImageData(0, 0, originSize.width, originSize.height) || {};
  // 缩放imageData对象
  const resizeData = await dataUtils.resizeImageData(imageData, currentWidth, currentHeight)
  const { data } = resizeData
  // data processing
  const dataTensor = ndarray(new Float32Array(data), [
    currentHeight,
    currentWidth,
    4,
  ]);

  const dataProcessedTensor = ndarray(
    new Float32Array(currentWidth * currentHeight * 3),
    [1, 3, currentHeight, currentWidth]
  );

  //  获取R数据
  ops.assign(
    dataProcessedTensor.pick(0, 0, null, null),
    dataTensor.pick(null, null, 0)
  );

  //  获取G数据
  ops.assign(
    dataProcessedTensor.pick(0, 1, null, null),
    dataTensor.pick(null, null, 1)
  );

  //获取B数据
  ops.assign(
    dataProcessedTensor.pick(0, 2, null, null),
    dataTensor.pick(null, null, 2)
  );

  ops.subseq(dataProcessedTensor.pick(0, 0, null, null), 127.5);
  ops.subseq(dataProcessedTensor.pick(0, 1, null, null), 127.5);
  ops.subseq(dataProcessedTensor.pick(0, 2, null, null), 127.5);
  ops.divseq(dataProcessedTensor, 127.5);
  //转化成float32格式
  const tensor = new Tensor(
    "float32",
    new Float32Array(currentWidth * currentHeight * 3),
    [1, 3, currentHeight, currentWidth]
  );

  tensor.data.set(dataProcessedTensor.data);
  return {
    tensor,
    resizeSize,
    resizeData: resizeData
  };
};

const postProcess = async (tensor, data, resizeSize) => {
  // TODO 缩放tensor到原尺寸, data为原尺寸
  try {
    // for (let i = 0; i < resizeSize.width * resizeSize.height; i++) {
    //   data.data[i * 4 + 3] =
    //     Math.round(255 * tensor.data[i]);
    // }


    let index = tensor.dims[3] * tensor.dims[2];
    for (let i = 0; i < index; i++) {
      let tmp_arr = [tensor.data[i], tensor.data[i + index], tensor.data[i + index * 2], tensor.data[i + index * 3]];
      let max_value = Math.max.apply(null, tmp_arr);
      let max_index = tmp_arr.indexOf(max_value);
      if (max_index > 0) { data.data[i * 4 + 3] = 255 } else { data.data[i * 4 + 3] = 0 }
    }

    return data;
  } catch (e) {
    alert("Model is not valid!");
  }
};

</script>

<style lang="less">
.remove-container {
  width: 100%;
  height: 100%;
}
</style>
