/*
--------------------------------------------------------------------------------

	Shaders Shaders

	Copyright (C) 2024 None
	Apache License 2.0

	Pass: Contrast adaptive sharpening and final output

--------------------------------------------------------------------------------
*/

//======// Utility //=============================================================================//

#include "/lib/Utility.glsl"

//======// Config //==============================================================================//

#include "/config.glsl"

//======// Output //==============================================================================//

out vec3 finalOut;

//======// Uniform //=============================================================================//

uniform sampler2D cloudMapTex;
#include "/lib/universal/Uniform.glsl"

//======// SSBO //================================================================================//

#include "/lib/universal/SSBO.glsl"

//======// Function //============================================================================//

#include "/lib/universal/Random.glsl"

// Contrast Adaptive Sharpening (CAS)
// Reference: Lou Kramer, FidelityFX CAS, AMD Developer Day 2019,
// https://gpuopen.com/wp-content/uploads/2019/07/FidelityFX-CAS.pptx
// https://github.com/GPUOpen-Effects/FidelityFX-CAS
vec3 FFXCasFilter(in ivec2 texel, in float sharpness) {
	#define CasLoad(offset) texelFetchOffset(colortex0, texel, 0, offset).rgb

	#ifndef CAS_ENABLED
		return CasLoad(ivec2(0, 0));
	#endif

	// a b c
	// d e f
	// g h i
	vec3 a = CasLoad(ivec2(-1, -1));
	vec3 b = CasLoad(ivec2( 0, -1));
	vec3 c = CasLoad(ivec2( 1, -1));
	vec3 d = CasLoad(ivec2(-1,  0));
	vec3 e = CasLoad(ivec2( 0,  0));
	vec3 f = CasLoad(ivec2( 1,  0));
	vec3 g = CasLoad(ivec2(-1,  1));
	vec3 h = CasLoad(ivec2( 0,  1));
	vec3 i = CasLoad(ivec2( 1,  1));

	// Soft min and max.
	//  a b c             b
	//  d e f * 0.5  +  d e f * 0.5
	//  g h i             h
	// These are 2.0x bigger (factored out the extra multiply).
	vec3 minCol = min(min(min(d, e), min(f, b)), h);
		minCol += min(min(min(a, c), min(g, i)), minCol);
	vec3 maxCol = max(max(max(d, e), max(f, b)), h);
		maxCol += max(max(max(a, c), max(g, i)), maxCol);

    vec3 amp = approxSqrt(saturate(min(minCol, 2.0 - maxCol) / maxCol));

	// Filter shape.
	//  0 w 0
	//  w 1 w
	//  0 w 0
    vec3 w = amp * -rcp(mix(8.0, 5.0, sharpness));
	return saturate(((b + d + f + h) * w + e) / (1.0 + 4.0 * w));
}

//================================================================================================//

// Approximation from SMAA presentation from siggraph 2016
vec3 textureCatmullRomFast(in sampler2D tex, in vec2 position, in const float sharpness) {
	//vec2 screenSize = textureSize(sampler, 0);
	//vec2 viewPixelSize = 1.0 / screenSize;

	//vec2 position = screenSize * coord;
	vec2 centerPosition = floor(position - 0.5) + 0.5;
	vec2 f = position - centerPosition;
	vec2 f2 = f * f;
	vec2 f3 = f * f2;

	vec2 w0 = -sharpness        * f3 + 2.0 * sharpness         * f2 - sharpness * f;
	vec2 w1 = (2.0 - sharpness) * f3 - (3.0 - sharpness)       * f2 + 1.0;
	vec2 w2 = (sharpness - 2.0) * f3 + (3.0 - 2.0 * sharpness) * f2 + sharpness * f;
	vec2 w3 = sharpness         * f3 - sharpness               * f2;

	vec2 w12 = w1 + w2;

	vec2 tc0 = viewPixelSize * (centerPosition - 1.0);
	vec2 tc3 = viewPixelSize * (centerPosition + 2.0);
	vec2 tc12 = viewPixelSize * (centerPosition + w2 / w12);

	float l0 = w12.x * w0.y;
	float l1 = w0.x  * w12.y;
	float l2 = w12.x * w12.y;
	float l3 = w3.x  * w12.y;
	float l4 = w12.x * w3.y;

	vec3 color =  texture(tex, vec2(tc12.x, tc0.y )).rgb * l0
				+ texture(tex, vec2(tc0.x,  tc12.y)).rgb * l1
				+ texture(tex, vec2(tc12.x, tc12.y)).rgb * l2
				+ texture(tex, vec2(tc3.x,  tc12.y)).rgb * l3
				+ texture(tex, vec2(tc12.x, tc3.y )).rgb * l4;

	return color / (l0 + l1 + l2 + l3 + l4);
}

#include "/lib/universal/TextRenderer.glsl"

void HistogramDisplay(inout vec3 color, in ivec2 texel) {
    const int binWidth = 2;

    if (all(lessThan(texel, ivec2(HISTOGRAM_BIN_COUNT * binWidth, 256)))) {
		int binIndex = texel.x / binWidth;
		uint binValue = global.exposure.histogram[binIndex];

		color = vec3(step(texel.y + 1, binValue));
	}
}

//======// Main //================================================================================//
void main() {
    ivec2 screenTexel = ivec2(gl_FragCoord.xy);

	#ifdef DEBUG_BLOOM_TILES
		finalOut = texelFetch(colortex4, screenTexel, 0).rgb;
	#else
		if (abs(MC_RENDER_QUALITY - 1.0) < 1e-2) {
			finalOut = FFXCasFilter(screenTexel, CAS_STRENGTH);
		} else {
			finalOut = textureCatmullRomFast(colortex0, gl_FragCoord.xy * MC_RENDER_QUALITY, 0.5);
		}
	#endif

	// Apply gamma correction
	finalOut = linearToSRGBApprox(finalOut);

	// Text display
	#if 0
		finalOut += renderText(ivec2(100), 3, vec3(0.5));
		finalOut = saturate(finalOut);
	#endif

	// Time display
	#if 0
		const ivec2 size = ivec2(30, 200);
		const int strokewidth = 3;
		const ivec2 start = ivec2(60, 200);
		const ivec2 end = start + size;
		const int center = (start.y + end.y) >> 1;

		if (clamp(screenTexel, start - strokewidth, end + strokewidth) == screenTexel) {
			finalOut = vec3(0.0);
			if (clamp(screenTexel, start, end) == screenTexel && clamp(screenTexel.y, center - 1, center + 1) != screenTexel.y) {
				float t = 1.0 - sunAngle * 2.0 + step(0.5, sunAngle);
				if (screenTexel.y > start.y + t * size.y) {
					finalOut = sunAngle < 0.5 ? vec3(0.2, 0.7, 1.0) : vec3(0.08, 0.24, 0.4);
				} else {
					finalOut = vec3(1.0);
				}
			}
		}
	#endif

	#ifdef DEBUG_CLOUD_SHADOWS
		if (all(lessThan(screenTexel, textureSize(cloudShadowTex, 0)))) {
			finalOut = vec3(texelFetch(cloudShadowTex, screenTexel, 0).x);
		}
	#endif

	#ifdef DEBUG_CLOUD_MAP
		ivec2 tempTexel = screenTexel;
		if (all(lessThan(tempTexel, textureSize(cloudMapTex, 0)))) {
			finalOut = vec3(texelFetch(cloudMapTex, tempTexel, 0).x);
		}
		tempTexel -= ivec2(textureSize(cloudMapTex, 0).x, 0);
		if (all(greaterThanEqual(tempTexel, ivec2(0)) && lessThan(tempTexel, textureSize(cloudMapTex, 0)))) {
			finalOut = vec3(texelFetch(cloudMapTex, tempTexel, 0).y);
		}
	#endif

	#ifdef DEBUG_SKY_COLOR
		if (all(lessThan(gl_FragCoord.xy * viewPixelSize, vec2(0.4)))) finalOut = skyColor;
	#endif

	#if 0
		HistogramDisplay(finalOut, screenTexel);
	#endif

	// Apply bayer dithering to reduce banding artifacts
	finalOut += (bayer16(gl_FragCoord.xy) - 0.5) * r255;

	// Update SSBO
	global.prevWorldTime = worldTime;
}