-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathtaxiride.py
126 lines (105 loc) · 4.26 KB
/
taxiride.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Pipelines that use the DataFrame API to process NYC taxiride CSV data."""
# pytype: skip-file
from __future__ import absolute_import
import argparse
import logging
import apache_beam as beam
from apache_beam.dataframe.io import read_csv
from apache_beam.options.pipeline_options import PipelineOptions
ZONE_LOOKUP_PATH = (
"gs://apache-beam-samples/nyc_taxi/misc/taxi+_zone_lookup.csv")
def run_aggregation_pipeline(pipeline, input_path, output_path):
# The pipeline will be run on exiting the with block.
# [START DataFrame_taxiride_aggregation]
with pipeline as p:
rides = p | read_csv(input_path)
# Count the number of passengers dropped off per LocationID
agg = rides.groupby('DOLocationID').passenger_count.sum()
agg.to_csv(output_path)
# [END DataFrame_taxiride_aggregation]
def run_enrich_pipeline(
pipeline, input_path, output_path, zone_lookup_path=ZONE_LOOKUP_PATH):
"""Enrich taxi ride data with zone lookup table and perform a grouped
aggregation."""
# The pipeline will be run on exiting the with block.
# [START DataFrame_taxiride_enrich]
with pipeline as p:
rides = p | "Read taxi rides" >> read_csv(input_path)
zones = p | "Read zone lookup" >> read_csv(zone_lookup_path)
# Enrich taxi ride data with boroughs from zone lookup table
# Joins on zones.LocationID and rides.DOLocationID, by first making the
# former the index for zones.
rides = rides.merge(
zones.set_index('LocationID').Borough,
right_index=True,
left_on='DOLocationID',
how='left')
# Sum passengers dropped off per Borough
agg = rides.groupby('Borough').passenger_count.sum()
agg.to_csv(output_path)
# [END DataFrame_taxiride_enrich]
# A more intuitive alternative to the above merge call, but this option
# doesn't preserve index, thus requires non-parallel execution.
#rides = rides.merge(zones[['LocationID','Borough']],
# how="left",
# left_on='DOLocationID',
# right_on='LocationID')
def run(argv=None):
"""Main entry point."""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--input',
dest='input',
default='gs://apache-beam-samples/nyc_taxi/misc/sample.csv',
help='Input file to process.')
parser.add_argument(
'--output',
dest='output',
required=True,
help='Output file to write results to.')
parser.add_argument(
'--zone_lookup',
dest='zone_lookup_path',
default=ZONE_LOOKUP_PATH,
help='Location for taxi zone lookup CSV.')
parser.add_argument(
'--pipeline',
dest='pipeline',
default='location_id_agg',
help=(
"Choice of pipeline to run. Must be one of "
"(location_id_agg, borough_enrich)."))
known_args, pipeline_args = parser.parse_known_args(argv)
pipeline = beam.Pipeline(options=PipelineOptions(pipeline_args))
if known_args.pipeline == 'location_id_agg':
run_aggregation_pipeline(pipeline, known_args.input, known_args.output)
elif known_args.pipeline == 'borough_enrich':
run_enrich_pipeline(
pipeline,
known_args.input,
known_args.output,
known_args.zone_lookup_path)
else:
raise ValueError(
f"Unrecognized value for --pipeline: {known_args.pipeline!r}. "
"Must be one of ('location_id_agg', 'borough_enrich')")
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()