import { shaderLibrary, GLShader, Registry } from '@zeainc/zea-engine'
import './GLSLCADConstants.js'
import './GLSLMath.js'
import './GLSLCADSurfaceDrawing.js'

const GLDrawCADSurfaceNormalsShader_VERTEX_SHADER = `
precision highp float;

attribute vec3 positions;
instancedattribute vec4 drawCoords;  // (DrawItemData Coords (x, y) 
// instancedattribute vec2 drawItemTexAddr;  // Address of the data in the draw item texture. (mat4)

uniform mat4 viewMatrix;
uniform mat4 projectionMatrix;
uniform ivec2 quadDetail;
uniform vec3 assetCentroid;
uniform float normalLength;

<%include file="GLSLUtils.glsl"/>
<%include file="GLSLCADConstants.glsl"/>
<%include file="stack-gl/transpose.glsl"/>
<%include file="stack-gl/inverse.glsl"/>

<%include file="GLSLCADSurfaceDrawing.vertexShader.glsl"/>

varying vec4 v_drawCoords;
varying vec3 v_viewPos;
varying vec3 v_worldPos;
varying vec3 v_viewNormal;
varying vec2 v_textureCoord;

void main(void) {
    int cadBodyId = ftoi(drawCoords.r);
    int drawItemIndexInBody = ftoi(drawCoords.g);
    int surfaceId = ftoi(drawCoords.b);
    int trimSetId = ftoi(drawCoords.a);

    vec2 texCoords = positions.xy + 0.5;
    
    v_drawCoords = drawCoords;

    vec4 cadBodyPixel0 = getCADBodyPixel(cadBodyId, 0);
    vec4 cadBodyPixel1 = getCADBodyPixel(cadBodyId, 1);

    // int bodyDescId = ftoi(cadBodyPixel0.r);
    int cadBodyFlags = ftoi(cadBodyPixel0.g);
    
    //////////////////////////////////////////////
    // Visibility
    if(testFlag(cadBodyFlags, BODY_FLAG_INVISIBLE)) {
        gl_Position = vec4(-3.0, -3.0, -3.0, 1.0);;
        return;
    }

    //////////////////////////////////////////////
    // Transforms
#ifdef DEBUG_SURFACES
    mat4 modelMatrix = mat4(1.0);
    // if(v_surfaceType == SURFACE_TYPE_NURBS_SURFACE) {
    //     // int drawItemIndexInBody = int(metadata.b+0.5);
    //     int sideLen = int(ceil(sqrt(float(numSurfacesInLibrary))));
    //     int x = drawItemIndexInBody % sideLen;
    //     int y = drawItemIndexInBody / sideLen;
    //     modelMatrix = mat4(1.0, 0.0, 0.0, 0.0, 
    //                     0.0, 1.0, 0.0, 0.0, 
    //                     0.0, 0.0, 1.0, 0.0,  
    //                     float(x), float(y), 0.0, 1.0);
    // }
#else

#ifdef CALC_GLOBAL_XFO_DURING_DRAW
    mat4 bodyMat = getCADBodyMatrix(cadBodyId);
    ivec2 bodyDescAddr = ftoi(cadBodyPixel0.ba);
    mat4 surfaceMat = getDrawItemMatrix(bodyDescAddr, drawItemIndexInBody);
    mat4 modelMatrix = bodyMat * surfaceMat;
#else
    mat4 modelMatrix = getModelMatrix();
    // Note: on mobile GPUs, we get only FP16 math in the
    // fragment shader, causing inaccuracies in modelMatrix
    // calculation. By offsetting the data to the origin
    // we calculate a modelMatrix in the asset space, and
    //  then add it back on during final drawing.
    // modelMatrix[3][0] += assetCentroid.x;
    // modelMatrix[3][1] += assetCentroid.y;
    // modelMatrix[3][2] += assetCentroid.z;
#endif
#endif
    // modelMatrix = mat4(1.0);
    mat4 modelViewMatrix = viewMatrix * modelMatrix;
    mat4 viewProjectionMatrix = projectionMatrix * viewMatrix;

    //////////////////////////////////////////////
    // Vertex Attributes
    
    GLSLBinReader surfaceLayoutDataReader;
    GLSLBinReader_init(surfaceLayoutDataReader, surfaceAtlasLayoutTextureSize, 16);
    vec4 surfaceDataAddr = GLSLBinReader_readVec4(surfaceLayoutDataReader, surfaceAtlasLayoutTexture, surfaceId * 8);
    int surfaceFlags = GLSLBinReader_readInt(surfaceLayoutDataReader, surfaceAtlasLayoutTexture, surfaceId * 8 + 6);

    bool isFan = int(quadDetail.y) == 0;
    vec2 vertexCoords = texCoords * (isFan ? vec2(quadDetail) + vec2(1.0, 1.0) : vec2(quadDetail));

    vec3 normal = getSurfaceNormal(surfaceDataAddr.xy, vertexCoords);
    vec4 pos = vec4(getSurfaceVertex(surfaceDataAddr.xy, vertexCoords).rgb, 1.0);

    bool flippedNormal = testFlag(surfaceFlags, SURFACE_FLAG_FLIPPED_NORMAL);
    if(flippedNormal){
        normal = -normal;
    }
  
    vec4 worldPos = modelMatrix * pos;
    vec3 worldNormal = normalize(mat3(modelMatrix) * normal);

    // if (positions.z > 0.5)
    //   worldPos = vec4(vec3(0.0), 1.0);
    worldPos += vec4(worldNormal * positions.z * normalLength, 0.0);
    
    gl_Position = viewProjectionMatrix * worldPos;

    
    v_textureCoord = texCoords;
    if(testFlag(surfaceFlags, SURFACE_FLAG_FLIPPED_UV))
        v_textureCoord = vec2(v_textureCoord.y, v_textureCoord.x);

    // v_textureCoord.y = 1.0 - v_textureCoord.y; // Flip y
}`

const FRAGMENT_SHADER = `
precision highp float;

<%include file="GLSLCADConstants.glsl"/>
<%include file="GLSLUtils.glsl"/>
<%include file="stack-gl/gamma.glsl"/>
<%include file="materialparams.glsl"/>
<%include file="GLSLBinReader.glsl"/>

uniform color BaseColor;

uniform mat4 cameraMatrix;

varying vec4 v_drawCoords;
varying vec3 v_viewPos;
varying vec3 v_worldPos;
varying vec3 v_viewNormal;
varying vec2 v_textureCoord;

<%include file="GLSLCADSurfaceDrawing.fragmentShader.glsl"/>

#ifdef ENABLE_ES3
out vec4 fragColor;
#endif

void main(void) {

#ifndef ENABLE_ES3
    vec4 fragColor;
#endif

    int cadBodyId = int(floor(v_drawCoords.r + 0.5));
    int drawItemIndexInBody = int(floor(v_drawCoords.g + 0.5));
    int surfaceId = int(floor(v_drawCoords.b + 0.5));
    int trimSetId = int(floor(v_drawCoords.a + 0.5));

    // TODO: pass as varying from pixel shader.
    vec4 cadBodyPixel0 = getCADBodyPixel(cadBodyId, 0);
    int flags = int(floor(cadBodyPixel0.g + 0.5));
            

    //////////////////////////////////////////////
    // Cutaways
    if (testFlag(flags, BODY_FLAG_CUTAWAY)) {
        vec4 cadBodyPixel6 = getCADBodyPixel(cadBodyId, 6);
        vec3 cutNormal = cadBodyPixel6.xyz;
        float cutPlaneDist = cadBodyPixel6.w;
        if (cutaway(v_worldPos, cutNormal, cutPlaneDist)) {
            discard;
        }
    }

    //////////////////////////////////////////////
    // Trimming
    vec4 trimPatchQuad;
    vec3 trimCoords;
    if(trimSetId >= 0) {
        GLSLBinReader trimsetLayoutDataReader;
        GLSLBinReader_init(trimsetLayoutDataReader, trimSetsAtlasLayoutTextureSize, 16);
        trimPatchQuad = GLSLBinReader_readVec4(trimsetLayoutDataReader, trimSetsAtlasLayoutTexture, trimSetId*4);

        if(applyTrim(trimPatchQuad, trimCoords, flags)){
            discard;
            return;
        }
    }

    vec4 baseColor      = vec4(1.0,0.0,0.0,1.0);

//#ifdef ENABLE_INLINE_GAMMACORRECTION
    fragColor.rgb = toGamma(baseColor.rgb);
//#endif

}
`

import { GLCADShader } from './GLCADShader.js'

/** Class representing a GL draw CAD surface normals shader.
 * @extends GLCADShader
 * @ignore
 */
class GLDrawCADSurfaceNormalsShader extends GLCADShader {
  /**
   * Create a GL draw CAD surface normals shader.
   * @param {any} gl - The gl value.
   */
  constructor(gl) {
    super(gl)

    this.setShaderStage('VERTEX_SHADER', GLDrawCADSurfaceNormalsShader_VERTEX_SHADER)
    this.setShaderStage('FRAGMENT_SHADER', FRAGMENT_SHADER)
    this.nonSelectable = true
  }

  /**
   * The getParamDeclarations method.
   * @return {any} - The return value.
   */
  static getParamDeclarations() {
    const paramDescs = super.getParamDeclarations()
    paramDescs.push({
      name: 'BaseColor',
      defaultValue: new Color(1.0, 1.0, 0.5),
    })
    return paramDescs
  }
}

Registry.register('GLDrawCADSurfaceNormalsShader', GLDrawCADSurfaceNormalsShader)

export { GLDrawCADSurfaceNormalsShader_VERTEX_SHADER, GLDrawCADSurfaceNormalsShader }
