Setup¶
In [ ]:
# Setup - Run only once per Kernel App
%conda install openjdk -y
# install PySpark
%pip install pyspark==3.4.0
# install spark-nlp
%pip install spark-nlp==5.1.3
# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
Collecting package metadata (current_repodata.json): done Solving environment: done ==> WARNING: A newer version of conda exists. <== current version: 23.3.1 latest version: 23.10.0 Please update conda by running $ conda update -n base -c defaults conda Or to minimize the number of packages updated during conda update use conda install conda=23.10.0 ## Package Plan ## environment location: /opt/conda added / updated specs: - openjdk The following packages will be downloaded: package | build ---------------------------|----------------- ca-certificates-2023.08.22 | h06a4308_0 123 KB certifi-2023.11.17 | py310h06a4308_0 158 KB openjdk-11.0.13 | h87a67e3_0 341.0 MB ------------------------------------------------------------ Total: 341.3 MB The following NEW packages will be INSTALLED: openjdk pkgs/main/linux-64::openjdk-11.0.13-h87a67e3_0 The following packages will be UPDATED: ca-certificates conda-forge::ca-certificates-2023.7.2~ --> pkgs/main::ca-certificates-2023.08.22-h06a4308_0 certifi conda-forge/noarch::certifi-2023.7.22~ --> pkgs/main/linux-64::certifi-2023.11.17-py310h06a4308_0 Downloading and Extracting Packages ca-certificates-2023 | 123 KB | | 0% certifi-2023.11.17 | 158 KB | | 0% ca-certificates-2023 | 123 KB | ##################################### | 100% openjdk-11.0.13 | 341.0 MB | 1 | 0% certifi-2023.11.17 | 158 KB | ##################################### | 100% certifi-2023.11.17 | 158 KB | ##################################### | 100% openjdk-11.0.13 | 341.0 MB | # | 3% openjdk-11.0.13 | 341.0 MB | #9 | 5% openjdk-11.0.13 | 341.0 MB | ##9 | 8% openjdk-11.0.13 | 341.0 MB | ###7 | 10% openjdk-11.0.13 | 341.0 MB | ####8 | 13% openjdk-11.0.13 | 341.0 MB | #####8 | 16% openjdk-11.0.13 | 341.0 MB | ######8 | 18% openjdk-11.0.13 | 341.0 MB | #######8 | 21% openjdk-11.0.13 | 341.0 MB | ######### | 25% openjdk-11.0.13 | 341.0 MB | ##########7 | 29% openjdk-11.0.13 | 341.0 MB | ############2 | 33% openjdk-11.0.13 | 341.0 MB | #############8 | 38% openjdk-11.0.13 | 341.0 MB | ###############5 | 42% openjdk-11.0.13 | 341.0 MB | #################2 | 47% openjdk-11.0.13 | 341.0 MB | ##################8 | 51% openjdk-11.0.13 | 341.0 MB | ####################5 | 56% openjdk-11.0.13 | 341.0 MB | ######################1 | 60% openjdk-11.0.13 | 341.0 MB | #######################7 | 64% openjdk-11.0.13 | 341.0 MB | #########################5 | 69% openjdk-11.0.13 | 341.0 MB | ###########################1 | 73% openjdk-11.0.13 | 341.0 MB | ############################8 | 78% openjdk-11.0.13 | 341.0 MB | ##############################5 | 83% openjdk-11.0.13 | 341.0 MB | ################################1 | 87% openjdk-11.0.13 | 341.0 MB | #################################8 | 91% openjdk-11.0.13 | 341.0 MB | ###################################3 | 95% openjdk-11.0.13 | 341.0 MB | ####################################7 | 99% Preparing transaction: done Verifying transaction: done Executing transaction: done Note: you may need to restart the kernel to use updated packages. Collecting pyspark==3.4.0 Using cached pyspark-3.4.0-py2.py3-none-any.whl Collecting py4j==0.10.9.7 (from pyspark==3.4.0) Using cached py4j-0.10.9.7-py2.py3-none-any.whl (200 kB) Installing collected packages: py4j, pyspark Successfully installed py4j-0.10.9.7 pyspark-3.4.0 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv [notice] A new release of pip is available: 23.2.1 -> 23.3.1 [notice] To update, run: pip install --upgrade pip Note: you may need to restart the kernel to use updated packages. Collecting spark-nlp==5.1.3 Obtaining dependency information for spark-nlp==5.1.3 from https://files.pythonhosted.org/packages/cd/7d/bc0eca4c9ec4c9c1d9b28c42c2f07942af70980a7d912d0aceebf8db32dd/spark_nlp-5.1.3-py2.py3-none-any.whl.metadata Using cached spark_nlp-5.1.3-py2.py3-none-any.whl.metadata (53 kB) Using cached spark_nlp-5.1.3-py2.py3-none-any.whl (537 kB) Installing collected packages: spark-nlp Successfully installed spark-nlp-5.1.3 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv [notice] A new release of pip is available: 23.2.1 -> 23.3.1 [notice] To update, run: pip install --upgrade pip Note: you may need to restart the kernel to use updated packages.
Out[ ]:
Setup Spark Session¶
In [ ]:
import sagemaker
session = sagemaker.Session()
bucket = session.default_bucket()
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
spark = (
SparkSession.builder.appName("PySparkApp")
.config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.2.2")
.config(
"fs.s3a.aws.credentials.provider",
"com.amazonaws.auth.ContainerCredentialsProvider",
)
.getOrCreate()
)
print(spark.version)
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
Warning: Ignoring non-Spark config property: fs.s3a.aws.credentials.provider
:: loading settings :: url = jar:file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml
Ivy Default Cache set to: /root/.ivy2/cache The jars for the packages stored in: /root/.ivy2/jars org.apache.hadoop#hadoop-aws added as a dependency :: resolving dependencies :: org.apache.spark#spark-submit-parent-6d8ece39-5cae-4e0e-aad3-a1bb4dd7d1e2;1.0 confs: [default] found org.apache.hadoop#hadoop-aws;3.2.2 in central found com.amazonaws#aws-java-sdk-bundle;1.11.563 in central :: resolution report :: resolve 537ms :: artifacts dl 42ms :: modules in use: com.amazonaws#aws-java-sdk-bundle;1.11.563 from central in [default] org.apache.hadoop#hadoop-aws;3.2.2 from central in [default] --------------------------------------------------------------------- | | modules || artifacts | | conf | number| search|dwnlded|evicted|| number|dwnlded| --------------------------------------------------------------------- | default | 2 | 0 | 0 | 0 || 2 | 0 | --------------------------------------------------------------------- :: retrieving :: org.apache.spark#spark-submit-parent-6d8ece39-5cae-4e0e-aad3-a1bb4dd7d1e2 confs: [default] 0 artifacts copied, 2 already retrieved (0kB/25ms) 23/11/30 02:50:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
3.4.0
Load in Comments Dataframes¶
In [ ]:
# Tegveer's S3 -- DO NOT CHANGE
s3_directory_comms = f"s3a://sagemaker-us-east-1-433974840707/project/nlp_cleaned_comments/"
# Read all the Parquet files in the directory into a DataFrame
df_comments = spark.read.parquet(s3_directory_comms)
23/11/29 19:19:23 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
Pre-process features¶
In [ ]:
# change `controversiality` data type from int to str, get year and month cols from `created_utc`, drop [removed] and [deleted] comments
import pyspark.sql.functions as F
from pyspark.sql.types import StringType
df_comments = df_comments.withColumn("controversiality",
df_comments["controversiality"]
.cast(StringType())) \
.withColumn('year', F.year('created_utc')) \
.withColumn('month', F.month('created_utc')) \
.withColumn ('day', F.dayofmonth("created_utc")) \
.withColumn('distinguished', F.when(df_comments['distinguished'] == 'moderator', 'yes')
.otherwise(df_comments['distinguished']))
df_comments = df_comments.fillna({'distinguished': "no"})
pattern = r"\[removed\]|\[deleted\]"
df_comments = df_comments.filter(~(F.col("body").rlike(pattern) | F.col("author").rlike(pattern)))
# select required cols for ML
df_comments = df_comments.select("controversiality", "distinguished", "subreddit", "year", "month", "day", "gilded", "score")
In [ ]:
print("Number of records in sampled and filtered df: ", df_comments.count())
[Stage 1:========================================================>(63 + 1) / 64]
Number of records in sampled and filtered df: 13242001
In [ ]:
df_comments.groupby('controversiality').count().show()
[Stage 4:========================================================>(63 + 1) / 64]
+----------------+--------+ |controversiality| count| +----------------+--------+ | 0|12311780| | 1| 930221| +----------------+--------+
In [ ]:
df_comments.printSchema()
root |-- controversiality: string (nullable = true) |-- distinguished: string (nullable = false) |-- subreddit: string (nullable = true) |-- year: integer (nullable = true) |-- month: integer (nullable = true) |-- day: integer (nullable = true) |-- gilded: integer (nullable = true) |-- score: integer (nullable = true)
Adding weight
column¶
In [ ]:
import numpy as np
y_collect = df_comments.select("controversiality").groupBy("controversiality").count().collect()
unique_y = [x["controversiality"] for x in y_collect]
total_y = sum([x["count"] for x in y_collect])
unique_y_count = len(y_collect)
bin_count = [x["count"] for x in y_collect]
class_weights_spark = {i: ii for i, ii in zip(unique_y, total_y / (unique_y_count * np.array(bin_count)))}
print(class_weights_spark)
[Stage 7:========================================================>(63 + 1) / 64]
{'0': 0.5377776812126273, '1': 7.117663974474882}
In [ ]:
from itertools import chain
mapping_expr = F.create_map([F.lit(x) for x in chain(*class_weights_spark.items())])
df_comments = df_comments.withColumn("weight", mapping_expr.getItem(F.col("controversiality")))
/opt/conda/lib/python3.10/site-packages/pyspark/sql/column.py:458: FutureWarning: A column as 'key' in getItem is deprecated as of Spark 3.0, and will not be supported in the future release. Use `column[key]` or `column.key` syntax instead. warnings.warn(
In [ ]:
# Save DataFrame to Tegveer's S3
s3_bucket = f"s3a://sagemaker-us-east-1-433974840707/project/ml_comments/"
# Write DataFrame to S3 in Parquet format
df_comments.write.mode("overwrite").parquet(s3_bucket)
In [ ]:
# sanity check
! aws s3 ls s3://sagemaker-us-east-1-433974840707/project/
PRE cleaned/ PRE comments/ PRE ml/ PRE ml_comments/ PRE nlp/ PRE nlp_cleaned_comments/ PRE nlp_cleaned_submissions/ PRE sentiment/ PRE submissions/ 2023-11-16 19:15:27 708534094 spark-nlp-assembly-5.1.3.jar
In [ ]:
# sanity check
! aws s3 ls s3://sagemaker-us-east-1-433974840707/project/ml_comments/
2023-11-29 19:26:24 0 _SUCCESS 2023-11-29 19:25:43 594198 part-00000-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:44 589362 part-00001-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:44 585372 part-00002-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:45 594971 part-00003-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:46 589538 part-00004-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:46 589704 part-00005-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:47 588680 part-00006-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:48 588935 part-00007-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:48 596708 part-00008-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:49 592220 part-00009-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:50 592149 part-00010-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:50 590467 part-00011-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:51 592791 part-00012-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:52 596045 part-00013-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:52 591944 part-00014-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:53 598898 part-00015-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:53 589267 part-00016-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:54 595070 part-00017-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:55 598616 part-00018-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:55 595925 part-00019-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:56 591923 part-00020-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:57 590291 part-00021-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:57 594049 part-00022-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:58 585946 part-00023-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:59 592030 part-00024-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:25:59 592298 part-00025-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:00 590378 part-00026-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:01 596147 part-00027-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:01 591614 part-00028-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:02 589628 part-00029-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:02 592963 part-00030-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:03 597448 part-00031-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:04 598247 part-00032-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:04 596369 part-00033-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:05 588168 part-00034-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:06 593327 part-00035-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:06 590642 part-00036-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:07 592382 part-00037-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:08 592080 part-00038-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:09 586510 part-00039-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:09 593508 part-00040-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:10 594475 part-00041-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:10 594784 part-00042-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:11 589972 part-00043-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:12 598985 part-00044-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:12 594861 part-00045-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:13 593997 part-00046-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:13 594913 part-00047-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:14 590775 part-00048-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:15 588999 part-00049-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:15 593737 part-00050-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:16 590286 part-00051-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:16 592259 part-00052-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:17 588594 part-00053-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:17 594006 part-00054-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:18 595317 part-00055-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:19 592238 part-00056-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:19 590498 part-00057-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:20 591853 part-00058-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:21 594221 part-00059-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:21 593457 part-00060-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:22 595649 part-00061-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:22 586560 part-00062-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet 2023-11-29 19:26:23 589307 part-00063-8f86faa0-d114-46be-9005-e82cbea8abbd-c000.snappy.parquet
Preparing features for text classification - Comments Dataset Only (as Submissions doesn't have controversiality
)¶
In [ ]:
%%writefile ./controversial_processing.py
import os
import sys
import logging
import argparse
# Import pyspark and build Spark session
from pyspark.sql.functions import *
from pyspark.sql.types import (
DoubleType,
IntegerType,
StringType,
StructField,
StructType,
)
import json
import sparknlp
import numpy as np
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression, NaiveBayes, GBTClassifier, LinearSVC
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline, Model
logging.basicConfig(format='%(asctime)s,%(levelname)s,%(module)s,%(filename)s,%(lineno)d,%(message)s', level=logging.DEBUG)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
def main():
parser = argparse.ArgumentParser(description="app inputs and outputs")
parser.add_argument("--ml_model", type=str, help="Model used for Classification")
parser.add_argument("--s3_dataset_path", type=str, help="Path of dataset in S3")
parser.add_argument("--s3_output_bucket", type=str, help="s3 output bucket")
parser.add_argument("--s3_output_key_prefix", type=str, help="s3 output key prefix")
args = parser.parse_args()
logger.info(f"args={args}")
spark = SparkSession.builder \
.appName("Spark ML")\
.config("spark.driver.memory","16G")\
.config("spark.driver.maxResultSize", "0") \
.config("spark.kryoserializer.buffer.max", "2000M")\
.getOrCreate()
logger.info(f"Spark version: {spark.version}")
logger.info(f"sparknlp version: {sparknlp.version()}")
# This is needed to save RDDs which is the only way to write nested Dataframes into CSV format
sc = spark.sparkContext
sc._jsc.hadoopConfiguration().set(
"mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter"
)
# get model
ml_model = args.ml_model
# Downloading the data from S3 into a Dataframe
logger.info(f"going to read {args.s3_dataset_path}")
df = spark.read.parquet(args.s3_dataset_path, header=True)
stringIndexer_controversiality = StringIndexer(inputCol="controversiality", outputCol="controversiality_str")
stringIndexer_distinguished = StringIndexer(inputCol="distinguished", outputCol="distinguished_ix")
stringIndexer_subreddit = StringIndexer(inputCol="subreddit", outputCol="subreddit_ix")
logger.info(f"Applying stringIndexer_controversiality to the dataframe")
indexed = stringIndexer_controversiality.fit(df).transform(df)
logger.info(f"Applying stringIndexer_distinguished to the new dataframe `indexed`")
indexed = stringIndexer_distinguished.fit(indexed).transform(indexed)
logger.info(f"Applying stringIndexer_subreddit to the new dataframe `indexed`")
indexed = stringIndexer_subreddit.fit(indexed).transform(indexed)
logger.info(f"One-hot Encoding")
onehot_subreddit = OneHotEncoder(inputCol="subreddit_ix", outputCol="subreddit_vec")
logger.info(f"Vector assembling features")
vectorAssembler_features = VectorAssembler(inputCols=['distinguished_ix', 'year', 'month', 'day', 'score', 'gilded', 'subreddit_ix'], # spark processing job throws algorithmerror when subreddit_vec column is included in VectorAssembler
outputCol= 'features')
logger.info(f"Creating new vectorized features df")
features_vectorized = vectorAssembler_features.transform(indexed) # note this is a new df
logger.info(f"Creating {ml_model}")
if ml_model == "rf":
pipeline_model = RandomForestClassifier(labelCol="controversiality_str", featuresCol="features", numTrees=30, weightCol="weight")
elif ml_model == "lr":
pipeline_model = LogisticRegression(labelCol="controversiality_str", featuresCol="features", maxIter=10, weightCol="weight")
elif ml_model == "gbt":
pipeline_model = GBTClassifier(labelCol="controversiality_str", featuresCol="features", maxIter=10, weightCol="weight")
elif ml_model == "svm":
pipeline_model = LinearSVC(labelCol="controversiality_str", featuresCol="features", maxIter=10, weightCol="weight")
logger.info(f"Creating Label Converter")
labelConverter = IndexToString(inputCol="prediction",
outputCol="predictedControversiality",
labels=["0", "1"])
logger.info(f"Creating Pipeline")
pipeline_model = Pipeline(stages=
[stringIndexer_controversiality,
stringIndexer_distinguished,
stringIndexer_subreddit,
onehot_subreddit,
vectorAssembler_features,
pipeline_model,
labelConverter]
)
logger.info(f"Split data into train, test, and validation")
train_data, test_data = df.randomSplit([0.75, 0.25], 24)
logger.info("Number of training records: " + str(train_data.count()))
logger.info("Number of testing records : " + str(test_data.count()))
logger.info(f"going to fit pipeline on train dataframe")
model = pipeline_model.fit(train_data)
# save the model
s3_path = f"s3://{args.s3_output_bucket}/{args.s3_output_key_prefix}"
logger.info(f"going to save model in {s3_path}")
model.save(f"{s3_path}/{ml_model}.model")
if __name__ == "__main__":
main()
Overwriting ./controversial_processing.py
In [ ]:
import boto3
import sagemaker
from sagemaker.spark.processing import PySparkProcessor
account_id = boto3.client('sts').get_caller_identity()['Account']
# Setup the PySpark processor to run the job. Note the instance type and instance count parameters. SageMaker will create these many instances of this type for the spark job.
role = sagemaker.get_execution_role()
spark_processor = PySparkProcessor(
base_job_name="sm-spark-project-ml",
image_uri=f"{account_id}.dkr.ecr.us-east-1.amazonaws.com/sagemaker-spark:latest",
role=role,
instance_count=8,
instance_type="ml.m5.xlarge",
max_runtime_in_seconds=7200,
)
# s3 paths
output_prefix = f"project/ml_updated"
output_prefix_logs = f"spark_logs/ml_updated"
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
In [ ]:
import time
models = ['rf', 'lr', 'gbt', 'svm']
for model in models:
# comments
print(f"Working on Comments for model {model}")
spark_processor.run(
submit_app="./controversial_processing.py",
arguments=[
"--ml_model",
model,
"--s3_dataset_path",
f"s3://sagemaker-us-east-1-433974840707/project/ml_comments/",
"--s3_output_bucket",
"sagemaker-us-east-1-433974840707",
"--s3_output_key_prefix",
f"{output_prefix}/{model}/",
],
spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(bucket, output_prefix_logs),
logs=False,
)
time.sleep(60)
Working on Comments for model rf
INFO:sagemaker:Creating processing-job with name sm-spark-project-ml-2023-11-30-04-06-37-362
.........................................................................................................!
INFO:sagemaker:Creating processing-job with name sm-spark-project-ml-2023-11-30-04-16-32-586
Working on Comments for model lr ...........................................................................................................!
INFO:sagemaker:Creating processing-job with name sm-spark-project-ml-2023-11-30-04-26-37-647
Working on Comments for model gbt .........................................................................................................!
INFO:sagemaker:Creating processing-job with name sm-spark-project-ml-2023-11-30-04-36-32-377
Working on Comments for model svm ................................................................................................!
In [ ]:
# sanity check
! aws s3 ls s3://sagemaker-us-east-1-433974840707/project/
PRE cleaned/ PRE comments/ PRE ml/ PRE ml_comments/ PRE nlp/ PRE nlp_cleaned_comments/ PRE nlp_cleaned_submissions/ PRE sentiment/ PRE submissions/ 2023-11-16 19:15:27 708534094 spark-nlp-assembly-5.1.3.jar
In [ ]:
# sanity check
! aws s3 ls s3://sagemaker-us-east-1-433974840707/project/ml/
PRE gbt/ PRE lr/ PRE rf/ PRE svm/
In [ ]: