//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Copyright 2025 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------

#import "Renderer.h"
#import "../HLSL/triangles.h"
#define IR_PRIVATE_IMPLEMENTATION
#import <metal_irconverter_runtime/metal_irconverter_runtime.h>
#import <simd/simd.h>

#define SIMPLE_BINDINGS 1
#define ARRAY_COUNT(array) (sizeof(array) / sizeof(array[0]))

typedef struct
{
    /*
     * For each transform,
     * - [0] and [1] (x,y) are scale factors
     * - [2] and [3] (z,w) are translation factors
    */
    
    float transform1[4];
    float transform2[4];
    float transform3[4];
} Uniforms;

static const NSUInteger kNumTriangles = 3;
static const NSUInteger kMaxBuffersInFlight = 3;
static const size_t kResourceBufferUniformsSize = sizeof(Uniforms);

static const char *LogicOpLabel[LOGIC_OP_COUNT] = {
    [LOGIC_OP_CLEAR] = "CLEAR",
    [LOGIC_OP_SET] = "SET",
    [LOGIC_OP_COPY] = "COPY",
    [LOGIC_OP_COPY_INVERTED] = "COPY_INVERTED",
    [LOGIC_OP_NOOP] =  "NOOP",
    [LOGIC_OP_INVERT] =  "INVERT",
    [LOGIC_OP_AND] = "AND",
    [LOGIC_OP_NAND] = "NAND",
    [LOGIC_OP_OR] = "OR",
    [LOGIC_OP_NOR] = "NOR",
    [LOGIC_OP_XOR] = "XOR",
    [LOGIC_OP_EQUIV] = "EQUIV",
    [LOGIC_OP_AND_REVERSE] = "AND_REVERSE",
    [LOGIC_OP_AND_INVERTED] = "AND_INVERTED",
    [LOGIC_OP_OR_REVERSE] = "OR_REVERSE",
    [LOGIC_OP_OR_INVERTED] = "OR_INVERTED",
};

typedef struct {
    id<MTLBuffer> TopLevelArgumentBuffer;
#if MSC_BINDING_EXPLICT_LAYOUT
    id<MTLBuffer> DescriptorTableBuffer;
#endif
} MSCBinding;

@implementation Renderer
{
    dispatch_semaphore_t _inFlightSemaphore;
    id<MTLDevice> _device;
    id<MTLCommandQueue> _commandQueue;
    id<MTLRenderPipelineState> _pipelineStates[LogicBlendStateCount][LOGIC_OP_COUNT][kNumTriangles];

    id<MTLBuffer> _dynamicUniformBuffer;
    uint64_t _uniformBufferOffset;
    uint8_t _uniformBufferIndex;
    void* _uniformBufferAddress;
    
    LOGIC_OP _logicOp;
    LogicBlendState _blendState;
    float _translationFactor;
    
    NSView *_overlayView;
}

- (nonnull instancetype)initWithMetalKitView:(nonnull MTKView *)view
{
    self = [super init];
    if(self)
    {
        _device = view.device;
        _logicOp = LOGIC_OP_AND_REVERSE;
        _blendState = LogicBlendStateEnabled;
        _translationFactor = 0.0f;
        _inFlightSemaphore = dispatch_semaphore_create(kMaxBuffersInFlight);
        
        [self _loadMetalWithView:view];
    }
    
    return self;
}

- (void)_loadMetalWithView:(nonnull MTKView *)view;
{
    NSError *error = nil;
    
    /// Load Metal state objects and initialize renderer dependent view properties
    
    view.depthStencilPixelFormat = MTLPixelFormatDepth32Float_Stencil8;
    view.colorPixelFormat = MTLPixelFormatBGRA8Unorm_sRGB;
    view.sampleCount = 1;
    
    /// Create dynamic resource buffer
    
    _dynamicUniformBuffer = [_device newBufferWithLength:(kResourceBufferUniformsSize * kMaxBuffersInFlight) options:MTLResourceStorageModeShared];
    _dynamicUniformBuffer.label = @"UniformBuffer";
    
    /// Create pipeline states
    
    NSURL *libraryURL = [[NSBundle mainBundle] URLForResource:@"triangles" withExtension:@"metallib"];
    id<MTLLibrary> library = [_device newLibraryWithURL:libraryURL error:&error];
    if (library == nil)
    {
        NSLog(@"Failed to created library: error %@", error);
        return;
    }
    
    /// Create vertex function variants for function constant 0 in triangles.hlsl
    /// (See MTL_FUNCTION_CONSTANT(uint, transformIndex, FC_IDX_VERTEX_TRANSFORM_INDEX))
    
    NSMutableArray<id<MTLFunction>> *vertexFunctionVariants = [[NSMutableArray alloc] initWithCapacity:kNumTriangles];
    for (int transformIndex = 0; transformIndex < kNumTriangles; ++transformIndex)
    {
        MTLFunctionConstantValues *functionConstantValues = [[MTLFunctionConstantValues alloc] init];
        IRRuntimeFunctionConstantValue transformIndexFC = { .i0 = transformIndex };
        IRRuntimeSetFunctionConstantValue(functionConstantValues, FC_IDX_VERTEX_TRANSFORM_INDEX, &transformIndexFC);
        id<MTLFunction> vertexFunction = [library newFunctionWithName:@"vertexShader" constantValues:functionConstantValues error:&error];
        if (!vertexFunction)
        {
            NSLog(@"Error: %@", error);
            return;
        }
        [vertexFunctionVariants addObject:vertexFunction];
    }
    
    /// Variants for each blending state.
    NSMutableArray *fragmentFunctionVariants [LogicBlendStateCount] = {
        [[NSMutableArray alloc] initWithCapacity:LOGIC_OP_COUNT * kNumTriangles],
        [[NSMutableArray alloc] initWithCapacity:LOGIC_OP_COUNT * kNumTriangles]
    };
    
    /// Each triangle will have a distinct color
    simd_float4 ColorPalette[] = {
        simd_make_float4(0.9f, 0.2f, 0.2f, 1.0f),
        simd_make_float4(0.5f, 0.9f, 0.5f, 1.0f),
        simd_make_float4(0.3f, 0.1f, 0.9f, 1.0f),
    };
    size_t MAX_PALETTE_COLORS = ARRAY_COUNT(ColorPalette);
    
    for (int op = 0; op < LOGIC_OP_COUNT; ++op)
    {
        MTLFunctionConstantValues *functionConstantValues = [[MTLFunctionConstantValues alloc] init];
        LOGIC_OP logicOp = (LOGIC_OP)op;
        IRRuntimeFunctionConstantValue logicOpFC = { .i0 = logicOp };
        IRRuntimeSetFunctionConstantValue(functionConstantValues, FC_IDX_FRAGMENT_LOGIC_OP_MODE, &logicOpFC);
        
        for (int blendState = 0; blendState < LogicBlendStateCount; ++blendState)
        {
            IRRuntimeFunctionConstantValue blendStateFC = { .i0 = blendState };
            IRRuntimeSetFunctionConstantValue(functionConstantValues, FC_IDX_FRAGMENT_BLEND_STATE, &blendStateFC);
            
            NSMutableArray<id<MTLFunction>> *fragVariants = fragmentFunctionVariants[blendState];
            for (int tri = 0; tri < kNumTriangles; ++tri)
            {
                int colorIdx = ((op * kNumTriangles) + tri) % MAX_PALETTE_COLORS;
                const simd_float4 *color = &ColorPalette[colorIdx];
                IRRuntimeFunctionConstantValue colorFC = { .f0 = color->x, .f1 = color->y, .f2 = color->z, .f3 = color->w };
                IRRuntimeSetFunctionConstantValue(functionConstantValues, FC_IDX_FRAGMENT_CONSTANT_COLOR, &colorFC);
                
                id<MTLFunction> fragmentFunction = [library newFunctionWithName:@"pixelShader" constantValues:functionConstantValues error:&error];
                if (!fragmentFunction)
                {
                    NSLog(@"Error: %@", error);
                    return;
                }
                [fragVariants addObject:fragmentFunction];
            }
        }
    }
    
    /// Create all pipeline variants: For each blend state, pair each triangle vertex function with a logic operation fragment function
    
    MTLRenderPipelineDescriptor *pipelineStateDescriptor = [[MTLRenderPipelineDescriptor alloc] init];
    pipelineStateDescriptor.rasterSampleCount = view.sampleCount;
    pipelineStateDescriptor.colorAttachments[0].pixelFormat = view.colorPixelFormat;
    pipelineStateDescriptor.colorAttachments[0].blendingEnabled = NO;
    pipelineStateDescriptor.depthAttachmentPixelFormat = view.depthStencilPixelFormat;
    pipelineStateDescriptor.stencilAttachmentPixelFormat = view.depthStencilPixelFormat;
    
    for (int logicOp = 0; logicOp < LOGIC_OP_COUNT; ++logicOp)
    {
        for (int blendState = 0; blendState < LogicBlendStateCount; ++blendState)
        {
            NSMutableArray<id<MTLFunction>> *fragVariants = fragmentFunctionVariants[blendState];
            for (int tri = 0; tri < kNumTriangles; ++tri)
            {
                int idx = (logicOp * kNumTriangles) + tri;
                id<MTLFunction> fragmentFunction = fragVariants[idx];
                id<MTLFunction> vertexFunction = vertexFunctionVariants[tri];
                
                NSString *label = [NSString stringWithFormat:@"%s_transform%d", LogicOpLabel[logicOp], tri];
                pipelineStateDescriptor.label = label;
                pipelineStateDescriptor.vertexFunction = vertexFunction;
                pipelineStateDescriptor.fragmentFunction = fragmentFunction;
                
                id<MTLRenderPipelineState> pipelineState = [_device newRenderPipelineStateWithDescriptor:pipelineStateDescriptor error:&error];
                if (pipelineState == nil)
                {
                    NSLog(@"Pipeline state creation failed: %@", error);
                    return;
                }

                _pipelineStates[blendState][logicOp][tri] = pipelineState;
            }
        }
    }
    
    _commandQueue = [_device newCommandQueue];
}

- (void)_updateState
{
    /// Update any game state before encoding renderint commands to our drawable

    Uniforms *uniforms = (Uniforms *)_uniformBufferAddress;
    {
        float scale_size = 0.6f;
        float translate_pos = sinf(_translationFactor) * 0.3;
        
        // translate along x (starting left to right)
        uniforms->transform1[0] = scale_size;
        uniforms->transform1[1] = scale_size;
        uniforms->transform1[2] = translate_pos;
        uniforms->transform1[3] = 0.0;
        
        // flip triangle
        uniforms->transform2[0] = scale_size;
        uniforms->transform2[1] = -scale_size;
        uniforms->transform2[2] = 0.0;
        uniforms->transform2[3] = 0.0;
        
        // translate along x (starting right to left)
        uniforms->transform3[0] = scale_size;
        uniforms->transform3[1] = scale_size;
        uniforms->transform3[2] = -translate_pos;
        uniforms->transform3[3] = 0.0f;
    }

    _translationFactor += .01;
}

- (MSCBinding)_encodeMSCBindingBuffers
{
    // Set metal-shaderconverter bindings
    MSCBinding binding;
    
    // Create descriptor table buffer containing the HLSL CBV buffer for the uniforms
    IRDescriptorTableEntry entry;
    IRDescriptorTableSetBuffer(&entry, _dynamicUniformBuffer.gpuAddress + _uniformBufferOffset, 0);
    id<MTLBuffer> descriptorTableBuffer = [_device newBufferWithBytes:&entry length:sizeof(entry) options:MTLResourceStorageModeShared];
    
#if MSC_BINDING_EXPLICT_LAYOUT
    // Set the top-level argument buffer to reference the descriptor table buffer
    binding.DescriptorTableBuffer = descriptorTableBuffer;
    binding.DescriptorTableBuffer.label = @"MSC_DescriptorTable";
    
    uint64_t descriptorTableBufferAddress = descriptorTableBuffer.gpuAddress;
    binding.TopLevelArgumentBuffer = [_device newBufferWithBytes:&descriptorTableBufferAddress length:sizeof(descriptorTableBufferAddress) options:MTLResourceStorageModeShared];
    binding.TopLevelArgumentBuffer.label = @"MSC_TopLevelArgumentBuffer";
#else
    assert(MSC_BINDING_AUTOMATIC_LAYOUT);
    // Set the top-level argument buffer with contents of the descriptor table
    binding.TopLevelArgumentBuffer = descriptorTableBuffer;
    binding.TopLevelArgumentBuffer.label = @"MSC_TopLevelArgumentBuffer";
#endif
    
    return binding;
}

- (void)_updateDynamicBufferState
{
    /// Update the state of our uniform buffers before rendering

    _uniformBufferIndex = (_uniformBufferIndex + 1) % kMaxBuffersInFlight;

    _uniformBufferOffset = kResourceBufferUniformsSize * _uniformBufferIndex;

    _uniformBufferAddress = ((uint8_t*)_dynamicUniformBuffer.contents) + _uniformBufferOffset;
}

- (void)drawInMTKView:(nonnull MTKView *)view
{
    /// Per frame updates here
    dispatch_semaphore_wait(_inFlightSemaphore, DISPATCH_TIME_FOREVER);
    
    id <MTLCommandBuffer> commandBuffer = [_commandQueue commandBuffer];
    commandBuffer.label = [@"CommandBuffer" stringByAppendingFormat:@": UniformBuffer[%@]", @(_uniformBufferIndex)];
    
    __block dispatch_semaphore_t block_sema = _inFlightSemaphore;
    [commandBuffer addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
        dispatch_semaphore_signal(block_sema);
    }];
    
    [self _updateDynamicBufferState];
    
    [self _updateState];
    
    MSCBinding mscBinding = [self _encodeMSCBindingBuffers];
    
    LogicBlendState blendState = _blendState;
    assert((blendState >= 0) && (blendState <= LogicBlendStateCount));
    
    LOGIC_OP logicOp = _logicOp;
    assert((logicOp >= 0) && (logicOp <= LOGIC_OP_COUNT));
    
    /// Delay getting the currentRenderPassDescriptor until we absolutely need it to avoid
    ///   holding onto the drawable and blocking the display pipeline any longer than necessary
    MTLRenderPassDescriptor* renderPassDescriptor = view.currentRenderPassDescriptor;
    if (renderPassDescriptor != nil)
    {
        renderPassDescriptor.colorAttachments[0].clearColor = MTLClearColorMake(0.15, 0.15, 0.8, 1.0);
        renderPassDescriptor.colorAttachments[0].storeAction = MTLStoreActionStore;
        
        id <MTLRenderCommandEncoder> renderEncoder = [commandBuffer renderCommandEncoderWithDescriptor:renderPassDescriptor];
        [renderEncoder setLabel:@"RenderEncoder"];
        [renderEncoder pushDebugGroup:@"Triangles"];
        [renderEncoder setFrontFacingWinding:MTLWindingCounterClockwise];
        [renderEncoder setCullMode:MTLCullModeNone];
        [renderEncoder useResource:_dynamicUniformBuffer usage:MTLResourceUsageRead stages:MTLRenderStageVertex];
        [renderEncoder setVertexBuffer:mscBinding.TopLevelArgumentBuffer offset:0 atIndex:kIRArgumentBufferBindPoint];
#if MSC_BINDING_EXPLICT_LAYOUT
        [renderEncoder useResource:mscBinding.DescriptorTableBuffer usage:MTLResourceUsageRead stages:MTLRenderStageVertex];
#endif
        for (int tri = 0; tri < kNumTriangles; ++tri)
        {
            id<MTLRenderPipelineState> pipelineState = _pipelineStates[blendState][logicOp][tri];
            [renderEncoder setRenderPipelineState:pipelineState];
            IRRuntimeDrawPrimitives(renderEncoder, MTLPrimitiveTypeTriangle, 0, 3, 1, 0);
        }
        
        [renderEncoder popDebugGroup];
        [renderEncoder endEncoding];
        [commandBuffer presentDrawable:view.currentDrawable];
    }
    
    [commandBuffer commit];
}

- (void)mtkView:(nonnull MTKView *)view drawableSizeWillChange:(CGSize)size
{
    dispatch_async(dispatch_get_main_queue(), ^{
        [self _updateOverlayWithView:view];
    });
}

- (void)_updateOverlayWithView:(MTKView *)mtkView
{
    NSView *overlayView = [[NSView alloc] initWithFrame:mtkView.bounds];
    {
        overlayView.autoresizingMask = NSViewWidthSizable | NSViewHeightSizable;
        overlayView.wantsLayer = YES;
        overlayView.layer.backgroundColor = NSColor.clearColor.CGColor;
    }
    
    CGFloat x = 20;
    CGFloat y = mtkView.bounds.size.height - 40;
    CGFloat spacing = 20;
    
    // Create the toggle button
    NSButton *toggleButton = [[NSButton alloc] initWithFrame:NSMakeRect(x, y, 200, 24)];
    {
        toggleButton.title = @"Enable Blending";
        toggleButton.buttonType = NSButtonTypeSwitch;
        toggleButton.target = self;
        toggleButton.action = @selector(toggleButtonPressed:);
        toggleButton.state = (_blendState == LogicBlendStateEnabled) ? NSControlStateValueOn : NSControlStateValueOff;
        [overlayView addSubview:toggleButton];
    }
    
    // Create the logic operation buttons
    NSFont *font = [NSFont fontWithName:@"SF Mono Light" size:10];
    assert(font != nil);
    NSDictionary *attributes = @{ NSFontAttributeName : font };
    for (NSInteger i = 0, count = ARRAY_COUNT(LogicOpLabel); i < count; ++i)
    {
        const char *label = LogicOpLabel[i];
        NSAttributedString *atrributedTitle = [[NSAttributedString alloc] initWithString:@(label) attributes:attributes];
        NSButton *radioButton = [[NSButton alloc] initWithFrame:NSMakeRect(x, y - (i + 1) * spacing, 160, 24)];
        {
            radioButton.attributedTitle = atrributedTitle;
            radioButton.buttonType = NSButtonTypeRadio;
            radioButton.target = self;
            radioButton.action = @selector(radioButtonSelected:);
            radioButton.tag = i;
            radioButton.state = (i == _logicOp) ? NSControlStateValueOn : NSControlStateValueOff;
            [overlayView addSubview:radioButton];
        }
    }
    
    [_overlayView removeFromSuperview];
    _overlayView = overlayView;
    [mtkView.superview addSubview:_overlayView positioned:NSWindowAbove relativeTo:mtkView];
}

- (void)radioButtonSelected:(NSButton *)sender
{
    LOGIC_OP logicOp = (LOGIC_OP)sender.tag;
    _logicOp = logicOp;
}

- (void)toggleButtonPressed:(NSButton *)sender
{
    if (sender.state == NSControlStateValueOn)
    {
        _blendState = LogicBlendStateEnabled;
    }
    else
    {
        _blendState = LogicBlendStateDisabled;
    }
}

@end

