Coverage for flogin/jsonrpc/client.py: 28%

119 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-03 22:51 +0000

1from __future__ import annotations 

2 

3import asyncio 

4import json 

5import logging 

6from typing import TYPE_CHECKING, Any 

7 

8from .errors import ( 

9 InternalError, 

10 MethodNotFound, 

11) 

12from .errors import ( 

13 get_exception_from_json as _get_jsonrpc_error_from_json, 

14) 

15from .requests import Request 

16from .responses import BaseResponse, ErrorResponse 

17 

18log = logging.getLogger(__name__) 

19 

20if TYPE_CHECKING: 

21 from asyncio.streams import StreamReader, StreamWriter 

22 

23 from .._types.jsonrpc.request import ( 

24 Request as RequestPayload, 

25 ) 

26 from .._types.jsonrpc.request import ( 

27 RequestResult as RequestResultPayload, 

28 ) 

29 from ..plugin import Plugin 

30 

31__all__ = ("JsonRPCClient",) 

32 

33 

34class JsonRPCClient: 

35 reader: StreamReader 

36 writer: StreamWriter 

37 

38 def __init__(self, plugin: Plugin[Any]) -> None: 

39 self.tasks: dict[int, asyncio.Task[BaseResponse[Any]]] = {} 

40 self.requests: dict[int, asyncio.Future[Any | ErrorResponse]] = {} 

41 self._current_request_id = 1 

42 self.plugin = plugin 

43 self.ignore_cancellations: bool = plugin.options.get( 

44 "ignore_cancellation_requests", False 

45 ) 

46 

47 @property 

48 def request_id(self) -> int: 

49 self._current_request_id += 1 

50 return self._current_request_id 

51 

52 @request_id.setter 

53 def request_id(self, value: int) -> None: 

54 self._current_request_id = value 

55 

56 async def request(self, method: str, params: list[Any] | None = None) -> Any: 

57 if params is None: 

58 params = [] 

59 

60 fut: asyncio.Future[Any] = asyncio.Future() 

61 rid = self.request_id 

62 self.requests[rid] = fut 

63 msg = Request(method, rid, params).to_message(rid) 

64 await self.write(msg, drain=False) 

65 return await fut 

66 

67 async def handle_cancellation(self, id: int) -> None: 

68 if self.ignore_cancellations: 

69 return log.debug("Ignoring cancellation request of %r", id) 

70 

71 if id in self.tasks: 

72 task = self.tasks.pop(id) 

73 success = task.cancel() 

74 if success: 

75 log.debug("Successfully cancelled task with id %r", id) 

76 else: 

77 log.exception("Failed to cancel task with id of %r, task=%r", id, task) 

78 else: 

79 log.exception( 

80 "Failed to cancel task with id of %r, could not find task.", id 

81 ) 

82 

83 async def handle_result(self, result: RequestResultPayload[Any]) -> None: 

84 rid = result["id"] 

85 

86 log.debug("Result: %r, %r", rid, result) 

87 if rid in self.requests: 

88 try: 

89 self.requests.pop(rid).set_result(result) 

90 except asyncio.InvalidStateError: 

91 pass 

92 else: 

93 log.error( 

94 "Result from unknown request given. id=%r, result=%r", rid, result 

95 ) 

96 

97 async def handle_error(self, id: int, error: ErrorResponse) -> None: 

98 if id in self.requests: 

99 self.requests.pop(id).set_exception( 

100 _get_jsonrpc_error_from_json(error.to_dict()["error"]) 

101 ) 

102 else: 

103 log.error("Error response received for unknown request, id=%r", id) 

104 

105 async def handle_notification(self, method: str, params: dict[str, Any]) -> None: 

106 if method == "$/cancelRequest": 

107 await self.handle_cancellation(params["id"]) 

108 else: 

109 err = MethodNotFound(f"Notification Method {method!r} Not Found") 

110 

111 log.exception( 

112 "Unknown notification method received: %r", 

113 method, 

114 exc_info=err, 

115 ) 

116 

117 async def handle_request(self, request: RequestPayload) -> None: 

118 method: str = request["method"] 

119 params: list[Any] = request.get("params", []) 

120 task = None 

121 

122 self.request_id = request["id"] 

123 

124 if method.startswith("flogin.action"): 

125 task = self.plugin.process_action(method) 

126 

127 if task is None: 

128 task = self.plugin.dispatch(method, *params) 

129 if not task: 

130 err = MethodNotFound(f"Request method {method!r} was not found") 

131 log.exception( 

132 "Unknown request method received: %r", 

133 method, 

134 exc_info=err, 

135 ) 

136 return await self.write(err.to_response().to_message(id=request["id"])) 

137 

138 self.tasks[request["id"]] = task 

139 result = await task 

140 

141 if isinstance(result, BaseResponse): 

142 return await self.write(result.to_message(id=request["id"])) 

143 err = InternalError("Internal Error: Invalid Response Object", repr(result)) 

144 log.exception( 

145 "Invalid Response Object: %r", 

146 result, 

147 exc_info=err, 

148 ) 

149 return await self.write(err.to_response().to_message(id=request["id"])) 

150 

151 async def process_input(self, line: str) -> None: 

152 message = json.loads(line) 

153 

154 if "id" not in message: 

155 log.debug("Received notification: %r", message) 

156 await self.handle_notification(message["method"], message["params"]) 

157 elif "method" in message: 

158 log.debug("Received request: %r", message) 

159 await self.handle_request(message) 

160 elif "result" in message: 

161 log.debug("Received result: %r", message) 

162 await self.handle_result(message) 

163 elif "error" in message: 

164 log.exception("Received error: %r", message) 

165 await self.handle_error( 

166 message["id"], ErrorResponse.from_dict(message["error"]) 

167 ) 

168 else: 

169 err = InternalError("Unknown message type received", line) 

170 log.exception( 

171 "Unknown message type received", 

172 exc_info=err, 

173 ) 

174 

175 async def start_listening(self, reader: StreamReader, writer: StreamWriter) -> None: 

176 self.reader = reader 

177 self.writer = writer 

178 

179 stream_log = logging.getLogger("flogin.stream_reader") 

180 tasks: set[asyncio.Task[None]] = set() 

181 

182 while 1: 

183 async for line in reader: 

184 stream_log.debug("Received line: %r", line) 

185 line = line.decode("utf-8") 

186 if line == "": 

187 continue 

188 

189 task = asyncio.create_task(self.process_input(line)) 

190 tasks.add(task) 

191 task.add_done_callback(tasks.discard) 

192 

193 async def write(self, msg: bytes, drain: bool = True) -> None: 

194 log.debug("Sending: %r", msg) 

195 self.writer.write(msg) 

196 if drain: 

197 await self.writer.drain()