# Setup - Run only once per Kernel App
%conda install openjdk -y
# install PySpark
%pip install pyspark==3.4.0
# install spark-nlp
%pip install spark-nlp==5.1.3
# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
Retrieving notices: ...working... done Collecting package metadata (current_repodata.json): done Solving environment: done ==> WARNING: A newer version of conda exists. <== current version: 23.3.1 latest version: 23.10.0 Please update conda by running $ conda update -n base -c defaults conda Or to minimize the number of packages updated during conda update use conda install conda=23.10.0 ## Package Plan ## environment location: /opt/conda added / updated specs: - openjdk The following packages will be downloaded: package | build ---------------------------|----------------- ca-certificates-2023.08.22 | h06a4308_0 123 KB certifi-2023.11.17 | py310h06a4308_0 158 KB openjdk-11.0.13 | h87a67e3_0 341.0 MB ------------------------------------------------------------ Total: 341.3 MB The following NEW packages will be INSTALLED: openjdk pkgs/main/linux-64::openjdk-11.0.13-h87a67e3_0 The following packages will be UPDATED: ca-certificates conda-forge::ca-certificates-2023.7.2~ --> pkgs/main::ca-certificates-2023.08.22-h06a4308_0 certifi conda-forge/noarch::certifi-2023.7.22~ --> pkgs/main/linux-64::certifi-2023.11.17-py310h06a4308_0 Downloading and Extracting Packages certifi-2023.11.17 | 158 KB | | 0% ca-certificates-2023 | 123 KB | | 0% openjdk-11.0.13 | 341.0 MB | | 0% certifi-2023.11.17 | 158 KB | ##################################### | 100% openjdk-11.0.13 | 341.0 MB | 3 | 1% openjdk-11.0.13 | 341.0 MB | #8 | 5% openjdk-11.0.13 | 341.0 MB | ###5 | 10% openjdk-11.0.13 | 341.0 MB | ####9 | 13% openjdk-11.0.13 | 341.0 MB | ######8 | 18% openjdk-11.0.13 | 341.0 MB | ########7 | 24% openjdk-11.0.13 | 341.0 MB | ##########7 | 29% openjdk-11.0.13 | 341.0 MB | ############6 | 34% openjdk-11.0.13 | 341.0 MB | ##############5 | 39% openjdk-11.0.13 | 341.0 MB | ################3 | 44% openjdk-11.0.13 | 341.0 MB | ##################2 | 49% openjdk-11.0.13 | 341.0 MB | ####################2 | 55% openjdk-11.0.13 | 341.0 MB | ######################2 | 60% openjdk-11.0.13 | 341.0 MB | ########################1 | 65% openjdk-11.0.13 | 341.0 MB | ########################## | 70% openjdk-11.0.13 | 341.0 MB | ############################ | 76% openjdk-11.0.13 | 341.0 MB | #############################9 | 81% openjdk-11.0.13 | 341.0 MB | ###############################8 | 86% openjdk-11.0.13 | 341.0 MB | #################################7 | 91% openjdk-11.0.13 | 341.0 MB | ###################################6 | 96% Preparing transaction: done Verifying transaction: done Executing transaction: done Note: you may need to restart the kernel to use updated packages. Collecting pyspark==3.4.0 Using cached pyspark-3.4.0-py2.py3-none-any.whl Collecting py4j==0.10.9.7 (from pyspark==3.4.0) Using cached py4j-0.10.9.7-py2.py3-none-any.whl (200 kB) Installing collected packages: py4j, pyspark Successfully installed py4j-0.10.9.7 pyspark-3.4.0 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv [notice] A new release of pip is available: 23.2.1 -> 23.3.1 [notice] To update, run: pip install --upgrade pip Note: you may need to restart the kernel to use updated packages. Collecting spark-nlp==5.1.3 Obtaining dependency information for spark-nlp==5.1.3 from https://files.pythonhosted.org/packages/cd/7d/bc0eca4c9ec4c9c1d9b28c42c2f07942af70980a7d912d0aceebf8db32dd/spark_nlp-5.1.3-py2.py3-none-any.whl.metadata Using cached spark_nlp-5.1.3-py2.py3-none-any.whl.metadata (53 kB) Using cached spark_nlp-5.1.3-py2.py3-none-any.whl (537 kB) Installing collected packages: spark-nlp Successfully installed spark-nlp-5.1.3 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv [notice] A new release of pip is available: 23.2.1 -> 23.3.1 [notice] To update, run: pip install --upgrade pip Note: you may need to restart the kernel to use updated packages.
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
# Import pyspark and build Spark session
spark = SparkSession.builder \
.appName("Spark NLP")\
.master("local[*]")\
.config("spark.driver.memory","16G")\
.config("spark.executor.memory", "12g")\
.config("spark.executor.cores", "3")\
.config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.1.3,org.apache.hadoop:hadoop-aws:3.2.2")\
.config(
"fs.s3a.aws.credentials.provider",
"com.amazonaws.auth.ContainerCredentialsProvider"
)\
.getOrCreate()
print(spark.version)
Warning: Ignoring non-Spark config property: fs.s3a.aws.credentials.provider
:: loading settings :: url = jar:file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml
Ivy Default Cache set to: /root/.ivy2/cache The jars for the packages stored in: /root/.ivy2/jars com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency org.apache.hadoop#hadoop-aws added as a dependency :: resolving dependencies :: org.apache.spark#spark-submit-parent-0a72fff1-6a0f-413c-a4ed-44cb433aeee1;1.0 confs: [default] found com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 in central found com.typesafe#config;1.4.2 in central found org.rocksdb#rocksdbjni;6.29.5 in central found com.amazonaws#aws-java-sdk-bundle;1.11.828 in central found com.github.universal-automata#liblevenshtein;3.0.0 in central found com.google.protobuf#protobuf-java-util;3.0.0-beta-3 in central found com.google.protobuf#protobuf-java;3.0.0-beta-3 in central found com.google.code.gson#gson;2.3 in central found it.unimi.dsi#fastutil;7.0.12 in central found org.projectlombok#lombok;1.16.8 in central found com.google.cloud#google-cloud-storage;2.20.1 in central found com.google.guava#guava;31.1-jre in central found com.google.guava#failureaccess;1.0.1 in central found com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava in central found com.google.errorprone#error_prone_annotations;2.18.0 in central found com.google.j2objc#j2objc-annotations;1.3 in central found com.google.http-client#google-http-client;1.43.0 in central found io.opencensus#opencensus-contrib-http-util;0.31.1 in central found com.google.http-client#google-http-client-jackson2;1.43.0 in central found com.google.http-client#google-http-client-gson;1.43.0 in central found com.google.api-client#google-api-client;2.2.0 in central found commons-codec#commons-codec;1.15 in central found com.google.oauth-client#google-oauth-client;1.34.1 in central found com.google.http-client#google-http-client-apache-v2;1.43.0 in central found com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 in central found com.google.code.gson#gson;2.10.1 in central found com.google.cloud#google-cloud-core;2.12.0 in central found io.grpc#grpc-context;1.53.0 in central found com.google.auto.value#auto-value-annotations;1.10.1 in central found com.google.auto.value#auto-value;1.10.1 in central found javax.annotation#javax.annotation-api;1.3.2 in central found commons-logging#commons-logging;1.2 in central found com.google.cloud#google-cloud-core-http;2.12.0 in central found com.google.http-client#google-http-client-appengine;1.43.0 in central found com.google.api#gax-httpjson;0.108.2 in central found com.google.cloud#google-cloud-core-grpc;2.12.0 in central found io.grpc#grpc-alts;1.53.0 in central found io.grpc#grpc-grpclb;1.53.0 in central found org.conscrypt#conscrypt-openjdk-uber;2.5.2 in central found io.grpc#grpc-auth;1.53.0 in central found io.grpc#grpc-protobuf;1.53.0 in central found io.grpc#grpc-protobuf-lite;1.53.0 in central found io.grpc#grpc-core;1.53.0 in central found com.google.api#gax;2.23.2 in central found com.google.api#gax-grpc;2.23.2 in central found com.google.auth#google-auth-library-credentials;1.16.0 in central found com.google.auth#google-auth-library-oauth2-http;1.16.0 in central found com.google.api#api-common;2.6.2 in central found io.opencensus#opencensus-api;0.31.1 in central found com.google.api.grpc#proto-google-iam-v1;1.9.2 in central found com.google.protobuf#protobuf-java;3.21.12 in central found com.google.protobuf#protobuf-java-util;3.21.12 in central found com.google.api.grpc#proto-google-common-protos;2.14.2 in central found org.threeten#threetenbp;1.6.5 in central found com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha in central found com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha in central found com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha in central found com.fasterxml.jackson.core#jackson-core;2.14.2 in central found com.google.code.findbugs#jsr305;3.0.2 in central found io.grpc#grpc-api;1.53.0 in central found io.grpc#grpc-stub;1.53.0 in central found org.checkerframework#checker-qual;3.31.0 in central found io.perfmark#perfmark-api;0.26.0 in central found com.google.android#annotations;4.1.1.4 in central found org.codehaus.mojo#animal-sniffer-annotations;1.22 in central found io.opencensus#opencensus-proto;0.2.0 in central found io.grpc#grpc-services;1.53.0 in central found com.google.re2j#re2j;1.6 in central found io.grpc#grpc-netty-shaded;1.53.0 in central found io.grpc#grpc-googleapis;1.53.0 in central found io.grpc#grpc-xds;1.53.0 in central found com.navigamez#greex;1.0 in central found dk.brics.automaton#automaton;1.11-8 in central found com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 in central found com.microsoft.onnxruntime#onnxruntime;1.15.0 in central found org.apache.hadoop#hadoop-aws;3.2.2 in central :: resolution report :: resolve 3909ms :: artifacts dl 492ms :: modules in use: com.amazonaws#aws-java-sdk-bundle;1.11.828 from central in [default] com.fasterxml.jackson.core#jackson-core;2.14.2 from central in [default] com.github.universal-automata#liblevenshtein;3.0.0 from central in [default] com.google.android#annotations;4.1.1.4 from central in [default] com.google.api#api-common;2.6.2 from central in [default] com.google.api#gax;2.23.2 from central in [default] com.google.api#gax-grpc;2.23.2 from central in [default] com.google.api#gax-httpjson;0.108.2 from central in [default] com.google.api-client#google-api-client;2.2.0 from central in [default] com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha from central in [default] com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha from central in [default] com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha from central in [default] com.google.api.grpc#proto-google-common-protos;2.14.2 from central in [default] com.google.api.grpc#proto-google-iam-v1;1.9.2 from central in [default] com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 from central in [default] com.google.auth#google-auth-library-credentials;1.16.0 from central in [default] com.google.auth#google-auth-library-oauth2-http;1.16.0 from central in [default] com.google.auto.value#auto-value;1.10.1 from central in [default] com.google.auto.value#auto-value-annotations;1.10.1 from central in [default] com.google.cloud#google-cloud-core;2.12.0 from central in [default] com.google.cloud#google-cloud-core-grpc;2.12.0 from central in [default] com.google.cloud#google-cloud-core-http;2.12.0 from central in [default] com.google.cloud#google-cloud-storage;2.20.1 from central in [default] com.google.code.findbugs#jsr305;3.0.2 from central in [default] com.google.code.gson#gson;2.10.1 from central in [default] com.google.errorprone#error_prone_annotations;2.18.0 from central in [default] com.google.guava#failureaccess;1.0.1 from central in [default] com.google.guava#guava;31.1-jre from central in [default] com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava from central in [default] com.google.http-client#google-http-client;1.43.0 from central in [default] com.google.http-client#google-http-client-apache-v2;1.43.0 from central in [default] com.google.http-client#google-http-client-appengine;1.43.0 from central in [default] com.google.http-client#google-http-client-gson;1.43.0 from central in [default] com.google.http-client#google-http-client-jackson2;1.43.0 from central in [default] com.google.j2objc#j2objc-annotations;1.3 from central in [default] com.google.oauth-client#google-oauth-client;1.34.1 from central in [default] com.google.protobuf#protobuf-java;3.21.12 from central in [default] com.google.protobuf#protobuf-java-util;3.21.12 from central in [default] com.google.re2j#re2j;1.6 from central in [default] com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 from central in [default] com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 from central in [default] com.microsoft.onnxruntime#onnxruntime;1.15.0 from central in [default] com.navigamez#greex;1.0 from central in [default] com.typesafe#config;1.4.2 from central in [default] commons-codec#commons-codec;1.15 from central in [default] commons-logging#commons-logging;1.2 from central in [default] dk.brics.automaton#automaton;1.11-8 from central in [default] io.grpc#grpc-alts;1.53.0 from central in [default] io.grpc#grpc-api;1.53.0 from central in [default] io.grpc#grpc-auth;1.53.0 from central in [default] io.grpc#grpc-context;1.53.0 from central in [default] io.grpc#grpc-core;1.53.0 from central in [default] io.grpc#grpc-googleapis;1.53.0 from central in [default] io.grpc#grpc-grpclb;1.53.0 from central in [default] io.grpc#grpc-netty-shaded;1.53.0 from central in [default] io.grpc#grpc-protobuf;1.53.0 from central in [default] io.grpc#grpc-protobuf-lite;1.53.0 from central in [default] io.grpc#grpc-services;1.53.0 from central in [default] io.grpc#grpc-stub;1.53.0 from central in [default] io.grpc#grpc-xds;1.53.0 from central in [default] io.opencensus#opencensus-api;0.31.1 from central in [default] io.opencensus#opencensus-contrib-http-util;0.31.1 from central in [default] io.opencensus#opencensus-proto;0.2.0 from central in [default] io.perfmark#perfmark-api;0.26.0 from central in [default] it.unimi.dsi#fastutil;7.0.12 from central in [default] javax.annotation#javax.annotation-api;1.3.2 from central in [default] org.apache.hadoop#hadoop-aws;3.2.2 from central in [default] org.checkerframework#checker-qual;3.31.0 from central in [default] org.codehaus.mojo#animal-sniffer-annotations;1.22 from central in [default] org.conscrypt#conscrypt-openjdk-uber;2.5.2 from central in [default] org.projectlombok#lombok;1.16.8 from central in [default] org.rocksdb#rocksdbjni;6.29.5 from central in [default] org.threeten#threetenbp;1.6.5 from central in [default] :: evicted modules: com.google.protobuf#protobuf-java-util;3.0.0-beta-3 by [com.google.protobuf#protobuf-java-util;3.21.12] in [default] com.google.protobuf#protobuf-java;3.0.0-beta-3 by [com.google.protobuf#protobuf-java;3.21.12] in [default] com.google.code.gson#gson;2.3 by [com.google.code.gson#gson;2.10.1] in [default] com.amazonaws#aws-java-sdk-bundle;1.11.563 by [com.amazonaws#aws-java-sdk-bundle;1.11.828] in [default] --------------------------------------------------------------------- | | modules || artifacts | | conf | number| search|dwnlded|evicted|| number|dwnlded| --------------------------------------------------------------------- | default | 77 | 0 | 0 | 4 || 73 | 0 | --------------------------------------------------------------------- :: retrieving :: org.apache.spark#spark-submit-parent-0a72fff1-6a0f-413c-a4ed-44cb433aeee1 confs: [default] 0 artifacts copied, 73 already retrieved (0kB/160ms) 23/11/28 17:44:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
3.4.0
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from sparknlp.annotator import *
from sparknlp.base import *
import sparknlp
from sparknlp.pretrained import PretrainedPipeline
from sparknlp.base import Finisher, DocumentAssembler
from pyspark.sql.functions import length
%%time
bucket = "project-group34"
session = sagemaker.Session()
output_prefix_data_submissions = "project/submissions/yyyy=*"
s3_path = f"s3a://{bucket}/{output_prefix_data_submissions}"
print(f"reading comments from {s3_path}")
submissions = spark.read.parquet(s3_path, header=True)
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml reading comments from s3a://project-group34/project/submissions/yyyy=*
23/11/27 21:29:41 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
CPU times: user 199 ms, sys: 16.9 ms, total: 216 ms Wall time: 6.49 s
23/11/27 21:29:47 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
print(f"Shape of the submissions data is: {submissions.count()} x {len(submissions.columns)}")
[Stage 1:=======================================================>(99 + 1) / 100]
Shape of the submissions data is: 875969 x 68
submissions = submissions.select("subreddit", "title", "selftext", "score", "num_comments", "over_18", "is_self", "is_video", "domain", "created_utc", "author", "author_flair_text", "media")
# Assuming your DataFrame is named `df`
submissions = submissions.withColumn('post_length', length(submissions.title) + length(submissions.selftext))
from pyspark.sql import functions as F
submissions = submissions.withColumn('created_utc', F.to_timestamp('created_utc'))
# Extract time-based features
submissions = submissions.withColumn('hour_of_day', F.hour('created_utc'))
submissions = submissions.withColumn('day_of_week', F.dayofweek('created_utc')) # 1 (Sunday) to 7 (Saturday)
# Map each day of the week from numeric to string
submissions = submissions.withColumn('day_of_week_str', F.expr("""
CASE day_of_week
WHEN 1 THEN 'Sunday'
WHEN 2 THEN 'Monday'
WHEN 3 THEN 'Tuesday'
WHEN 4 THEN 'Wednesday'
WHEN 5 THEN 'Thursday'
WHEN 6 THEN 'Friday'
WHEN 7 THEN 'Saturday'
END
"""))
submissions = submissions.withColumn('day_of_month', F.dayofmonth('created_utc'))
submissions = submissions.withColumn('month', F.month('created_utc'))
submissions = submissions.withColumn('year', F.year('created_utc'))
submissions = submissions.withColumn('has_media', F.col('media').isNotNull())
submissions = submissions.drop(*["media"])
submissions = submissions.select('subreddit',
'title',
'selftext',
'score',
'num_comments',
'over_18',
'is_self',
'is_video',
'domain',
'post_length',
'hour_of_day',
'day_of_week',
'day_of_week_str',
'day_of_month',
'month',
'year',
'has_media')
# Combine 'title' and 'selftext' into a new column 'body'
submissions = submissions.withColumn("body", concat_ws(" ", col("title"), col("selftext")))
submissions = submissions.drop(*["title", "selftext"])
submissions.show(5)
[Stage 1:> (0 + 1) / 1]
+----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+---------------+------------+-----+----+---------+--------------------+ | subreddit|score|num_comments|over_18|is_self|is_video| domain|post_length|hour_of_day|day_of_week|day_of_week_str|day_of_month|month|year|has_media| body| +----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+---------------+------------+-----+----+---------+--------------------+ |television| 0| 9| false| true| false|self.television| 605| 22| 4| Wednesday| 27| 1|2021| false|Is there a websit...| | anime| 0| 3| false| false| false| i.redd.it| 50| 22| 4| Wednesday| 27| 1|2021| false|Does anyone know ...| |television| 4| 11| false| false| false| deadline.com| 86| 22| 4| Wednesday| 27| 1|2021| false|‘Doogie Kameāloha...| | movies| 0| 4| false| true| false| self.movies| 42| 22| 4| Wednesday| 27| 1|2021| false|4K movies on desk...| | anime| 0| 9| false| true| false| self.anime| 64| 22| 4| Wednesday| 27| 1|2021| false|Where can I buy a...| +----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+---------------+------------+-----+----+---------+--------------------+ only showing top 5 rows
from pyspark.sql.functions import col, count, when
missing_vals = submissions.select([count(when(col(c).isNull(), c)).alias(c) for c in submissions.columns])
missing_vals.show()
[Stage 11:=====================================================> (98 + 2) / 100]
+---------+-----+------------+-------+-------+--------+------+-----------+-----------+-----------+---------------+------------+-----+----+---------+----+ |subreddit|score|num_comments|over_18|is_self|is_video|domain|post_length|hour_of_day|day_of_week|day_of_week_str|day_of_month|month|year|has_media|body| +---------+-----+------------+-------+-------+--------+------+-----------+-----------+-----------+---------------+------------+-----+----+---------+----+ | 0| 0| 0| 0| 0| 0| 8002| 0| 0| 0| 0| 0| 0| 0| 0| 0| +---------+-----+------------+-------+-------+--------+------+-----------+-----------+-----------+---------------+------------+-----+----+---------+----+
submissions = submissions.na.drop(subset=["domain"])
from pyspark.sql.functions import lower, regexp_replace
submissions = submissions.withColumn("body", lower(col("body")))
# Remove newline characters
submissions = submissions.withColumn("body", regexp_replace(col("body"), "\n", " "))
# Remove punctuations
submissions = submissions.withColumn("body", regexp_replace(col("body"), "[^a-zA-Z0-9\s]", ""))
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF
# Tokenize text
tokenizer = Tokenizer(inputCol="body", outputCol="words")
tokenized_df = tokenizer.transform(submissions)
# Remove stop words
remover = StopWordsRemover(inputCol="words", outputCol="filtered_words")
df_no_stopwords = remover.transform(tokenized_df)
# Vectorize words
hashingTF = HashingTF(inputCol="filtered_words", outputCol="rawFeatures")
featurizedData = hashingTF.transform(df_no_stopwords)
# Optionally, use IDF to rescale the feature vectors
idf = IDF(inputCol="rawFeatures", outputCol="features")
rescaledData = idf.fit(featurizedData).transform(featurizedData)
rescaledData = rescaledData.drop(*["day_of_week_str", "words", "filtered_words", "rawFeatures"])
rescaledData.printSchema()
root |-- subreddit: string (nullable = true) |-- score: long (nullable = true) |-- num_comments: long (nullable = true) |-- over_18: boolean (nullable = true) |-- is_self: boolean (nullable = true) |-- is_video: boolean (nullable = true) |-- domain: string (nullable = true) |-- post_length: integer (nullable = true) |-- hour_of_day: integer (nullable = true) |-- day_of_week: integer (nullable = true) |-- day_of_month: integer (nullable = true) |-- month: integer (nullable = true) |-- year: integer (nullable = true) |-- has_media: boolean (nullable = false) |-- body: string (nullable = false) |-- features: vector (nullable = true)
rescaledData.show()
23/11/27 21:34:10 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
+----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+------------+-----+----+---------+--------------------+--------------------+ | subreddit|score|num_comments|over_18|is_self|is_video| domain|post_length|hour_of_day|day_of_week|day_of_month|month|year|has_media| body| features| +----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+------------+-----+----+---------+--------------------+--------------------+ |television| 0| 9| false| true| false|self.television| 605| 22| 4| 27| 1|2021| false|is there a websit...|(262144,[1546,158...| | anime| 0| 3| false| false| false| i.redd.it| 50| 22| 4| 27| 1|2021| false|does anyone know ...|(262144,[101370,1...| |television| 4| 11| false| false| false| deadline.com| 86| 22| 4| 27| 1|2021| false|doogie kameloha m...|(262144,[16384,73...| | movies| 0| 4| false| true| false| self.movies| 42| 22| 4| 27| 1|2021| false|4k movies on desk...|(262144,[1206,202...| | anime| 0| 9| false| true| false| self.anime| 64| 22| 4| 27| 1|2021| false|where can i buy a...|(262144,[53596,92...| | anime| 0| 7| false| true| false| self.anime| 1249| 22| 4| 27| 1|2021| false|ever suddenly fin...|(262144,[2437,392...| | movies| 1| 0| false| false| false| apple.news| 64| 22| 4| 27| 1|2021| false|cloris leachman m...|(262144,[27708,48...| |television| 2| 0| false| false| false| variety.com| 91| 22| 4| 27| 1|2021| false|netflix alan yang...|(262144,[24303,27...| | anime| 1| 6| false| false| false| youtu.be| 67| 22| 4| 27| 1|2021| true|the way eren spea...|(262144,[51471,53...| |television| 3| 2| false| true| false|self.television| 75| 22| 4| 27| 1|2021| false|whats the best wa...|(262144,[8145,271...| | anime| 1| 0| false| false| false| youtu.be| 31| 22| 4| 27| 1|2021| true|anime mix amv mmm...|(262144,[7231,820...| | anime| 1| 1| false| false| false| youtu.be| 28| 22| 4| 27| 1|2021| true|eromanga sensei ...|(262144,[8408,112...| | anime| 1| 1| false| false| false| youtu.be| 45| 22| 4| 27| 1|2021| true|god of destructio...|(262144,[35501,55...| | anime| 1| 1| false| true| false| self.anime| 25| 22| 4| 27| 1|2021| false|manga collectors ...|(262144,[41314,61...| | anime| 1| 1| false| true| false| self.anime| 29| 22| 4| 27| 1|2021| false|wholesome animes ...|(262144,[49869,61...| | movies| 1| 0| false| true| false| self.movies| 64| 22| 4| 27| 1|2021| false|a crime in a buil...|(262144,[33803,61...| | anime| 1| 3| false| true| false| self.anime| 38| 22| 4| 27| 1|2021| false|question for mang...|(262144,[41314,61...| | movies| 1| 17| false| true| false| self.movies| 58| 22| 4| 27| 1|2021| false|daredevil or ghos...|(262144,[61710,13...| | anime| 1| 5| false| true| false| self.anime| 69| 22| 4| 27| 1|2021| false|idk what the tita...|(262144,[61710,79...| | movies| 15| 54| false| true| false| self.movies| 595| 22| 4| 27| 1|2021| false|with hollywood ma...|(262144,[4978,853...| +----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+------------+-----+----+---------+--------------------+--------------------+ only showing top 20 rows
rescaledData.select("features").show(1, truncate=False)
23/11/27 21:34:11 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |features | +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |(262144,[1546,15867,27139,51471,52351,53570,54245,60776,61756,64358,91878,94214,111767,113004,116485,138201,140461,141589,143918,146337,156724,158661,161756,168211,179943,183553,190256,206312,206910,208258,213767,229604,231350,232735,234233,243658,245599,245731,249180,249943,254072,254600],[5.69251854502339,5.760023411747378,7.579618575735043,3.5506848158879545,4.968082016336675,13.530412171330834,5.253448124146994,6.604035998144887,4.909856857255696,4.2268389298045514,6.532216203833996,4.525126060326382,3.6691755070510474,4.716142116632298,6.872627092131839,11.461674818651053,12.80489709590626,7.257177844091133,5.128712738777624,8.83762821965198,7.575835844437219,5.314775239845497,9.66657694137099,5.447604138587951,12.287615765483569,7.473400952560769,10.178630837871902,4.128098018339647,5.057958163163958,2.268569872313168,4.6447319837013845,5.203915567000776,11.594468584923623,5.498081117888862,8.820816413026089,3.522859618212378,3.808643786887529,3.026130860184214,29.77500852413356,5.293682790260379,9.722666408022032,7.643224866342196])| +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ only showing top 1 row
from pyspark.sql.functions import col
rescaledData = rescaledData.withColumn("over_18", col("over_18").cast("string"))
rescaledData = rescaledData.withColumn("is_self", col("is_self").cast("string"))
rescaledData = rescaledData.withColumn("is_video", col("is_video").cast("string"))
rescaledData = rescaledData.withColumn("has_media", col("has_media").cast("string"))
rescaledData.write.parquet("s3a://project-group34/project/submissions/cleaned_ML/", mode="overwrite")
23/11/27 21:34:13 WARN DAGScheduler: Broadcasting large task binary with size 4.3 MiB
%%time
import time
import sagemaker
from sagemaker.spark.processing import PySparkProcessor
# Setup the PySpark processor to run the job. Note the instance type and instance count parameters. SageMaker will create these many instances of this type for the spark job.
role = sagemaker.get_execution_role()
spark_processor = PySparkProcessor(
base_job_name="sm-spark-project",
framework_version="3.3",
role=role,
instance_count=8,
instance_type="ml.m5.xlarge",
max_runtime_in_seconds=21600,
)
# # S3 URI of the initialization script
# s3_uri_init_script = f's3://{bucket}/{script_key}'
# s3 paths
session = sagemaker.Session()
output_prefix_logs = f"spark_logs"
configuration = [
{
"Classification": "spark-defaults",
"Properties": {"spark.executor.memory": "12g", "spark.executor.cores": "4"},
}
]
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml CPU times: user 1.37 s, sys: 539 ms, total: 1.91 s Wall time: 1.47 s
!mkdir -p ./code
%%writefile ./code/preprocess_ml.py
import sys
import os
import logging
import argparse
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, MultilayerPerceptronClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline, Model
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.sql.functions import length
logging.basicConfig(format='%(asctime)s,%(levelname)s,%(module)s,%(filename)s,%(lineno)d,%(message)s', level=logging.DEBUG)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
def main():
parser = argparse.ArgumentParser(description="app inputs and outputs")
parser.add_argument("--s3_dataset_path", type=str, help="Path of dataset in S3")
args = parser.parse_args()
spark = SparkSession.builder \
.appName("Spark ML")\
.config("spark.driver.memory","16G")\
.config("spark.driver.maxResultSize", "0") \
.config("spark.kryoserializer.buffer.max", "2000M")\
.config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.1.3")\
.getOrCreate()
logger.info(f"Spark version: {spark.version}")
# This is needed to save RDDs which is the only way to write nested Dataframes into CSV format
sc = spark.sparkContext
sc._jsc.hadoopConfiguration().set(
"mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter"
)
# Downloading the data from S3 into a Dataframe
logger.info(f"going to read {args.s3_dataset_path}")
rescaledData = spark.read.parquet(args.s3_dataset_path, header=True)
vals = ["movies", "anime"]
rescaledData = rescaledData.where(col("subreddit").isin(vals))
train_data, test_data, val_data = rescaledData.randomSplit([0.8, 0.15, 0.05], seed=1220)
# Print the number of records in each dataset
logger.info("Number of training records: " + str(train_data.count()))
logger.info("Number of testing records: " + str(test_data.count()))
logger.info("Number of validation records: " + str(val_data.count()))
stringIndexer_over_18 = StringIndexer(inputCol="over_18", outputCol="over_18_ix")
stringIndexer_is_self = StringIndexer(inputCol="is_self", outputCol="is_self_ix")
stringIndexer_is_video = StringIndexer(inputCol="is_video", outputCol="is_video_ix")
stringIndexer_has_media = StringIndexer(inputCol="has_media", outputCol="has_media_ix")
stringIndexer_subreddit = StringIndexer(inputCol="subreddit", outputCol="subreddit_ix")
onehot_over_18 = OneHotEncoder(inputCol="over_18_ix", outputCol="over_18_vec")
onehot_is_self = OneHotEncoder(inputCol="is_self_ix", outputCol="is_self_vec")
onehot_is_video = OneHotEncoder(inputCol="is_video_ix", outputCol="is_video_vec")
onehot_has_media = OneHotEncoder(inputCol="has_media_ix", outputCol="has_media_vec")
vectorAssembler_features = VectorAssembler(
inputCols=["features", "over_18_vec", "is_self_vec", "is_video_vec",
"has_media_vec", "score", "num_comments",
"post_length", "hour_of_day", "day_of_week",
"day_of_month", "month", "year"],
outputCol="combined_features")
# Define the stages for the pipeline
stages = [
stringIndexer_subreddit,
stringIndexer_over_18,
stringIndexer_is_self,
stringIndexer_is_video,
stringIndexer_has_media,
onehot_over_18,
onehot_is_self,
onehot_is_video,
onehot_has_media,
vectorAssembler_features
]
# Define the pipeline without the classifier and evaluator
pipeline = Pipeline(stages=stages)
# Fit the preprocessing part of the pipeline
pipeline_fit = pipeline.fit(train_data)
# Transform the data
transformed_train_data = pipeline_fit.transform(train_data)
transformed_test_data = pipeline_fit.transform(test_data)
transformed_train_data.write.parquet("s3a://project-group34/project/submissions/preprocessed_ML/train/")
transformed_test_data.write.parquet("s3a://project-group34/project/submissions/preprocessed_ML/test/")
if __name__ == "__main__":
main()
Overwriting ./code/preprocess_ml.py
%%time
bucket = "project-group34"
s3_path = "s3a://project-group34/project/submissions/cleaned_ML/"
# run the job now, the arguments array is provided as command line to the Python script (Spark code in this case).
spark_processor.run(
submit_app="./code/preprocess_ml.py",
arguments=[
"--s3_dataset_path",
s3_path,
],
spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(bucket, output_prefix_logs),
logs=False,
configuration=configuration
)
INFO:sagemaker:Creating processing-job with name sm-spark-project-2023-11-28-04-04-56-000
.....................................................................................................!CPU times: user 359 ms, sys: 45.3 ms, total: 404 ms Wall time: 8min 34s
%%writefile ./code/ml_train_RF1.py
import sys
import os
import logging
import argparse
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, MultilayerPerceptronClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline, Model
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.sql.functions import length
logging.basicConfig(format='%(asctime)s,%(levelname)s,%(module)s,%(filename)s,%(lineno)d,%(message)s', level=logging.DEBUG)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
def main():
spark = (SparkSession.builder\
.appName("PySparkApp")\
.config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.2.2")\
.config(
"fs.s3a.aws.credentials.provider",
"com.amazonaws.auth.ContainerCredentialsProvider",
)\
.getOrCreate()
)
logger.info(f"Spark version: {spark.version}")
# This is needed to save RDDs which is the only way to write nested Dataframes into CSV format
sc = spark.sparkContext
sc._jsc.hadoopConfiguration().set(
"mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter"
)
# Downloading the data from S3 into a Dataframe
transformed_train_data = spark.read.parquet("s3a://project-group34/project/submissions/preprocessed_ML/train/")
transformed_test_data = spark.read.parquet("s3a://project-group34/project/submissions/preprocessed_ML/test/")
# RandomForestClassifier without hyperparameter tuning
rf_classifier = RandomForestClassifier(labelCol="subreddit_ix",
featuresCol="combined_features")
labelConverter = IndexToString(inputCol="prediction",
outputCol="predictedSubreddit",
labels=["anime", "movie"])
# Add the best model and labelConverter to the pipeline
pipeline = Pipeline(stages=[rf_classifier,
labelConverter])
# Fit the entire pipeline
pipeline_fit = pipeline.fit(transformed_train_data)
train_predictions = pipeline_fit.transform(transformed_train_data)
# Transform the data with the best model
test_predictions = pipeline_fit.transform(transformed_test_data)
train_predictions.write.parquet("s3a://project-group34/project/submissions/RandomForest/default/train_pred/", mode="overwrite")
test_predictions.write.parquet("s3a://project-group34/project/submissions/RandomForest/default/test_pred/", mode="overwrite")
pipeline_fit.save("s3a://project-group34/project/submissions/RandomForest/default/model/")
logger.info(f"all done...")
if __name__ == "__main__":
main()
Overwriting ./code/ml_train_RF1.py
%%time
bucket = "project-group34"
# run the job now, the arguments array is provided as command line to the Python script (Spark code in this case).
spark_processor.run(
submit_app="./code/ml_train_RF1.py",
spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(bucket, output_prefix_logs),
logs=False,
configuration=configuration
)
INFO:sagemaker:Creating processing-job with name sm-spark-project-2023-11-30-23-59-44-449
..............................................................................................................................................................................................................................................................................!CPU times: user 1 s, sys: 83.5 ms, total: 1.08 s Wall time: 22min 46s
%%writefile ./code/ml_train_RF2.py
import sys
import os
import logging
import argparse
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, MultilayerPerceptronClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline, Model
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.sql.functions import length
logging.basicConfig(format='%(asctime)s,%(levelname)s,%(module)s,%(filename)s,%(lineno)d,%(message)s', level=logging.DEBUG)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
def main():
spark = (SparkSession.builder\
.appName("PySparkApp")\
.config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.2.2")\
.config(
"fs.s3a.aws.credentials.provider",
"com.amazonaws.auth.ContainerCredentialsProvider",
)\
.getOrCreate()
)
logger.info(f"Spark version: {spark.version}")
# This is needed to save RDDs which is the only way to write nested Dataframes into CSV format
sc = spark.sparkContext
sc._jsc.hadoopConfiguration().set(
"mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter"
)
# Downloading the data from S3 into a Dataframe
transformed_train_data = spark.read.parquet("s3a://project-group34/project/submissions/preprocessed_ML/train/")
transformed_test_data = spark.read.parquet("s3a://project-group34/project/submissions/preprocessed_ML/test/")
# RandomForestClassifier without hyperparameter tuning
rf_classifier = RandomForestClassifier(labelCol="subreddit_ix",
featuresCol="combined_features",
numTrees=100,
maxDepth=10,
maxBins=64)
labelConverter = IndexToString(inputCol="prediction",
outputCol="predictedSubreddit",
labels=["anime", "movie"])
# Add the best model and labelConverter to the pipeline
pipeline = Pipeline(stages=[rf_classifier,
labelConverter])
# Fit the entire pipeline
pipeline_fit = pipeline.fit(transformed_train_data)
train_predictions = pipeline_fit.transform(transformed_train_data)
# Transform the data with the best model
test_predictions = pipeline_fit.transform(transformed_test_data)
train_predictions.write.parquet("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/train_pred/", mode="overwrite")
test_predictions.write.parquet("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/test_pred/", mode="overwrite")
pipeline_fit.save("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/model/")
logger.info(f"all done...")
if __name__ == "__main__":
main()
Overwriting ./code/ml_train_RF2.py
%%time
bucket = "project-group34"
# run the job now, the arguments array is provided as command line to the Python script (Spark code in this case).
spark_processor.run(
submit_app="./code/ml_train_RF2.py",
spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(bucket, output_prefix_logs),
logs=False,
configuration=configuration
)
INFO:sagemaker:Creating processing-job with name sm-spark-project-2023-11-28-23-33-31-511
....................................................................................................................................................................................................................................................................................................................................................................................................................!CPU times: user 1.44 s, sys: 106 ms, total: 1.54 s Wall time: 34min
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
# Import pyspark and build Spark session
spark = SparkSession.builder \
.appName("Spark NLP")\
.master("local[*]")\
.config("spark.driver.memory","16G")\
.config("spark.executor.memory", "12g")\
.config("spark.executor.cores", "3")\
.config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.1.3,org.apache.hadoop:hadoop-aws:3.2.2")\
.config(
"fs.s3a.aws.credentials.provider",
"com.amazonaws.auth.ContainerCredentialsProvider"
)\
.getOrCreate()
print(spark.version)
Warning: Ignoring non-Spark config property: fs.s3a.aws.credentials.provider
:: loading settings :: url = jar:file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml
Ivy Default Cache set to: /root/.ivy2/cache The jars for the packages stored in: /root/.ivy2/jars com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency org.apache.hadoop#hadoop-aws added as a dependency :: resolving dependencies :: org.apache.spark#spark-submit-parent-ebc720bb-2fd4-4933-9da9-6508d62eae59;1.0 confs: [default] found com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 in central found com.typesafe#config;1.4.2 in central found org.rocksdb#rocksdbjni;6.29.5 in central found com.amazonaws#aws-java-sdk-bundle;1.11.828 in central found com.github.universal-automata#liblevenshtein;3.0.0 in central found com.google.protobuf#protobuf-java-util;3.0.0-beta-3 in central found com.google.protobuf#protobuf-java;3.0.0-beta-3 in central found com.google.code.gson#gson;2.3 in central found it.unimi.dsi#fastutil;7.0.12 in central found org.projectlombok#lombok;1.16.8 in central found com.google.cloud#google-cloud-storage;2.20.1 in central found com.google.guava#guava;31.1-jre in central found com.google.guava#failureaccess;1.0.1 in central found com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava in central found com.google.errorprone#error_prone_annotations;2.18.0 in central found com.google.j2objc#j2objc-annotations;1.3 in central found com.google.http-client#google-http-client;1.43.0 in central found io.opencensus#opencensus-contrib-http-util;0.31.1 in central found com.google.http-client#google-http-client-jackson2;1.43.0 in central found com.google.http-client#google-http-client-gson;1.43.0 in central found com.google.api-client#google-api-client;2.2.0 in central found commons-codec#commons-codec;1.15 in central found com.google.oauth-client#google-oauth-client;1.34.1 in central found com.google.http-client#google-http-client-apache-v2;1.43.0 in central found com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 in central found com.google.code.gson#gson;2.10.1 in central found com.google.cloud#google-cloud-core;2.12.0 in central found io.grpc#grpc-context;1.53.0 in central found com.google.auto.value#auto-value-annotations;1.10.1 in central found com.google.auto.value#auto-value;1.10.1 in central found javax.annotation#javax.annotation-api;1.3.2 in central found commons-logging#commons-logging;1.2 in central found com.google.cloud#google-cloud-core-http;2.12.0 in central found com.google.http-client#google-http-client-appengine;1.43.0 in central found com.google.api#gax-httpjson;0.108.2 in central found com.google.cloud#google-cloud-core-grpc;2.12.0 in central found io.grpc#grpc-alts;1.53.0 in central found io.grpc#grpc-grpclb;1.53.0 in central found org.conscrypt#conscrypt-openjdk-uber;2.5.2 in central found io.grpc#grpc-auth;1.53.0 in central found io.grpc#grpc-protobuf;1.53.0 in central found io.grpc#grpc-protobuf-lite;1.53.0 in central found io.grpc#grpc-core;1.53.0 in central found com.google.api#gax;2.23.2 in central found com.google.api#gax-grpc;2.23.2 in central found com.google.auth#google-auth-library-credentials;1.16.0 in central found com.google.auth#google-auth-library-oauth2-http;1.16.0 in central found com.google.api#api-common;2.6.2 in central found io.opencensus#opencensus-api;0.31.1 in central found com.google.api.grpc#proto-google-iam-v1;1.9.2 in central found com.google.protobuf#protobuf-java;3.21.12 in central found com.google.protobuf#protobuf-java-util;3.21.12 in central found com.google.api.grpc#proto-google-common-protos;2.14.2 in central found org.threeten#threetenbp;1.6.5 in central found com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha in central found com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha in central found com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha in central found com.fasterxml.jackson.core#jackson-core;2.14.2 in central found com.google.code.findbugs#jsr305;3.0.2 in central found io.grpc#grpc-api;1.53.0 in central found io.grpc#grpc-stub;1.53.0 in central found org.checkerframework#checker-qual;3.31.0 in central found io.perfmark#perfmark-api;0.26.0 in central found com.google.android#annotations;4.1.1.4 in central found org.codehaus.mojo#animal-sniffer-annotations;1.22 in central found io.opencensus#opencensus-proto;0.2.0 in central found io.grpc#grpc-services;1.53.0 in central found com.google.re2j#re2j;1.6 in central found io.grpc#grpc-netty-shaded;1.53.0 in central found io.grpc#grpc-googleapis;1.53.0 in central found io.grpc#grpc-xds;1.53.0 in central found com.navigamez#greex;1.0 in central found dk.brics.automaton#automaton;1.11-8 in central found com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 in central found com.microsoft.onnxruntime#onnxruntime;1.15.0 in central found org.apache.hadoop#hadoop-aws;3.2.2 in central :: resolution report :: resolve 4860ms :: artifacts dl 732ms :: modules in use: com.amazonaws#aws-java-sdk-bundle;1.11.828 from central in [default] com.fasterxml.jackson.core#jackson-core;2.14.2 from central in [default] com.github.universal-automata#liblevenshtein;3.0.0 from central in [default] com.google.android#annotations;4.1.1.4 from central in [default] com.google.api#api-common;2.6.2 from central in [default] com.google.api#gax;2.23.2 from central in [default] com.google.api#gax-grpc;2.23.2 from central in [default] com.google.api#gax-httpjson;0.108.2 from central in [default] com.google.api-client#google-api-client;2.2.0 from central in [default] com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha from central in [default] com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha from central in [default] com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha from central in [default] com.google.api.grpc#proto-google-common-protos;2.14.2 from central in [default] com.google.api.grpc#proto-google-iam-v1;1.9.2 from central in [default] com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 from central in [default] com.google.auth#google-auth-library-credentials;1.16.0 from central in [default] com.google.auth#google-auth-library-oauth2-http;1.16.0 from central in [default] com.google.auto.value#auto-value;1.10.1 from central in [default] com.google.auto.value#auto-value-annotations;1.10.1 from central in [default] com.google.cloud#google-cloud-core;2.12.0 from central in [default] com.google.cloud#google-cloud-core-grpc;2.12.0 from central in [default] com.google.cloud#google-cloud-core-http;2.12.0 from central in [default] com.google.cloud#google-cloud-storage;2.20.1 from central in [default] com.google.code.findbugs#jsr305;3.0.2 from central in [default] com.google.code.gson#gson;2.10.1 from central in [default] com.google.errorprone#error_prone_annotations;2.18.0 from central in [default] com.google.guava#failureaccess;1.0.1 from central in [default] com.google.guava#guava;31.1-jre from central in [default] com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava from central in [default] com.google.http-client#google-http-client;1.43.0 from central in [default] com.google.http-client#google-http-client-apache-v2;1.43.0 from central in [default] com.google.http-client#google-http-client-appengine;1.43.0 from central in [default] com.google.http-client#google-http-client-gson;1.43.0 from central in [default] com.google.http-client#google-http-client-jackson2;1.43.0 from central in [default] com.google.j2objc#j2objc-annotations;1.3 from central in [default] com.google.oauth-client#google-oauth-client;1.34.1 from central in [default] com.google.protobuf#protobuf-java;3.21.12 from central in [default] com.google.protobuf#protobuf-java-util;3.21.12 from central in [default] com.google.re2j#re2j;1.6 from central in [default] com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 from central in [default] com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 from central in [default] com.microsoft.onnxruntime#onnxruntime;1.15.0 from central in [default] com.navigamez#greex;1.0 from central in [default] com.typesafe#config;1.4.2 from central in [default] commons-codec#commons-codec;1.15 from central in [default] commons-logging#commons-logging;1.2 from central in [default] dk.brics.automaton#automaton;1.11-8 from central in [default] io.grpc#grpc-alts;1.53.0 from central in [default] io.grpc#grpc-api;1.53.0 from central in [default] io.grpc#grpc-auth;1.53.0 from central in [default] io.grpc#grpc-context;1.53.0 from central in [default] io.grpc#grpc-core;1.53.0 from central in [default] io.grpc#grpc-googleapis;1.53.0 from central in [default] io.grpc#grpc-grpclb;1.53.0 from central in [default] io.grpc#grpc-netty-shaded;1.53.0 from central in [default] io.grpc#grpc-protobuf;1.53.0 from central in [default] io.grpc#grpc-protobuf-lite;1.53.0 from central in [default] io.grpc#grpc-services;1.53.0 from central in [default] io.grpc#grpc-stub;1.53.0 from central in [default] io.grpc#grpc-xds;1.53.0 from central in [default] io.opencensus#opencensus-api;0.31.1 from central in [default] io.opencensus#opencensus-contrib-http-util;0.31.1 from central in [default] io.opencensus#opencensus-proto;0.2.0 from central in [default] io.perfmark#perfmark-api;0.26.0 from central in [default] it.unimi.dsi#fastutil;7.0.12 from central in [default] javax.annotation#javax.annotation-api;1.3.2 from central in [default] org.apache.hadoop#hadoop-aws;3.2.2 from central in [default] org.checkerframework#checker-qual;3.31.0 from central in [default] org.codehaus.mojo#animal-sniffer-annotations;1.22 from central in [default] org.conscrypt#conscrypt-openjdk-uber;2.5.2 from central in [default] org.projectlombok#lombok;1.16.8 from central in [default] org.rocksdb#rocksdbjni;6.29.5 from central in [default] org.threeten#threetenbp;1.6.5 from central in [default] :: evicted modules: com.google.protobuf#protobuf-java-util;3.0.0-beta-3 by [com.google.protobuf#protobuf-java-util;3.21.12] in [default] com.google.protobuf#protobuf-java;3.0.0-beta-3 by [com.google.protobuf#protobuf-java;3.21.12] in [default] com.google.code.gson#gson;2.3 by [com.google.code.gson#gson;2.10.1] in [default] com.amazonaws#aws-java-sdk-bundle;1.11.563 by [com.amazonaws#aws-java-sdk-bundle;1.11.828] in [default] --------------------------------------------------------------------- | | modules || artifacts | | conf | number| search|dwnlded|evicted|| number|dwnlded| --------------------------------------------------------------------- | default | 77 | 0 | 0 | 4 || 73 | 0 | --------------------------------------------------------------------- :: retrieving :: org.apache.spark#spark-submit-parent-ebc720bb-2fd4-4933-9da9-6508d62eae59 confs: [default] 0 artifacts copied, 73 already retrieved (0kB/236ms) 23/12/01 00:24:05 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
3.4.0
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from sparknlp.annotator import *
from sparknlp.base import *
import sparknlp
from sparknlp.pretrained import PretrainedPipeline
from sparknlp.base import Finisher, DocumentAssembler
from pyspark.sql.functions import length
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, MultilayerPerceptronClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline, Model
transformed_train_RF = spark.read.parquet("s3a://project-group34/project/submissions/RandomForest/default/train_pred/")
transformed_train_RF_withHyperparams = spark.read.parquet("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/train_pred/")
23/12/01 00:26:18 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
transformed_test_RF = spark.read.parquet("s3a://project-group34/project/submissions/RandomForest/default/test_pred/")
transformed_test_RF_withHyperparams = spark.read.parquet("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/test_pred/")
evaluator = MulticlassClassificationEvaluator(labelCol="subreddit_ix", predictionCol="prediction", metricName="accuracy")
accuracy_RF_train = evaluator.evaluate(transformed_train_RF)
accuracy_RF_test = evaluator.evaluate(transformed_test_RF)
accuracy_RF_withHyperparams_train = evaluator.evaluate(transformed_train_RF_withHyperparams)
accuracy_RF_withHyperparams_test = evaluator.evaluate(transformed_test_RF_withHyperparams)
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import pandas as pd
# Already defined in your script
# evaluator = MulticlassClassificationEvaluator(labelCol="subreddit_ix", predictionCol="prediction")
# Function to calculate metrics
def compute_metrics(dataset, label_col, prediction_col):
evaluator = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col)
accuracy = evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
f1 = evaluator.evaluate(dataset, {evaluator.metricName: "f1"})
precision = evaluator.evaluate(dataset, {evaluator.metricName: "weightedPrecision"})
recall = evaluator.evaluate(dataset, {evaluator.metricName: "weightedRecall"})
return accuracy, f1, precision, recall
# Calculate metrics for the default model
accuracy_RF_train, f1_RF_train, precision_RF_train, recall_RF_train = compute_metrics(transformed_train_RF, "subreddit_ix", "prediction")
accuracy_RF_test, f1_RF_test, precision_RF_test, recall_RF_test = compute_metrics(transformed_test_RF, "subreddit_ix", "prediction")
# Calculate metrics for the model with hyperparameters
accuracy_RF_hp_train, f1_RF_hp_train, precision_RF_hp_train, recall_RF_hp_train = compute_metrics(transformed_train_RF_withHyperparams, "subreddit_ix", "prediction")
accuracy_RF_hp_test, f1_RF_hp_test, precision_RF_hp_test, recall_RF_hp_test = compute_metrics(transformed_test_RF_withHyperparams, "subreddit_ix", "prediction")
# Create a DataFrame with the metrics
metrics_data = {
'Model': ['RandomForest Default', 'RandomForest Hyperparameters'],
'Accuracy Train': [accuracy_RF_train, accuracy_RF_hp_train],
'Accuracy Test': [accuracy_RF_test, accuracy_RF_hp_test],
'F1 Score Train': [f1_RF_train, f1_RF_hp_train],
'F1 Score Test': [f1_RF_test, f1_RF_hp_test],
'Precision Train': [precision_RF_train, precision_RF_hp_train],
'Precision Test': [precision_RF_test, precision_RF_hp_test],
'Recall Train': [recall_RF_train, recall_RF_hp_train],
'Recall Test': [recall_RF_test, recall_RF_hp_test]
}
metrics_df = pd.DataFrame(metrics_data)
metrics_df
Model | Accuracy Train | Accuracy Test | F1 Score Train | F1 Score Test | Precision Train | Precision Test | Recall Train | Recall Test | |
---|---|---|---|---|---|---|---|---|---|
0 | RandomForest Default | 0.676122 | 0.675532 | 0.642950 | 0.642239 | 0.760997 | 0.760339 | 0.676122 | 0.675532 |
1 | RandomForest Hyperparameters | 0.760637 | 0.759841 | 0.745298 | 0.744455 | 0.829303 | 0.828266 | 0.760637 | 0.759841 |
from sklearn.metrics import confusion_matrix
y_pred_RF_test = [row['prediction'] for row in transformed_test_RF.select("prediction").collect()]
y_orig_RF_test = [row['subreddit_ix'] for row in transformed_test_RF.select("subreddit_ix").collect()]
y_pred_RF_hp_test = [row['prediction'] for row in transformed_test_RF_withHyperparams.select("prediction").collect()]
y_orig_RF_hp_test = [row['subreddit_ix'] for row in transformed_test_RF_withHyperparams.select("subreddit_ix").collect()]
cm_RF = confusion_matrix(y_orig_RF_test, y_pred_RF_test)
cm_RF_hp = confusion_matrix(y_orig_RF_hp_test, y_pred_RF_hp_test)
import matplotlib.pyplot as plt
import seaborn as sns
# Plot using heatmap
plt.figure(figsize=(10, 7))
sns.heatmap(cm_RF, annot=True, fmt="d", cmap=sns.color_palette("YlOrBr", as_cmap=True), yticklabels=["Anime", "Movie"], xticklabels=["Anime", "Movie"])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix for Random Forest Model with default Hyperparameters')
plt.savefig("../../data/plots/Heatmap1_CM_Classification_Model1.png", dpi=300)
plt.show()
# Plot using heatmap
plt.figure(figsize=(10, 7))
sns.heatmap(cm_RF_hp, annot=True, fmt="d", cmap=sns.color_palette("YlOrBr", as_cmap=True), yticklabels=["Anime", "Movie"], xticklabels=["Anime", "Movie"])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix for Random Forest Model with Hyperparameters')
plt.savefig("../../data/plots/Heatmap2_CM_Classification_Model1.png", dpi=300)
plt.show()
transformed_test_RF.printSchema()
root |-- subreddit: string (nullable = true) |-- score: long (nullable = true) |-- num_comments: long (nullable = true) |-- over_18: string (nullable = true) |-- is_self: string (nullable = true) |-- is_video: string (nullable = true) |-- domain: string (nullable = true) |-- post_length: integer (nullable = true) |-- hour_of_day: integer (nullable = true) |-- day_of_week: integer (nullable = true) |-- day_of_month: integer (nullable = true) |-- month: integer (nullable = true) |-- year: integer (nullable = true) |-- has_media: string (nullable = true) |-- body: string (nullable = true) |-- features: vector (nullable = true) |-- subreddit_ix: double (nullable = true) |-- over_18_ix: double (nullable = true) |-- is_self_ix: double (nullable = true) |-- is_video_ix: double (nullable = true) |-- has_media_ix: double (nullable = true) |-- over_18_vec: vector (nullable = true) |-- is_self_vec: vector (nullable = true) |-- is_video_vec: vector (nullable = true) |-- has_media_vec: vector (nullable = true) |-- combined_features: vector (nullable = true) |-- rawPrediction: vector (nullable = true) |-- probability: vector (nullable = true) |-- prediction: double (nullable = true) |-- predictedSubreddit: string (nullable = true)
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
import pandas as pd
# Function to extract the probability of the positive class
get_probability = udf(lambda v: float(v[1]), FloatType())
# Add a new column with the probability for the positive class
roc_data_RF = transformed_test_RF.withColumn("positive_probability", get_probability("probability"))
roc_data_RF_hp = transformed_test_RF_withHyperparams.withColumn("positive_probability", get_probability("probability"))
# Collect the data
roc_data_RF = roc_data_RF.select("positive_probability", "subreddit_ix").toPandas()
roc_data_RF_hp = roc_data_RF_hp.select("positive_probability", "subreddit_ix").toPandas()
# Probabilities and actual labels
y_probs_RF = roc_data_RF['positive_probability']
y_orig_RF = roc_data_RF['subreddit_ix']
y_probs_RF_hp = roc_data_RF_hp['positive_probability']
y_orig_RF_hp = roc_data_RF_hp['subreddit_ix']
from sklearn.metrics import roc_curve, auc
# Calculate the ROC curve points
fpr, tpr, thresholds = roc_curve(y_orig_RF, y_probs_RF)
# Calculate the AUC (Area under the Curves)
roc_auc_RF = auc(fpr, tpr)
# Plotting
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_RF:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for RF Model with default hyperparameters')
plt.legend(loc="lower right")
plt.savefig("ROC1_Classification_Model1.png", dpi=300)
plt.show()
from sklearn.metrics import roc_curve, auc
# Calculate the ROC curve points
fpr, tpr, thresholds = roc_curve(y_orig_RF_hp, y_probs_RF_hp)
# Calculate the AUC (Area under the Curve)
roc_auc_RF_hp = auc(fpr, tpr)
# Plottings
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_RF_hp:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for RF Model with hyperparameters')
plt.legend(loc="lower right")
plt.savefig("../../data/plots/ROC2_Classification_Model1.png", dpi=300)
plt.show()