Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: user-specified validation rule skips #340

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ariadne_codegen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_graphql_schema_from_path,
get_graphql_schema_from_url,
)
from .settings import Strategy
from .settings import Strategy, get_validation_rule


@click.command() # type: ignore
Expand Down Expand Up @@ -64,7 +64,7 @@ def client(config_dict):
fragments = []
queries = []
if settings.queries_path:
definitions = get_graphql_queries(settings.queries_path, schema)
definitions = get_graphql_queries(settings.queries_path, schema, [get_validation_rule(e) for e in settings.skip_validation_rules])
queries = filter_operations_definitions(definitions)
fragments = filter_fragments_definitions(definitions)

Expand Down
6 changes: 4 additions & 2 deletions ariadne_codegen/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple, cast
from typing_extensions import Any, Sequence

import httpx
from graphql import (
Expand All @@ -14,6 +15,7 @@
IntrospectionQuery,
NoUnusedFragmentsRule,
OperationDefinitionNode,
UniqueFragmentNamesRule,
build_ast_schema,
build_client_schema,
get_introspection_query,
Expand Down Expand Up @@ -45,15 +47,15 @@ def filter_fragments_definitions(


def get_graphql_queries(
queries_path: str, schema: GraphQLSchema
queries_path: str, schema: GraphQLSchema, skip_rules: Sequence[Any] = (NoUnusedFragmentsRule,)
) -> Tuple[DefinitionNode, ...]:
"""Get graphql queries definitions build from provided path."""
queries_str = load_graphql_files_from_path(Path(queries_path))
queries_ast = parse(queries_str)
validation_errors = validate(
schema=schema,
document_ast=queries_ast,
rules=[r for r in specified_rules if r is not NoUnusedFragmentsRule],
rules=[r for r in specified_rules if r not in skip_rules],
)
if validation_errors:
raise InvalidOperationForSchema(
Expand Down
15 changes: 15 additions & 0 deletions ariadne_codegen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from textwrap import dedent
from typing import Dict, List

from graphql.validation import UniqueFragmentNamesRule, NoUnusedFragmentsRule

from .client_generators.constants import (
DEFAULT_ASYNC_BASE_CLIENT_NAME,
DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_NAME,
Expand All @@ -25,6 +27,18 @@ class CommentsStrategy(str, enum.Enum):
STABLE = "stable"
TIMESTAMP = "timestamp"

class ValidationRuleSkips(str, enum.Enum):
UniqueFragmentNames = "UniqueFragmentNames"
NoUnusedFragments = "NoUnusedFragments"

def get_validation_rule(rule: ValidationRuleSkips):
if rule == ValidationRuleSkips.UniqueFragmentNames:
return UniqueFragmentNamesRule
elif rule == ValidationRuleSkips.NoUnusedFragments:
return NoUnusedFragmentsRule
else:
raise ValueError(f"Unknown validation rule: {rule}")


class Strategy(str, enum.Enum):
CLIENT = "client"
Expand Down Expand Up @@ -70,6 +84,7 @@ class ClientSettings(BaseSettings):
include_all_enums: bool = True
async_client: bool = True
opentelemetry_client: bool = False
skip_validation_rules: List[ValidationRuleSkips] = field(default_factory=lambda: [ValidationRuleSkips.UniqueFragmentNames,])
files_to_include: List[str] = field(default_factory=list)
scalars: Dict[str, ScalarData] = field(default_factory=dict)

Expand Down
96 changes: 96 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ariadne_codegen.settings import get_validation_rule
import httpx
import pytest
from graphql import GraphQLSchema, OperationDefinitionNode, build_schema
Expand All @@ -15,6 +16,7 @@
read_graphql_file,
walk_graphql_files,
)
from ariadne_codegen.settings import ValidationRuleSkips


@pytest.fixture
Expand Down Expand Up @@ -63,6 +65,49 @@ def test_query_2_str():
}
"""

@pytest.fixture
def test_fragment_str():
return """
fragment fragmentA on Custom {
node
}
query testQuery2 {
test {
default
...fragmentA
}
}
"""

@pytest.fixture
def test_duplicate_fragment_str():
return """
fragment fragmentA on Custom {
node
}
fragment fragmentA on Custom {
node
}
query testQuery2 {
test {
default
...fragmentA
}
}
"""

@pytest.fixture
def test_unused_fragment_str():
return """
fragment fragmentA on Custom {
node
}
query testQuery2 {
test {
default
}
}
"""

@pytest.fixture
def single_file_schema(tmp_path_factory, schema_str):
Expand Down Expand Up @@ -132,6 +177,24 @@ def single_file_query(tmp_path_factory, test_query_str):
file_.write_text(test_query_str, encoding="utf-8")
return file_

@pytest.fixture
def single_file_query_with_fragment(tmp_path_factory, test_query_str, test_fragment_str):
file_ = tmp_path_factory.mktemp("queries").joinpath("query1_fragment.graphql")
file_.write_text(test_query_str + test_fragment_str, encoding="utf-8")
return file_

@pytest.fixture
def single_file_query_with_duplicate_fragment(tmp_path_factory, test_query_str, test_duplicate_fragment_str):
file_ = tmp_path_factory.mktemp("queries").joinpath("query1_duplicate_fragment.graphql")
file_.write_text(test_query_str + test_duplicate_fragment_str, encoding="utf-8")
return file_

@pytest.fixture
def single_file_query_with_unused_fragment(tmp_path_factory, test_query_str, test_unused_fragment_str):
file_ = tmp_path_factory.mktemp("queries").joinpath("query1_unused_fragment.graphql")
file_.write_text(test_query_str + test_unused_fragment_str, encoding="utf-8")
return file_


@pytest.fixture
def invalid_syntax_query_file(tmp_path_factory):
Expand Down Expand Up @@ -434,3 +497,36 @@ def test_get_graphql_queries_with_invalid_query_for_schema_raises_invalid_operat
get_graphql_queries(
invalid_query_for_schema_file.as_posix(), build_schema(schema_str)
)


def test_get_graphql_queries_with_fragment_returns_schema_definitions(
single_file_query_with_fragment, schema_str
):
queries = get_graphql_queries(
single_file_query_with_fragment.as_posix(), build_schema(schema_str)
)

assert len(queries) == 3

def test_get_graphql_queries_with_duplicate_fragment_raises_invalid_operation(
single_file_query_with_duplicate_fragment, schema_str
):
with pytest.raises(InvalidOperationForSchema):
get_graphql_queries(
single_file_query_with_duplicate_fragment.as_posix(), build_schema(schema_str)
)

def test_get_graphql_queries_with_unused_fragment_and_no_skip_rules_raises_invalid_operation(
single_file_query_with_unused_fragment, schema_str
):
with pytest.raises(InvalidOperationForSchema):
get_graphql_queries(
single_file_query_with_unused_fragment.as_posix(), build_schema(schema_str), []
)

def test_get_graphql_queries_with_skip_unique_fragment_names_and_duplicate_fragment_returns_schema_definition(
single_file_query_with_duplicate_fragment, schema_str
):
get_graphql_queries(
single_file_query_with_duplicate_fragment.as_posix(), build_schema(schema_str), [get_validation_rule(ValidationRuleSkips.NoUnusedFragments),get_validation_rule(ValidationRuleSkips.UniqueFragmentNames)]
)