Skip to content

Commit

Permalink
Merge pull request #245 from mirumee/operation_name
Browse files Browse the repository at this point in the history
Add operationName
  • Loading branch information
mat-sop authored Dec 1, 2023
2 parents 46bd337 + 108c0cd commit 4cb4af7
Show file tree
Hide file tree
Showing 41 changed files with 1,157 additions and 269 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Added `NoReimportsPlugin` that makes the `__init__.py` of generated client package empty.
- Added `include_all_inputs` config flag to generate only inputs used in supplied operations.
- Added `include_all_enums` config flag to generate only enums used in supplied operations.
- Added `operationName` to payload sent by generated client's methods.


## 0.10.0 (2023-11-15)
Expand Down
23 changes: 18 additions & 5 deletions EXAMPLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ class Client(AsyncBaseClient):
"""
)
variables: Dict[str, object] = {"userData": user_data}
response = await self.execute(query=query, variables=variables, **kwargs)
response = await self.execute(
query=query, operation_name="CreateUser", variables=variables, **kwargs
)
data = self.get_data(response)
return CreateUser.model_validate(data)

Expand All @@ -231,7 +233,9 @@ class Client(AsyncBaseClient):
"""
)
variables: Dict[str, object] = {}
response = await self.execute(query=query, variables=variables, **kwargs)
response = await self.execute(
query=query, operation_name="ListAllUsers", variables=variables, **kwargs
)
data = self.get_data(response)
return ListAllUsers.model_validate(data)

Expand Down Expand Up @@ -260,7 +264,12 @@ class Client(AsyncBaseClient):
"""
)
variables: Dict[str, object] = {"country": country}
response = await self.execute(query=query, variables=variables, **kwargs)
response = await self.execute(
query=query,
operation_name="ListUsersByCountry",
variables=variables,
**kwargs
)
data = self.get_data(response)
return ListUsersByCountry.model_validate(data)

Expand All @@ -273,7 +282,9 @@ class Client(AsyncBaseClient):
"""
)
variables: Dict[str, object] = {}
async for data in self.execute_ws(query=query, variables=variables, **kwargs):
async for data in self.execute_ws(
query=query, operation_name="GetUsersCounter", variables=variables, **kwargs
):
yield GetUsersCounter.model_validate(data)

async def upload_file(self, file: Upload, **kwargs: Any) -> UploadFile:
Expand All @@ -285,7 +296,9 @@ class Client(AsyncBaseClient):
"""
)
variables: Dict[str, object] = {"file": file}
response = await self.execute(query=query, variables=variables, **kwargs)
response = await self.execute(
query=query, operation_name="uploadFile", variables=variables, **kwargs
)
data = self.get_data(response)
return UploadFile.model_validate(data)
```
Expand Down
37 changes: 28 additions & 9 deletions ariadne_codegen/client_generators/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def add_method(
arguments, arguments_dict = self.arguments_generator.generate(
definition.variable_definitions
)
operation_name = definition.name.value if definition.name else ""
if definition.operation == OperationType.SUBSCRIPTION:
if not async_:
raise NotSupported(
Expand All @@ -144,6 +145,7 @@ def add_method(
ast.FunctionDef, ast.AsyncFunctionDef
] = self._generate_subscription_method_def(
name=name,
operation_name=operation_name,
return_type=return_type,
arguments=arguments,
arguments_dict=arguments_dict,
Expand All @@ -156,6 +158,7 @@ def add_method(
arguments=arguments,
arguments_dict=arguments_dict,
operation_str=operation_str,
operation_name=operation_name,
)
else:
method_def = self._generate_method(
Expand All @@ -164,6 +167,7 @@ def add_method(
arguments=arguments,
arguments_dict=arguments_dict,
operation_str=operation_str,
operation_name=operation_name,
)

method_def.lineno = len(self._class_def.body) + 1
Expand All @@ -188,6 +192,7 @@ def _add_import(self, import_: Optional[ast.ImportFrom] = None):
def _generate_subscription_method_def(
self,
name: str,
operation_name: str,
return_type: str,
arguments: ast.arguments,
arguments_dict: ast.Dict,
Expand All @@ -202,7 +207,7 @@ def _generate_subscription_method_def(
body=[
self._generate_operation_str_assign(operation_str, 1),
self._generate_variables_assign(arguments_dict, 2),
self._generate_async_generator_loop(return_type, 3),
self._generate_async_generator_loop(operation_name, return_type, 3),
],
)

Expand All @@ -213,6 +218,7 @@ def _generate_async_method(
arguments: ast.arguments,
arguments_dict: ast.Dict,
operation_str: str,
operation_name: str,
) -> ast.AsyncFunctionDef:
return generate_async_method_definition(
name=name,
Expand All @@ -221,7 +227,7 @@ def _generate_async_method(
body=[
self._generate_operation_str_assign(operation_str, 1),
self._generate_variables_assign(arguments_dict, 2),
self._generate_async_response_assign(3),
self._generate_async_response_assign(operation_name, 3),
self._generate_data_retrieval(),
self._generate_return_parsed_obj(return_type),
],
Expand All @@ -234,6 +240,7 @@ def _generate_method(
arguments: ast.arguments,
arguments_dict: ast.Dict,
operation_str: str,
operation_name: str,
) -> ast.FunctionDef:
return generate_method_definition(
name=name,
Expand All @@ -242,7 +249,7 @@ def _generate_method(
body=[
self._generate_operation_str_assign(operation_str, 1),
self._generate_variables_assign(arguments_dict, 2),
self._generate_response_assign(3),
self._generate_response_assign(operation_name, 3),
self._generate_data_retrieval(),
self._generate_return_parsed_obj(return_type),
],
Expand Down Expand Up @@ -275,27 +282,36 @@ def _generate_variables_assign(
lineno=lineno,
)

def _generate_async_response_assign(self, lineno: int = 1) -> ast.Assign:
def _generate_async_response_assign(
self, operation_name: str, lineno: int = 1
) -> ast.Assign:
return generate_assign(
targets=[self._response_variable],
value=generate_await(self._generate_execute_call()),
value=generate_await(
self._generate_execute_call(operation_name=operation_name)
),
lineno=lineno,
)

def _generate_response_assign(self, lineno: int = 1) -> ast.Assign:
def _generate_response_assign(
self, operation_name: str, lineno: int = 1
) -> ast.Assign:
return generate_assign(
targets=[self._response_variable],
value=self._generate_execute_call(),
value=self._generate_execute_call(operation_name=operation_name),
lineno=lineno,
)

def _generate_execute_call(self) -> ast.Call:
def _generate_execute_call(self, operation_name: str) -> ast.Call:
return generate_call(
func=generate_attribute(generate_name("self"), "execute"),
keywords=[
generate_keyword(
value=generate_name(self._operation_str_variable), arg="query"
),
generate_keyword(
value=generate_constant(operation_name), arg="operation_name"
),
generate_keyword(
value=generate_name(self._variables_dict_variable), arg="variables"
),
Expand Down Expand Up @@ -323,7 +339,7 @@ def _generate_return_parsed_obj(self, return_type: str) -> ast.Return:
)

def _generate_async_generator_loop(
self, return_type: str, lineno: int = 1
self, operation_name: str, return_type: str, lineno: int = 1
) -> ast.AsyncFor:
return generate_async_for(
target=generate_name(self._data_variable),
Expand All @@ -333,6 +349,9 @@ def _generate_async_generator_loop(
generate_keyword(
value=generate_name(self._operation_str_variable), arg="query"
),
generate_keyword(
value=generate_constant(operation_name), arg="operation_name"
),
generate_keyword(
value=generate_name(self._variables_dict_variable),
arg="variables",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,29 @@ async def __aexit__(
await self.http_client.aclose()

async def execute(
self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any
self,
query: str,
operation_name: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> httpx.Response:
processed_variables, files, files_map = self._process_variables(variables)

if files and files_map:
return await self._execute_multipart(
query=query,
operation_name=operation_name,
variables=processed_variables,
files=files,
files_map=files_map,
**kwargs,
)

return await self._execute_json(
query=query, variables=processed_variables, **kwargs
query=query,
operation_name=operation_name,
variables=processed_variables,
**kwargs,
)

def get_data(self, response: httpx.Response) -> Dict[str, Any]:
Expand Down Expand Up @@ -126,7 +134,11 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]:
return cast(Dict[str, Any], data)

async def execute_ws(
self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any
self,
query: str,
operation_name: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
headers = self.ws_headers.copy()
headers.update(kwargs.get("extra_headers", {}))
Expand All @@ -146,6 +158,7 @@ async def execute_ws(
websocket,
operation_id=operation_id,
query=query,
operation_name=operation_name,
variables=variables,
)

Expand Down Expand Up @@ -226,14 +239,20 @@ def separate_files(path: str, obj: Any) -> Any:
async def _execute_multipart(
self,
query: str,
operation_name: Optional[str],
variables: Dict[str, Any],
files: Dict[str, Tuple[str, IO[bytes], str]],
files_map: Dict[str, List[str]],
**kwargs: Any,
) -> httpx.Response:
data = {
"operations": json.dumps(
{"query": query, "variables": variables}, default=to_jsonable_python
{
"query": query,
"operationName": operation_name,
"variables": variables,
},
default=to_jsonable_python,
),
"map": json.dumps(files_map, default=to_jsonable_python),
}
Expand All @@ -243,7 +262,11 @@ async def _execute_multipart(
)

async def _execute_json(
self, query: str, variables: Dict[str, Any], **kwargs: Any
self,
query: str,
operation_name: Optional[str],
variables: Dict[str, Any],
**kwargs: Any,
) -> httpx.Response:
headers: Dict[str, str] = {"Content-Type": "application/json"}
headers.update(kwargs.get("headers", {}))
Expand All @@ -254,7 +277,12 @@ async def _execute_json(
return await self.http_client.post(
url=self.url,
content=json.dumps(
{"query": query, "variables": variables}, default=to_jsonable_python
{
"query": query,
"operationName": operation_name,
"variables": variables,
},
default=to_jsonable_python,
),
**merged_kwargs,
)
Expand All @@ -272,12 +300,13 @@ async def _send_subscribe(
websocket: WebSocketClientProtocol,
operation_id: str,
query: str,
operation_name: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
) -> None:
payload: Dict[str, Any] = {
"id": operation_id,
"type": GraphQLTransportWSMessageType.SUBSCRIBE.value,
"payload": {"query": query},
"payload": {"query": query, "operationName": operation_name},
}
if variables:
payload["payload"]["variables"] = self._convert_dict_to_json_serializable(
Expand Down
Loading

0 comments on commit 4cb4af7

Please sign in to comment.