|
| 1 | +# Copyright 2023 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +def test_bqml_getting_started(random_model_id): |
| 17 | + your_model_id = random_model_id |
| 18 | + |
| 19 | + # [START bigquery_dataframes_bqml_getting_started_tutorial] |
| 20 | + from bigframes.ml.linear_model import LogisticRegression |
| 21 | + import bigframes.pandas as bpd |
| 22 | + |
| 23 | + # Start by selecting the data you'll use for training. `read_gbq` accepts |
| 24 | + # either a SQL query or a table ID. Since this example selects from multiple |
| 25 | + # tables via a wildcard, use SQL to define this data. Watch issue |
| 26 | + # https://github.com/googleapis/python-bigquery-dataframes/issues/169 |
| 27 | + # for updates to `read_gbq` to support wildcard tables. |
| 28 | + |
| 29 | + df = bpd.read_gbq( |
| 30 | + """ |
| 31 | + -- Since the order of rows isn't useful for the model training, |
| 32 | + -- generate a random ID to use as the index for the DataFrame. |
| 33 | + SELECT GENERATE_UUID() AS rowindex, * |
| 34 | + FROM |
| 35 | + `bigquery-public-data.google_analytics_sample.ga_sessions_*` |
| 36 | + WHERE |
| 37 | + _TABLE_SUFFIX BETWEEN '20160801' AND '20170630' |
| 38 | + """, |
| 39 | + index_col="rowindex", |
| 40 | + ) |
| 41 | + |
| 42 | + # Extract the total number of transactions within |
| 43 | + # the Google Analytics session. |
| 44 | + # |
| 45 | + # Because the totals column is a STRUCT data type, call |
| 46 | + # Series.struct.field("transactions") to extract the transactions field. |
| 47 | + # See the reference documentation below: |
| 48 | + # https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.operations.structs.StructAccessor#bigframes_operations_structs_StructAccessor_field |
| 49 | + transactions = df["totals"].struct.field("transactions") |
| 50 | + |
| 51 | + # The "label" values represent the outcome of the model's |
| 52 | + # prediction. In this case, the model predicts if there are any |
| 53 | + # ecommerce transactions within the Google Analytics session. |
| 54 | + # If the number of transactions is NULL, the value in the label |
| 55 | + # column is set to 0. Otherwise, it is set to 1. |
| 56 | + label = transactions.notnull().map({True: 1, False: 0}) |
| 57 | + |
| 58 | + # Extract the operating system of the visitor's device. |
| 59 | + operatingSystem = df["device"].struct.field("operatingSystem") |
| 60 | + operatingSystem = operatingSystem.fillna("") |
| 61 | + |
| 62 | + # Extract whether the visitor's device is a mobile device. |
| 63 | + isMobile = df["device"].struct.field("isMobile") |
| 64 | + |
| 65 | + # Extract the country from which the sessions originated, based on the IP address. |
| 66 | + country = df["geoNetwork"].struct.field("country").fillna("") |
| 67 | + |
| 68 | + # Extract the total number of page views within the session. |
| 69 | + pageviews = df["totals"].struct.field("pageviews").fillna(0) |
| 70 | + |
| 71 | + # Combine all the feature columns into a single DataFrame |
| 72 | + # to use as training data. |
| 73 | + features = bpd.DataFrame( |
| 74 | + { |
| 75 | + "os": operatingSystem, |
| 76 | + "is_mobile": isMobile, |
| 77 | + "country": country, |
| 78 | + "pageviews": pageviews, |
| 79 | + } |
| 80 | + ) |
| 81 | + |
| 82 | + # Logistic Regression model splits data into two classes, giving the |
| 83 | + # a confidence score that the data is in one of the classes. |
| 84 | + model = LogisticRegression() |
| 85 | + model.fit(features, label) |
| 86 | + |
| 87 | + # The model.fit() call above created a temporary model. |
| 88 | + # Use the to_gbq() method to write to a permanent location. |
| 89 | + model.to_gbq( |
| 90 | + your_model_id, # For example: "bqml_tutorial.sample_model", |
| 91 | + replace=True, |
| 92 | + ) |
| 93 | + # [END bigquery_dataframes_bqml_getting_started_tutorial] |
0 commit comments