tokio_postgres/
connection.rs

1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::copy_in::CopyInReceiver;
3use crate::error::DbError;
4use crate::maybe_tls_stream::MaybeTlsStream;
5use crate::{AsyncMessage, Error, Notification};
6use bytes::BytesMut;
7use fallible_iterator::FallibleIterator;
8use futures_channel::mpsc;
9use futures_util::{ready, stream::FusedStream, Sink, Stream, StreamExt};
10use log::{info, trace};
11use postgres_protocol::message::backend::Message;
12use postgres_protocol::message::frontend;
13use std::collections::{HashMap, VecDeque};
14use std::future::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio_util::codec::Framed;
19
20pub enum RequestMessages {
21    Single(FrontendMessage),
22    CopyIn(CopyInReceiver),
23}
24
25pub struct Request {
26    pub messages: RequestMessages,
27    pub sender: mpsc::Sender<BackendMessages>,
28}
29
30pub struct Response {
31    sender: mpsc::Sender<BackendMessages>,
32}
33
34#[derive(PartialEq, Debug)]
35enum State {
36    Active,
37    Terminating,
38    Closing,
39}
40
41/// A connection to a PostgreSQL database.
42///
43/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
44/// server, and should generally be spawned off onto an executor to run in the background.
45///
46/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
47/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
48#[must_use = "futures do nothing unless polled"]
49pub struct Connection<S, T> {
50    stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
51    parameters: HashMap<String, String>,
52    receiver: mpsc::UnboundedReceiver<Request>,
53    pending_request: Option<RequestMessages>,
54    pending_responses: VecDeque<BackendMessage>,
55    responses: VecDeque<Response>,
56    state: State,
57}
58
59impl<S, T> Connection<S, T>
60where
61    S: AsyncRead + AsyncWrite + Unpin,
62    T: AsyncRead + AsyncWrite + Unpin,
63{
64    pub(crate) fn new(
65        stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
66        pending_responses: VecDeque<BackendMessage>,
67        parameters: HashMap<String, String>,
68        receiver: mpsc::UnboundedReceiver<Request>,
69    ) -> Connection<S, T> {
70        Connection {
71            stream,
72            parameters,
73            receiver,
74            pending_request: None,
75            pending_responses,
76            responses: VecDeque::new(),
77            state: State::Active,
78        }
79    }
80
81    fn poll_response(
82        &mut self,
83        cx: &mut Context<'_>,
84    ) -> Poll<Option<Result<BackendMessage, Error>>> {
85        if let Some(message) = self.pending_responses.pop_front() {
86            trace!("retrying pending response");
87            return Poll::Ready(Some(Ok(message)));
88        }
89
90        Pin::new(&mut self.stream)
91            .poll_next(cx)
92            .map(|o| o.map(|r| r.map_err(Error::io)))
93    }
94
95    fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
96        if self.state != State::Active {
97            trace!("poll_read: done");
98            return Ok(None);
99        }
100
101        loop {
102            let message = match self.poll_response(cx)? {
103                Poll::Ready(Some(message)) => message,
104                Poll::Ready(None) => return Err(Error::closed()),
105                Poll::Pending => {
106                    trace!("poll_read: waiting on response");
107                    return Ok(None);
108                }
109            };
110
111            let (mut messages, request_complete) = match message {
112                BackendMessage::Async(Message::NoticeResponse(body)) => {
113                    let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
114                    return Ok(Some(AsyncMessage::Notice(error)));
115                }
116                BackendMessage::Async(Message::NotificationResponse(body)) => {
117                    let notification = Notification {
118                        process_id: body.process_id(),
119                        channel: body.channel().map_err(Error::parse)?.to_string(),
120                        payload: body.message().map_err(Error::parse)?.to_string(),
121                    };
122                    return Ok(Some(AsyncMessage::Notification(notification)));
123                }
124                BackendMessage::Async(Message::ParameterStatus(body)) => {
125                    self.parameters.insert(
126                        body.name().map_err(Error::parse)?.to_string(),
127                        body.value().map_err(Error::parse)?.to_string(),
128                    );
129                    continue;
130                }
131                BackendMessage::Async(_) => unreachable!(),
132                BackendMessage::Normal {
133                    messages,
134                    request_complete,
135                } => (messages, request_complete),
136            };
137
138            let mut response = match self.responses.pop_front() {
139                Some(response) => response,
140                None => match messages.next().map_err(Error::parse)? {
141                    Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
142                    _ => return Err(Error::unexpected_message()),
143                },
144            };
145
146            match response.sender.poll_ready(cx) {
147                Poll::Ready(Ok(())) => {
148                    let _ = response.sender.start_send(messages);
149                    if !request_complete {
150                        self.responses.push_front(response);
151                    }
152                }
153                Poll::Ready(Err(_)) => {
154                    // we need to keep paging through the rest of the messages even if the receiver's hung up
155                    if !request_complete {
156                        self.responses.push_front(response);
157                    }
158                }
159                Poll::Pending => {
160                    self.responses.push_front(response);
161                    self.pending_responses.push_back(BackendMessage::Normal {
162                        messages,
163                        request_complete,
164                    });
165                    trace!("poll_read: waiting on sender");
166                    return Ok(None);
167                }
168            }
169        }
170    }
171
172    fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
173        if let Some(messages) = self.pending_request.take() {
174            trace!("retrying pending request");
175            return Poll::Ready(Some(messages));
176        }
177
178        if self.receiver.is_terminated() {
179            return Poll::Ready(None);
180        }
181
182        match self.receiver.poll_next_unpin(cx) {
183            Poll::Ready(Some(request)) => {
184                trace!("polled new request");
185                self.responses.push_back(Response {
186                    sender: request.sender,
187                });
188                Poll::Ready(Some(request.messages))
189            }
190            Poll::Ready(None) => Poll::Ready(None),
191            Poll::Pending => Poll::Pending,
192        }
193    }
194
195    fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
196        loop {
197            if self.state == State::Closing {
198                trace!("poll_write: done");
199                return Ok(false);
200            }
201
202            if Pin::new(&mut self.stream)
203                .poll_ready(cx)
204                .map_err(Error::io)?
205                .is_pending()
206            {
207                trace!("poll_write: waiting on socket");
208                return Ok(false);
209            }
210
211            let request = match self.poll_request(cx) {
212                Poll::Ready(Some(request)) => request,
213                Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
214                    trace!("poll_write: at eof, terminating");
215                    self.state = State::Terminating;
216                    let mut request = BytesMut::new();
217                    frontend::terminate(&mut request);
218                    RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
219                }
220                Poll::Ready(None) => {
221                    trace!(
222                        "poll_write: at eof, pending responses {}",
223                        self.responses.len()
224                    );
225                    return Ok(true);
226                }
227                Poll::Pending => {
228                    trace!("poll_write: waiting on request");
229                    return Ok(true);
230                }
231            };
232
233            match request {
234                RequestMessages::Single(request) => {
235                    Pin::new(&mut self.stream)
236                        .start_send(request)
237                        .map_err(Error::io)?;
238                    if self.state == State::Terminating {
239                        trace!("poll_write: sent eof, closing");
240                        self.state = State::Closing;
241                    }
242                }
243                RequestMessages::CopyIn(mut receiver) => {
244                    let message = match receiver.poll_next_unpin(cx) {
245                        Poll::Ready(Some(message)) => message,
246                        Poll::Ready(None) => {
247                            trace!("poll_write: finished copy_in request");
248                            continue;
249                        }
250                        Poll::Pending => {
251                            trace!("poll_write: waiting on copy_in stream");
252                            self.pending_request = Some(RequestMessages::CopyIn(receiver));
253                            return Ok(true);
254                        }
255                    };
256                    Pin::new(&mut self.stream)
257                        .start_send(message)
258                        .map_err(Error::io)?;
259                    self.pending_request = Some(RequestMessages::CopyIn(receiver));
260                }
261            }
262        }
263    }
264
265    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
266        match Pin::new(&mut self.stream)
267            .poll_flush(cx)
268            .map_err(Error::io)?
269        {
270            Poll::Ready(()) => trace!("poll_flush: flushed"),
271            Poll::Pending => trace!("poll_flush: waiting on socket"),
272        }
273        Ok(())
274    }
275
276    fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
277        if self.state != State::Closing {
278            return Poll::Pending;
279        }
280
281        match Pin::new(&mut self.stream)
282            .poll_close(cx)
283            .map_err(Error::io)?
284        {
285            Poll::Ready(()) => {
286                trace!("poll_shutdown: complete");
287                Poll::Ready(Ok(()))
288            }
289            Poll::Pending => {
290                trace!("poll_shutdown: waiting on socket");
291                Poll::Pending
292            }
293        }
294    }
295
296    /// Returns the value of a runtime parameter for this connection.
297    pub fn parameter(&self, name: &str) -> Option<&str> {
298        self.parameters.get(name).map(|s| &**s)
299    }
300
301    /// Polls for asynchronous messages from the server.
302    ///
303    /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
304    /// examine those messages should use this method to drive the connection rather than its `Future` implementation.
305    ///
306    /// Return values of `None` or `Some(Err(_))` are "terminal"; callers should not invoke this method again after
307    /// receiving one of those values.
308    pub fn poll_message(
309        &mut self,
310        cx: &mut Context<'_>,
311    ) -> Poll<Option<Result<AsyncMessage, Error>>> {
312        let message = self.poll_read(cx)?;
313        let want_flush = self.poll_write(cx)?;
314        if want_flush {
315            self.poll_flush(cx)?;
316        }
317        match message {
318            Some(message) => Poll::Ready(Some(Ok(message))),
319            None => match self.poll_shutdown(cx) {
320                Poll::Ready(Ok(())) => Poll::Ready(None),
321                Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
322                Poll::Pending => Poll::Pending,
323            },
324        }
325    }
326}
327
328impl<S, T> Future for Connection<S, T>
329where
330    S: AsyncRead + AsyncWrite + Unpin,
331    T: AsyncRead + AsyncWrite + Unpin,
332{
333    type Output = Result<(), Error>;
334
335    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
336        while let Some(message) = ready!(self.poll_message(cx)?) {
337            if let AsyncMessage::Notice(notice) = message {
338                info!("{}: {}", notice.severity(), notice.message());
339            }
340        }
341        Poll::Ready(Ok(()))
342    }
343}