/*
 * Copyright 2019,2020 Sony Corporation
 */

#include "Compositor/DisplayProjector.h"

#include <ClearQuad.h>
#include <GlobalShader.h>
#include <ScreenRendering.h>
#include <CommonRenderResources.h>

/*
 * Homography transformation shader
 */

class FSRDisplayHomographyTransformPS : public FGlobalShader
{
	DECLARE_SHADER_TYPE(FSRDisplayHomographyTransformPS, Global);

public:
	static bool ShouldCompilePermutation(const FGlobalShaderPermutationParameters& Parameters)
	{
		return IsFeatureLevelSupported(Parameters.Platform, ERHIFeatureLevel::SM5);
	}

	FSRDisplayHomographyTransformPS() { }

	FSRDisplayHomographyTransformPS(const ShaderMetaType::CompiledShaderInitializerType& Initializer)
		: FGlobalShader(Initializer)
	{
		// Bind shader inputs.
		InTexture.Bind(Initializer.ParameterMap, TEXT("InTexture"));
		InTextureSampler.Bind(Initializer.ParameterMap, TEXT("InTextureSampler"));
		InHomographyMatrix.Bind(Initializer.ParameterMap, TEXT("InHomographyMatrix"));
	}

#if ENGINE_MINOR_VERSION<=24
	void SetParameters(FRHICommandList& RHICmdList, FRHITexture* Texture, FMatrix HomographyMatrix)
	{
		FRHIPixelShader* PixelShaderRHI = GetPixelShader();
		FRHISamplerState* SamplerStateRHI = TStaticSamplerState<SF_Point>::GetRHI();

		SetTextureParameter(RHICmdList, PixelShaderRHI, InTexture, InTextureSampler, SamplerStateRHI, Texture);
		SetShaderValue(RHICmdList, PixelShaderRHI, InHomographyMatrix, HomographyMatrix);
	}

	virtual bool Serialize(FArchive& Ar) override
	{
		bool bShaderHasOutdatedParameters = FGlobalShader::Serialize(Ar);

		// Serialize shader inputs.
		Ar << InTexture;
		Ar << InTextureSampler;
		Ar << InHomographyMatrix;

		return bShaderHasOutdatedParameters;
	}
#else // ENGINE_MINOR_VERSION<=24
	template<typename TShaderRHIParamRef>
	void SetParameters(FRHICommandList& RHICmdList, const TShaderRHIParamRef ShaderRHI, FRHITexture* Texture, FMatrix HomographyMatrix)
	{
		FRHISamplerState* SamplerStateRHI = TStaticSamplerState<SF_Point>::GetRHI();

		SetTextureParameter(RHICmdList, ShaderRHI, InTexture, InTextureSampler, SamplerStateRHI, Texture);
		SetShaderValue(RHICmdList, ShaderRHI, InHomographyMatrix, HomographyMatrix);
	}
#endif // ENGINE_MINOR_VERSION<=24

private:
	// Shader parameters.
#if ENGINE_MINOR_VERSION<=24
	FShaderResourceParameter InTexture;
	FShaderResourceParameter InTextureSampler;
	FShaderParameter InHomographyMatrix;
#else // ENGINE_MINOR_VERSION<=24
	LAYOUT_FIELD(FShaderResourceParameter, InTexture);
	LAYOUT_FIELD(FShaderResourceParameter, InTextureSampler);
	LAYOUT_FIELD(FShaderParameter, InHomographyMatrix);
#endif // ENGINE_MINOR_VERSION<=24
};

IMPLEMENT_SHADER_TYPE(, FSRDisplayHomographyTransformPS, TEXT("/Plugin/SRDisplayPlugin/Private/HomographyTransformShader.usf"), TEXT("MainPS"), SF_Pixel)

namespace sr_display
{
	FDisplayProjector::FDisplayProjector(FVector4& LeftBottomPosition, FVector4& LeftTopPosition, FVector4& RightBottomPosition, FVector4& RightTopPosition)
	{
		static const FName RendererModuleName("Renderer");
		RendererModule = FModuleManager::GetModulePtr<IRendererModule>(RendererModuleName);

		SetDisplayCornersPosition(LeftBottomPosition, LeftTopPosition, RightBottomPosition, RightTopPosition);
	}

	FDisplayProjector::~FDisplayProjector()
	{
		RendererModule = nullptr;
	}

	void FDisplayProjector::SetDisplayCornersPosition(FVector4& LeftBottomPosition, FVector4& LeftTopPosition, FVector4& RightBottomPosition, FVector4& RightTopPosition)
	{
		LeftBottomCornerPosition = LeftBottomPosition;
		LeftTopCornerPosition = LeftTopPosition;
		RightBottomCornerPosition = RightBottomPosition;
		RightTopCornerPosition = RightTopPosition;
	}

	bool FDisplayProjector::ExecuteDisplayProjection(FRHICommandListImmediate& RHICmdList, EStereoscopicPass StereoPass, FRHITexture2D* BackBuffer, FRHITexture2D* SrcTexture, FVector2D WindowSize, FMatrix ViewProjectionMatrix, uint32 SourceWidth, uint32 SourceHeight) const
	{
		const FVector2D ViewportSize = FVector2D(WindowSize.X, WindowSize.Y);

		// 4 corners of LFB(Screen Position)[range: 0.0 ~ 1.0]
		FVector2D LeftBottomPositionInScreen;
		FVector2D LeftTopPositionInScreen;
		FVector2D RightBottomPositionInScreen;
		FVector2D RightTopPositionInScreen;

		// Calculate screen position
		ConvertWorldPositionToScreen(LeftBottomCornerPosition, ViewProjectionMatrix, LeftBottomPositionInScreen);
		ConvertWorldPositionToScreen(LeftTopCornerPosition, ViewProjectionMatrix, LeftTopPositionInScreen);
		ConvertWorldPositionToScreen(RightBottomCornerPosition, ViewProjectionMatrix, RightBottomPositionInScreen);
		ConvertWorldPositionToScreen(RightTopCornerPosition, ViewProjectionMatrix, RightTopPositionInScreen);

		// Calculate homography matrix
		FMatrix HomographyMatrix;
		CalculateHomographyMatrix(LeftBottomPositionInScreen, LeftTopPositionInScreen, RightBottomPositionInScreen, RightTopPositionInScreen, HomographyMatrix);

		FRHIRenderPassInfo RPInfoTempRight(BackBuffer, ERenderTargetActions::Load_Store);
		RHICmdList.BeginRenderPass(RPInfoTempRight, TEXT("FSRDisplaySystem_ProcessHomographyTransformation"));
		{
			DrawClearQuad(RHICmdList, FLinearColor(0.f, 0.f, 0.f, 1.f));
			RHICmdList.SetViewport(0, 0, 0.f, ViewportSize.X, ViewportSize.Y, 1.f);

			FGraphicsPipelineStateInitializer GraphicsPSOInit;
			RHICmdList.ApplyCachedRenderTargets(GraphicsPSOInit);

			const auto FeatureLevel = GMaxRHIFeatureLevel;
#if ENGINE_MINOR_VERSION<=24
			TShaderMap<FGlobalShaderType>* ShaderMap = GetGlobalShaderMap(FeatureLevel);
#else // ENGINE_MINOR_VERSION<=24
			FGlobalShaderMap* ShaderMap = GetGlobalShaderMap(FeatureLevel);
#endif // ENGINE_MINOR_VERSION<=24
			TShaderMapRef<FScreenVS> VertexShader(ShaderMap);
			TShaderMapRef<FSRDisplayHomographyTransformPS> PixelShader(ShaderMap);

			GraphicsPSOInit.BlendState = TStaticBlendState<>::GetRHI();
			GraphicsPSOInit.RasterizerState = TStaticRasterizerState<>::GetRHI();
			GraphicsPSOInit.DepthStencilState = TStaticDepthStencilState<false, CF_Always>::GetRHI();
			GraphicsPSOInit.PrimitiveType = PT_TriangleList;
			GraphicsPSOInit.BoundShaderState.VertexDeclarationRHI = GFilterVertexDeclaration.VertexDeclarationRHI;
#if ENGINE_MINOR_VERSION<=24
			GraphicsPSOInit.BoundShaderState.VertexShaderRHI = GETSAFERHISHADER_VERTEX(*VertexShader);
			GraphicsPSOInit.BoundShaderState.PixelShaderRHI = GETSAFERHISHADER_PIXEL(*PixelShader);
#else // ENGINE_MINOR_VERSION<=24
			GraphicsPSOInit.BoundShaderState.VertexShaderRHI = VertexShader.GetVertexShader();
			GraphicsPSOInit.BoundShaderState.PixelShaderRHI = PixelShader.GetPixelShader();
#endif // ENGINE_MINOR_VERSION<=24
			SetGraphicsPipelineState(RHICmdList, GraphicsPSOInit);

			// Create left or right texture buffer
			FRHIResourceCreateInfo Info;
			FTexture2DRHIRef TempTextureBuffer = RHICreateTexture2D(SourceWidth / 2, SourceHeight, PF_A2B10G10R10, 1, 1, TexCreate_ShaderResource | TexCreate_RenderTargetable, Info);
			if (StereoPass == eSSP_LEFT_EYE)
			{
				// Copy left texture to TempTextureBuffer
				RHICmdList.CopyToResolveTarget(SrcTexture, TempTextureBuffer,
					FResolveParams(FResolveRect(0, 0, SourceWidth / 2, SourceHeight), CubeFace_PosX, 0, 0, 0, FResolveRect(0, 0, SourceWidth / 2, SourceHeight)));
			}
			else
			{
				// Copy right texture to TempTextureBuffer
				RHICmdList.CopyToResolveTarget(SrcTexture, TempTextureBuffer,
					FResolveParams(FResolveRect(SourceWidth / 2, 0, SourceWidth, SourceHeight), CubeFace_PosX, 0, 0, 0, FResolveRect(0, 0, SourceWidth / 2, SourceHeight)));
			}

#if ENGINE_MINOR_VERSION<=24
			PixelShader->SetParameters(RHICmdList, TempTextureBuffer, HomographyMatrix);

			RendererModule->DrawRectangle(
				RHICmdList,
				0.f, 0.f,
				ViewportSize.X, ViewportSize.Y,
				0.f, 0.f,
				1.f, 1.f,
				FIntPoint(ViewportSize.X, ViewportSize.Y),
				FIntPoint(1, 1),
				*VertexShader,
				EDRF_Default);
#else // ENGINE_MINOR_VERSION<=24
			PixelShader->SetParameters(RHICmdList, PixelShader.GetPixelShader(), TempTextureBuffer, HomographyMatrix);

			RendererModule->DrawRectangle(
				RHICmdList,
				0.f, 0.f,
				ViewportSize.X, ViewportSize.Y,
				0.f, 0.f,
				1.f, 1.f,
				FIntPoint(ViewportSize.X, ViewportSize.Y),
				FIntPoint(1, 1),
				VertexShader,
				EDRF_Default);
#endif // ENGINE_MINOR_VERSION<=24
		}
		RHICmdList.EndRenderPass();

		return true;
	}

	void FDisplayProjector::ConvertWorldPositionToScreen(const FVector4& WorldPosition, const FMatrix& ViewProjectionMatrix, FVector2D& OutScreenPosition) const
	{
		FPlane Result = ViewProjectionMatrix.TransformFVector4(WorldPosition);

		// the result of this will be x and y coords in -1..1 projection space
		const float RHW = 1.f / Result.W;
		FPlane PosInScreenSpace = FPlane(Result.X * RHW, Result.Y * RHW, Result.Z * RHW, Result.W);

		// Move from projection space to normalized 0..1 UI space
		const float NormalizedX = (PosInScreenSpace.X + 1.f) * 0.5f;
		const float NormalizedY = 1.f - ((PosInScreenSpace.Y + 1.f) * 0.5f);

		OutScreenPosition = FVector2D(NormalizedX, NormalizedY);
	}

	void FDisplayProjector::CalculateHomographyMatrix(FVector2D InLeftBottomPosition, FVector2D InLeftTopPosition, FVector2D InRightBottomPosition, FVector2D InRightTopPosition, FMatrix& OutMatrix) const
	{
		FVector2D LeftBottomPosition = InLeftBottomPosition;
		FVector2D LeftTopPosition = InLeftTopPosition;
		FVector2D RightBottomPosition = InRightBottomPosition;
		FVector2D RightTopPosition = InRightTopPosition;

		// Conversion factor
		float A = RightTopPosition.X - RightBottomPosition.X;
		float B = LeftBottomPosition.X - RightBottomPosition.X;
		float C = LeftTopPosition.X - LeftBottomPosition.X - RightTopPosition.X + RightBottomPosition.X;
		float D = RightTopPosition.Y - RightBottomPosition.Y;
		float E = LeftBottomPosition.Y - RightBottomPosition.Y;
		float F = LeftTopPosition.Y - LeftBottomPosition.Y - RightTopPosition.Y + RightBottomPosition.Y;

		float Out02 = LeftTopPosition.X;
		float Out12 = LeftTopPosition.Y;
		float Out21 = (C * D - A * F) / (B * D - A * E);
		float Out20 = (C * E - B * F) / (A * E - B * D);
		float Out00 = RightTopPosition.X - LeftTopPosition.X + Out20 * RightTopPosition.X;
		float Out01 = LeftBottomPosition.X - LeftTopPosition.X + Out21 * LeftBottomPosition.X;
		float Out10 = RightTopPosition.Y - LeftTopPosition.Y + Out20 * RightTopPosition.Y;
		float Out11 = LeftBottomPosition.Y - LeftTopPosition.Y + Out21 * LeftBottomPosition.Y;

		OutMatrix = FMatrix(
				FPlane(Out00, Out01, Out02, 0.f),
				FPlane(Out10, Out11, Out12, 0.f),
				FPlane(Out20, Out21, 1.f, 0.f),
				FPlane(0.f, 0.f, 0.f, 0.f)
			);
	}

} // namespace sr_display
