import argparse
import os

import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import AutoTokenizer

from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel


def main(args):
    # checkpoint from https://huggingface.co/Shitao/OmniGen-v1

    if not os.path.exists(args.origin_ckpt_path):
        print("Model not found, downloading...")
        cache_folder = os.getenv("HF_HUB_CACHE")
        args.origin_ckpt_path = snapshot_download(
            repo_id=args.origin_ckpt_path,
            cache_dir=cache_folder,
            ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"],
        )
        print(f"Downloaded model to {args.origin_ckpt_path}")

    ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors")
    ckpt = load_file(ckpt, device="cpu")

    mapping_dict = {
        "pos_embed": "patch_embedding.pos_embed",
        "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
        "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
        "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
        "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
        "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
        "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
        "final_layer.linear.weight": "proj_out.weight",
        "final_layer.linear.bias": "proj_out.bias",
        "time_token.mlp.0.weight": "time_token.linear_1.weight",
        "time_token.mlp.0.bias": "time_token.linear_1.bias",
        "time_token.mlp.2.weight": "time_token.linear_2.weight",
        "time_token.mlp.2.bias": "time_token.linear_2.bias",
        "t_embedder.mlp.0.weight": "t_embedder.linear_1.weight",
        "t_embedder.mlp.0.bias": "t_embedder.linear_1.bias",
        "t_embedder.mlp.2.weight": "t_embedder.linear_2.weight",
        "t_embedder.mlp.2.bias": "t_embedder.linear_2.bias",
        "llm.embed_tokens.weight": "embed_tokens.weight",
    }

    converted_state_dict = {}
    for k, v in ckpt.items():
        if k in mapping_dict:
            converted_state_dict[mapping_dict[k]] = v
        elif "qkv" in k:
            to_q, to_k, to_v = v.chunk(3)
            converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q
            converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k
            converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v
        elif "o_proj" in k:
            converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v
        else:
            converted_state_dict[k[4:]] = v

    transformer = OmniGenTransformer2DModel(
        rope_scaling={
            "long_factor": [
                1.0299999713897705,
                1.0499999523162842,
                1.0499999523162842,
                1.0799999237060547,
                1.2299998998641968,
                1.2299998998641968,
                1.2999999523162842,
                1.4499999284744263,
                1.5999999046325684,
                1.6499998569488525,
                1.8999998569488525,
                2.859999895095825,
                3.68999981880188,
                5.419999599456787,
                5.489999771118164,
                5.489999771118164,
                9.09000015258789,
                11.579999923706055,
                15.65999984741211,
                15.769999504089355,
                15.789999961853027,
                18.360000610351562,
                21.989999771118164,
                23.079999923706055,
                30.009998321533203,
                32.35000228881836,
                32.590003967285156,
                35.56000518798828,
                39.95000457763672,
                53.840003967285156,
                56.20000457763672,
                57.95000457763672,
                59.29000473022461,
                59.77000427246094,
                59.920005798339844,
                61.190006256103516,
                61.96000671386719,
                62.50000762939453,
                63.3700065612793,
                63.48000717163086,
                63.48000717163086,
                63.66000747680664,
                63.850006103515625,
                64.08000946044922,
                64.760009765625,
                64.80001068115234,
                64.81001281738281,
                64.81001281738281,
            ],
            "short_factor": [
                1.05,
                1.05,
                1.05,
                1.1,
                1.1,
                1.1,
                1.2500000000000002,
                1.2500000000000002,
                1.4000000000000004,
                1.4500000000000004,
                1.5500000000000005,
                1.8500000000000008,
                1.9000000000000008,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.000000000000001,
                2.1000000000000005,
                2.1000000000000005,
                2.2,
                2.3499999999999996,
                2.3499999999999996,
                2.3499999999999996,
                2.3499999999999996,
                2.3999999999999995,
                2.3999999999999995,
                2.6499999999999986,
                2.6999999999999984,
                2.8999999999999977,
                2.9499999999999975,
                3.049999999999997,
                3.049999999999997,
                3.049999999999997,
            ],
            "type": "su",
        },
        patch_size=2,
        in_channels=4,
        pos_embed_max_size=192,
    )
    transformer.load_state_dict(converted_state_dict, strict=True)
    transformer.to(torch.bfloat16)

    num_model_params = sum(p.numel() for p in transformer.parameters())
    print(f"Total number of transformer parameters: {num_model_params}")

    scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)

    vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)

    tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)

    pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler)
    pipeline.save_pretrained(args.dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--origin_ckpt_path",
        default="Shitao/OmniGen-v1",
        type=str,
        required=False,
        help="Path to the checkpoint to convert.",
    )

    parser.add_argument(
        "--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline."
    )

    args = parser.parse_args()
    main(args)