import dataclasses
import typing
from collections.abc import Iterable

import inflection

from misc.codegen.loaders import schemaloader
from . import qlgen


@dataclasses.dataclass
class Param:
    name: str
    type: str
    first: bool = False


@dataclasses.dataclass
class Function:
    name: str
    signature: str


@dataclasses.dataclass
class TestCode:
    template: typing.ClassVar[str] = "rust_test_code"

    code: str
    function: Function | None = None


def _get_code(doc: list[str]) -> list[str]:
    adding_code = False
    has_code = False
    code = []
    for line in doc:
        match line, adding_code:
            case ("```", _) | ("```rust", _):
                adding_code = not adding_code
                has_code = True
            case _, False:
                code.append(f"// {line}")
            case _, True:
                code.append(line)
    assert not adding_code, "Unterminated code block in docstring:\n  " + "\n  ".join(doc)
    if has_code:
        return code
    return []


def generate(opts, renderer):
    assert opts.ql_test_output
    schema = schemaloader.load_file(opts.schema)
    with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
                         stubs=(),
                         registry=opts.ql_test_output / ".generated_tests.list",
                         force=opts.force) as renderer:
        for cls in schema.classes.values():
            if cls.imported:
                continue
            if (qlgen.should_skip_qltest(cls, schema.classes) or
                    "rust_skip_doc_test" in cls.pragmas):
                continue
            code = _get_code(cls.doc)
            for p in schema.iter_properties(cls.name):
                if "rust_skip_doc_test" in p.pragmas:
                    continue
                property_code = _get_code(p.description)
                if property_code:
                    code.append(f"// # {p.name}")
                    code += property_code
            if not code:
                continue
            test_name = inflection.underscore(cls.name)
            signature = cls.pragmas.get("rust_doc_test_signature", "() -> ()")
            fn = signature and Function(f"test_{test_name}", signature)
            if fn:
                indent = 4 * " "
                code = [indent + l for l in code]
            test_with_name = typing.cast(str, cls.pragmas.get("qltest_test_with"))
            test_with = schema.classes[test_with_name] if test_with_name else cls
            test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
            renderer.render(TestCode(code="\n".join(code), function=fn), test)