1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::copy_in::CopyInReceiver;
3use crate::error::DbError;
4use crate::maybe_tls_stream::MaybeTlsStream;
5use crate::{AsyncMessage, Error, Notification};
6use bytes::BytesMut;
7use fallible_iterator::FallibleIterator;
8use futures_channel::mpsc;
9use futures_util::{ready, stream::FusedStream, Sink, Stream, StreamExt};
10use log::{info, trace};
11use postgres_protocol::message::backend::Message;
12use postgres_protocol::message::frontend;
13use std::collections::{HashMap, VecDeque};
14use std::future::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio_util::codec::Framed;
19
20pub enum RequestMessages {
21 Single(FrontendMessage),
22 CopyIn(CopyInReceiver),
23}
24
25pub struct Request {
26 pub messages: RequestMessages,
27 pub sender: mpsc::Sender<BackendMessages>,
28}
29
30pub struct Response {
31 sender: mpsc::Sender<BackendMessages>,
32}
33
34#[derive(PartialEq, Debug)]
35enum State {
36 Active,
37 Terminating,
38 Closing,
39}
40
41#[must_use = "futures do nothing unless polled"]
49pub struct Connection<S, T> {
50 stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
51 parameters: HashMap<String, String>,
52 receiver: mpsc::UnboundedReceiver<Request>,
53 pending_request: Option<RequestMessages>,
54 pending_responses: VecDeque<BackendMessage>,
55 responses: VecDeque<Response>,
56 state: State,
57}
58
59impl<S, T> Connection<S, T>
60where
61 S: AsyncRead + AsyncWrite + Unpin,
62 T: AsyncRead + AsyncWrite + Unpin,
63{
64 pub(crate) fn new(
65 stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
66 pending_responses: VecDeque<BackendMessage>,
67 parameters: HashMap<String, String>,
68 receiver: mpsc::UnboundedReceiver<Request>,
69 ) -> Connection<S, T> {
70 Connection {
71 stream,
72 parameters,
73 receiver,
74 pending_request: None,
75 pending_responses,
76 responses: VecDeque::new(),
77 state: State::Active,
78 }
79 }
80
81 fn poll_response(
82 &mut self,
83 cx: &mut Context<'_>,
84 ) -> Poll<Option<Result<BackendMessage, Error>>> {
85 if let Some(message) = self.pending_responses.pop_front() {
86 trace!("retrying pending response");
87 return Poll::Ready(Some(Ok(message)));
88 }
89
90 Pin::new(&mut self.stream)
91 .poll_next(cx)
92 .map(|o| o.map(|r| r.map_err(Error::io)))
93 }
94
95 fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
96 if self.state != State::Active {
97 trace!("poll_read: done");
98 return Ok(None);
99 }
100
101 loop {
102 let message = match self.poll_response(cx)? {
103 Poll::Ready(Some(message)) => message,
104 Poll::Ready(None) => return Err(Error::closed()),
105 Poll::Pending => {
106 trace!("poll_read: waiting on response");
107 return Ok(None);
108 }
109 };
110
111 let (mut messages, request_complete) = match message {
112 BackendMessage::Async(Message::NoticeResponse(body)) => {
113 let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
114 return Ok(Some(AsyncMessage::Notice(error)));
115 }
116 BackendMessage::Async(Message::NotificationResponse(body)) => {
117 let notification = Notification {
118 process_id: body.process_id(),
119 channel: body.channel().map_err(Error::parse)?.to_string(),
120 payload: body.message().map_err(Error::parse)?.to_string(),
121 };
122 return Ok(Some(AsyncMessage::Notification(notification)));
123 }
124 BackendMessage::Async(Message::ParameterStatus(body)) => {
125 self.parameters.insert(
126 body.name().map_err(Error::parse)?.to_string(),
127 body.value().map_err(Error::parse)?.to_string(),
128 );
129 continue;
130 }
131 BackendMessage::Async(_) => unreachable!(),
132 BackendMessage::Normal {
133 messages,
134 request_complete,
135 } => (messages, request_complete),
136 };
137
138 let mut response = match self.responses.pop_front() {
139 Some(response) => response,
140 None => match messages.next().map_err(Error::parse)? {
141 Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
142 _ => return Err(Error::unexpected_message()),
143 },
144 };
145
146 match response.sender.poll_ready(cx) {
147 Poll::Ready(Ok(())) => {
148 let _ = response.sender.start_send(messages);
149 if !request_complete {
150 self.responses.push_front(response);
151 }
152 }
153 Poll::Ready(Err(_)) => {
154 if !request_complete {
156 self.responses.push_front(response);
157 }
158 }
159 Poll::Pending => {
160 self.responses.push_front(response);
161 self.pending_responses.push_back(BackendMessage::Normal {
162 messages,
163 request_complete,
164 });
165 trace!("poll_read: waiting on sender");
166 return Ok(None);
167 }
168 }
169 }
170 }
171
172 fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
173 if let Some(messages) = self.pending_request.take() {
174 trace!("retrying pending request");
175 return Poll::Ready(Some(messages));
176 }
177
178 if self.receiver.is_terminated() {
179 return Poll::Ready(None);
180 }
181
182 match self.receiver.poll_next_unpin(cx) {
183 Poll::Ready(Some(request)) => {
184 trace!("polled new request");
185 self.responses.push_back(Response {
186 sender: request.sender,
187 });
188 Poll::Ready(Some(request.messages))
189 }
190 Poll::Ready(None) => Poll::Ready(None),
191 Poll::Pending => Poll::Pending,
192 }
193 }
194
195 fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
196 loop {
197 if self.state == State::Closing {
198 trace!("poll_write: done");
199 return Ok(false);
200 }
201
202 if Pin::new(&mut self.stream)
203 .poll_ready(cx)
204 .map_err(Error::io)?
205 .is_pending()
206 {
207 trace!("poll_write: waiting on socket");
208 return Ok(false);
209 }
210
211 let request = match self.poll_request(cx) {
212 Poll::Ready(Some(request)) => request,
213 Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
214 trace!("poll_write: at eof, terminating");
215 self.state = State::Terminating;
216 let mut request = BytesMut::new();
217 frontend::terminate(&mut request);
218 RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
219 }
220 Poll::Ready(None) => {
221 trace!(
222 "poll_write: at eof, pending responses {}",
223 self.responses.len()
224 );
225 return Ok(true);
226 }
227 Poll::Pending => {
228 trace!("poll_write: waiting on request");
229 return Ok(true);
230 }
231 };
232
233 match request {
234 RequestMessages::Single(request) => {
235 Pin::new(&mut self.stream)
236 .start_send(request)
237 .map_err(Error::io)?;
238 if self.state == State::Terminating {
239 trace!("poll_write: sent eof, closing");
240 self.state = State::Closing;
241 }
242 }
243 RequestMessages::CopyIn(mut receiver) => {
244 let message = match receiver.poll_next_unpin(cx) {
245 Poll::Ready(Some(message)) => message,
246 Poll::Ready(None) => {
247 trace!("poll_write: finished copy_in request");
248 continue;
249 }
250 Poll::Pending => {
251 trace!("poll_write: waiting on copy_in stream");
252 self.pending_request = Some(RequestMessages::CopyIn(receiver));
253 return Ok(true);
254 }
255 };
256 Pin::new(&mut self.stream)
257 .start_send(message)
258 .map_err(Error::io)?;
259 self.pending_request = Some(RequestMessages::CopyIn(receiver));
260 }
261 }
262 }
263 }
264
265 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
266 match Pin::new(&mut self.stream)
267 .poll_flush(cx)
268 .map_err(Error::io)?
269 {
270 Poll::Ready(()) => trace!("poll_flush: flushed"),
271 Poll::Pending => trace!("poll_flush: waiting on socket"),
272 }
273 Ok(())
274 }
275
276 fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
277 if self.state != State::Closing {
278 return Poll::Pending;
279 }
280
281 match Pin::new(&mut self.stream)
282 .poll_close(cx)
283 .map_err(Error::io)?
284 {
285 Poll::Ready(()) => {
286 trace!("poll_shutdown: complete");
287 Poll::Ready(Ok(()))
288 }
289 Poll::Pending => {
290 trace!("poll_shutdown: waiting on socket");
291 Poll::Pending
292 }
293 }
294 }
295
296 pub fn parameter(&self, name: &str) -> Option<&str> {
298 self.parameters.get(name).map(|s| &**s)
299 }
300
301 pub fn poll_message(
309 &mut self,
310 cx: &mut Context<'_>,
311 ) -> Poll<Option<Result<AsyncMessage, Error>>> {
312 let message = self.poll_read(cx)?;
313 let want_flush = self.poll_write(cx)?;
314 if want_flush {
315 self.poll_flush(cx)?;
316 }
317 match message {
318 Some(message) => Poll::Ready(Some(Ok(message))),
319 None => match self.poll_shutdown(cx) {
320 Poll::Ready(Ok(())) => Poll::Ready(None),
321 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
322 Poll::Pending => Poll::Pending,
323 },
324 }
325 }
326}
327
328impl<S, T> Future for Connection<S, T>
329where
330 S: AsyncRead + AsyncWrite + Unpin,
331 T: AsyncRead + AsyncWrite + Unpin,
332{
333 type Output = Result<(), Error>;
334
335 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
336 while let Some(message) = ready!(self.poll_message(cx)?) {
337 if let AsyncMessage::Notice(notice) = message {
338 info!("{}: {}", notice.severity(), notice.message());
339 }
340 }
341 Poll::Ready(Ok(()))
342 }
343}