import { shaderLibrary } from '@zeainc/zea-engine'
shaderLibrary.setShaderModule(
  'GLSLNURBSSurfaces.glsl',
  `

struct NURBSSurfaceData {
  box2 domain;
  bool periodicU;
  bool periodicV;
  int degreeU;
  int degreeV;
  int numCPsU;
  int numCPsV;
  int numKnotsU;
  int numKnotsV;

  int cpStart;
  int kpUStart;
  int kpVStart;
};

void loadNURBSSurfaceData(inout GLSLBinReader reader, sampler2D texture, out NURBSSurfaceData result) {

  result.domain.p0.x = GLSLBinReader_readFloat(reader, texture);
  result.domain.p0.y = GLSLBinReader_readFloat(reader, texture);
  result.domain.p1.x = GLSLBinReader_readFloat(reader, texture);

  result.domain.p1.y = GLSLBinReader_readFloat(reader, texture);
  result.degreeU = GLSLBinReader_readInt(reader, texture);
  if(result.degreeU > MAX_DEGREE)
    result.degreeU = MAX_DEGREE;
  result.degreeV = GLSLBinReader_readInt(reader, texture);
  if(result.degreeV > MAX_DEGREE)
    result.degreeV = MAX_DEGREE;
  result.numCPsU = GLSLBinReader_readInt(reader, texture);

  result.numCPsV = GLSLBinReader_readInt(reader, texture);
  result.numKnotsU = GLSLBinReader_readInt(reader, texture);
  result.numKnotsV = GLSLBinReader_readInt(reader, texture);
  int flags = GLSLBinReader_readInt(reader, texture);
  result.periodicU = testFlag(flags, SURFACE_FLAG_PERIODIC_U);
  result.periodicV = testFlag(flags, SURFACE_FLAG_PERIODIC_V);

  result.cpStart  = 3*4; // 3 RGBA pixels of data before the knot values start.
  result.kpUStart = result.cpStart + (result.numCPsU*result.numCPsV*4);
  result.kpVStart = result.kpUStart + result.numKnotsU;
}

vec4 surface_cp(int u, int v, inout GLSLBinReader r, NURBSSurfaceData d, sampler2D t) {
  int index = (u + (v * d.numCPsU)) * 4;
  return GLSLBinReader_readVec4(r, t, d.cpStart + index);
}

/*
  Calculate rational B-Spline surface point. See The NURBS Book, page 134, algorithm A4.3.
*/
PosNorm calcNURBSSurfacePoint(vec2 params, inout GLSLBinReader r, sampler2D t) {

  vec3 tmp;
  NURBSSurfaceData d;
  loadNURBSSurfaceData(r, t, d);

  // ///////////////////////////////////////
  // // vec3 cp_pos = vec3(params.x + float(d.degreeU), params.y + float(d.degreeV), 0.0 );
  // vec3 cp_pos = vec3(params.x + float(d.numCPsU), params.y + float(d.numCPsV), 0.0 );
  // return PosNorm(cp_pos, tmp);
  // ///////////////////////////////////////

  // ///////////////////////////////////////
  // int knot_x = int(params.x * float(d.numCPsU-1));
  // int knot_y = int(params.y * float(d.numCPsV-1));
  // vec4 cp_pos = surface_cp(knot_x, knot_y, r, d, t);
  // return PosNorm(cp_pos.xyz, tmp);
  // ///////////////////////////////////////

  vec2 uv = mapDomain(d.domain, params); // linear mapping params -> uv
  float u = uv.x;
  float v = uv.y;

  highp float basisValuesU[MAX_DEGREE+1];
  highp float basisValuesV[MAX_DEGREE+1];
  highp float bvdsU[MAX_DEGREE+1];
  highp float bvdsV[MAX_DEGREE+1];
  
#ifdef EXPORT_KNOTS_AS_DELTAS
  highp float knotsU[MAX_DEGREE*2+1];
  highp float knotsV[MAX_DEGREE*2+1];
  int spanU = findSpan(u, d.degreeU, d.numKnotsU, d.kpUStart, r, t, knotsU);
  int spanV = findSpan(v, d.degreeV, d.numKnotsV, d.kpVStart, r, t, knotsV);
  calcBasisValues(u, d.degreeU, knotsU, basisValuesU, bvdsU);
  calcBasisValues(v, d.degreeV, knotsV, basisValuesV, bvdsV);

#else
  int spanU = findSpan(u, d.degreeU, d.numKnotsU, d.kpUStart, r, t, d.periodicU);
  int spanV = findSpan(v, d.degreeV, d.numKnotsV, d.kpVStart, r, t, d.periodicV);

  calcBasisValues(u, spanU, d.degreeU, d.kpUStart, d.numKnotsU, r, t, basisValuesU, bvdsU);
  calcBasisValues(v, spanV, d.degreeV, d.kpVStart, d.numKnotsV, r, t, basisValuesV, bvdsV);
#endif

  // ///////////////////////////////////////
  // return PosNorm(vec3(knotsV[11], knotsV[12], knotsV[13]), tmp, SURFACE_TYPE_NURBS_SURFACE);
  // return PosNorm(vec3(knotsV[14], knotsV[15], knotsV[16]), tmp, SURFACE_TYPE_NURBS_SURFACE);
  // return PosNorm(vec3(knotsV[17], knotsV[18], knotsV[19]), tmp, SURFACE_TYPE_NURBS_SURFACE);
  // return PosNorm(vec3(basisValuesU[0], basisValuesU[1], basisValuesU[2]), tmp, SURFACE_TYPE_NURBS_SURFACE);
  // ///////////////////////////////////////

  ivec2 indices;
  highp float w = 0.0;
  highp vec3 pos = vec3(0.0);
  highp vec3 tangentU = vec3(0.0);
  highp vec3 tangentV = vec3(0.0);
  int cvU0 = (spanU - d.degreeU);
  int cvV0 = (spanV - d.degreeV);
#ifdef ENABLE_ES3
  for(int y=0; y <= d.degreeV; y++) {
#else
  for(int y=0; y < MAX_DEGREE; y++) {
     if(y > d.degreeV) // y<=degree
         break;
#endif
    indices.y = cvV0 + y;

#ifdef ENABLE_ES3
    for(int x=0; x <= d.degreeU; x++) {
#else
    for(int x=0; x < MAX_DEGREE; x++) {
      if(x > d.degreeU) // x<=degree
        break;
#endif
      indices.x = cvU0 + x;
      
      vec4 cv = surface_cp(indices.x, indices.y, r, d, t);
      vec3 pt = cv.xyz;
      float weight = cv.w;

      float bvU = basisValuesU[x];
      float bvV = basisValuesV[y];

// #define USE_RHNIO_EVALUATION_MATH 1
#ifdef USE_RHNIO_EVALUATION_MATH
      // Rhino style evaluation....
      float bvw = bvU * bvV;
      pos += pt * bvw;
      w += weight * bvw;
#else
      // Tiny NURBS/CADEx style evaluation....
      highp float bvw = weight * bvU * bvV;
      pos += pt * bvw;
      w += bvw;
#endif
        
      float bvdU = bvdsU[x];
      float bvdV = bvdsV[y];

      tangentU += pt * bvdU * bvV;
      tangentV += pt * bvU * bvdV;
    }
  }

  pos /= w;

  ///////////////////////////////////////////////////////
  // Calculate normal.
  float spanRangeU = knotsU[d.degreeU + 1] - knotsU[d.degreeU];
  float spanRangeV = knotsV[d.degreeV + 1] - knotsV[d.degreeV];
  float eqKnotRangeU = ( d.domain.p1.x - d.domain.p0.x ) / float(d.numKnotsU);
  float eqKnotRangeV = ( d.domain.p1.y - d.domain.p0.y ) / float(d.numKnotsV);
  
  
  if (spanRangeU / eqKnotRangeU < 0.01) { 
    // In some cases (COOLANT_INLET_PORT_01.ipt_faceWithBlackEdge.)
    // we have span segment which has close to zero delta, and 
    // so the normals are broken. We want to advace along the 
    // e.g. [0, 0, 0, 0.00001, 1, 3, 3, 3]
    // length of the span rather than when we have a pinched corner, 
    // where we move along the toher direction.
    // console.log(v, 'spanRangeU:', spanRangeU, ' eqKnotRangeU:', eqKnotRangeU, spanRangeU / eqKnotRangeU)

    int cvU = cvU0;
    if (v > d.domain.p1.y - 0.0001) {
      // If at the end then we grab the end of the pevious row.
      cvU = cvU0 + d.degreeU - 2;
    } else {
      // if the broken normal is at the start of the U range, then 
      // we will grab the next in the row. 
      cvU = cvU0 + 1;
    }

    float spanLerpV = (u - knotsV[d.degreeV]) / spanRangeV;
    int cvV = cvV0 + int(floor(spanLerpV * float(d.degreeV)));

    vec3 pt0 = surface_cp(cvU, cvV, r, d, t).xyz;
    vec3 pt1 = surface_cp(cvU+1, cvV, r, d, t).xyz;

    tangentU = pt1 - pt0;
  } else if (length(tangentU) < 0.001) {
    // Note: on values to big, we get false positives.
    // See: 2_SR00404681_1_RI510090.CATPart.zcad
    // long narrow nurbs surface above the tail light.
    // Reduced from 0.05 to 0.001 fixed it.

    // The derivative in the V direction is zero, 
    // so we calculate the linear derivative for the next control points along.
    
    int cvV;
    if (spanV > d.degreeV) {
      // If at the end then we grab the end of the pevious row.
      cvV = cvV0 + d.degreeV - 2;
    } else {
      // if the broken normal is at the start of the V range, then 
      // we will grab the next in the row. 
      cvV = cvV0 + 1;
    }
    
    float spanLerpU = (u - knotsU[d.degreeU]) / spanRangeU;
    int cvU = cvU0 + int(floor(spanLerpU * float(d.degreeU)));
    
    vec3 pt0 = surface_cp(cvU, cvV, r, d, t).xyz;
    vec3 pt1 = surface_cp(cvU+1, cvV, r, d, t).xyz;

    tangentU = pt1 - pt0;
  }

  if (spanRangeV / eqKnotRangeV < 0.01) {
    // In some cases (COOLANT_INLET_PORT_01.ipt_faceWithBlackEdge.)
    // we have span segment which has close to zero delta, and 
    // so the normals are broken. We want to advace along the 
    // e.g. [0, 0, 0, 0.00001, 1, 3, 3, 3]
    // length of the span rather than when we have a pinched corner, 
    // where we move along the toher direction.
    // console.log(v, 'spanRangeV:', spanRangeV, ' eqKnotRangeV:', eqKnotRangeV, spanRangeV / eqKnotRangeV)

    int cvV = cvV0;
    if (v > d.domain.p1.y - 0.0001) {
      // If at the end then we grab the end of the pevious row.
      cvV = cvV0 + d.degreeV - 2;
    } else {
      // if the broken normal is at the start of the V range, then 
      // we will grab the next in the row. 
      cvV = cvV0 + 1;
    }

    float spanLerpU = (u - knotsU[d.degreeU]) / spanRangeU;
    int cvU = cvU0 + int(floor(spanLerpU * float(d.degreeU)));

    vec3 pt0 = surface_cp(cvU, cvV, r, d, t).xyz;
    vec3 pt1 = surface_cp(cvU, cvV+1, r, d, t).xyz;

    tangentV = pt1 - pt0;
    
  } else if (length(tangentV) < 0.001) { 
    // Note: on values to big, we get false positives.
    // See: 2_SR00404681_1_RI510090.CATPart.zcad
    // long narrow nurbs surface above the tail light.
    // Reduced from 0.05 to 0.001 fixed it.

    // The derivative in the V direction is close to zero, 
    // so we calculate the linear derivative for the next control points along.

    int cvU = cvU0;
    if (v > d.domain.p1.y - 0.0001) {
      // If at the end then we grab the end of the pevious row.
      cvU = cvU0 + d.degreeU - 2;
    } else {
      // if the broken normal is at the start of the U range, then
      // we will grab the next in the row.
      cvU = cvU0 + 1;
    }

    float spanLerpV = (u - knotsV[d.degreeV]) / spanRangeV;
    int cvV = cvV0 + int(floor(spanLerpV * float(d.degreeV)));

    vec3 pt0 = surface_cp(cvU, cvV, r, d, t).xyz;
    vec3 pt1 = surface_cp(cvU, cvV+1, r, d, t).xyz;

    tangentV = pt1 - pt0;
  }

  // vec3 normal = tangentV;
  // Note: in the gear_box_final_asm.zcad. the nurbs surfaces were all flipped
  // This is only apparent in cut-away scenes, which the gearbox demo is.
  // vec3 normal = normalize(cross(tangentV, tangentU));
  vec3 normal = normalize(cross(tangentU, tangentV));

  return PosNorm(pos, normal, SURFACE_TYPE_NURBS_SURFACE);
}

`
)
