{ "cells": [ { "cell_type": "markdown", "id": "65a280c6-168a-4b92-92b3-e402e2104e99", "metadata": {}, "source": [ "# Data Preparation and Cleaning\n", "\n", "This notebook showcases a few general methods to clean up your training data. The better the data, the more accurate the model. While \"good data\" is highly subjective to your application, but some general guidelines could include\n", "- keeping your data formatting consistent. (binary boolean values, integers vs. floats, case-sensitivity)\n", "- Clearing duplicates\n", "- Filtering non applicable outliers\n", "- etc.\n", "***" ] }, { "cell_type": "code", "execution_count": 3, "id": "160c85aa-26e5-4d4c-a502-7eddc8734ac6", "metadata": {}, "outputs": [], "source": [ "##### Imports packages\n", "\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 4, "id": "bd9e76ff-2ab0-4f6e-89cf-a9bbd92870b0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RGBcolour
081.042.0173.0PURPLE
1222.09.073.0RED
259.0188.0227.0BLUE
314.0158.029.0GREEN
4222.0222.082.0YELLOW
...............
504139.0140.0122.0Grey
505189.0236.0182.0Light Green
506198.0166.0100.0Brown
50759.0131.0189.0Blue
508130.0137.0143.0Grey
\n", "

509 rows × 4 columns

\n", "
" ], "text/plain": [ " R G B colour\n", "0 81.0 42.0 173.0 PURPLE\n", "1 222.0 9.0 73.0 RED\n", "2 59.0 188.0 227.0 BLUE\n", "3 14.0 158.0 29.0 GREEN\n", "4 222.0 222.0 82.0 YELLOW\n", ".. ... ... ... ...\n", "504 139.0 140.0 122.0 Grey\n", "505 189.0 236.0 182.0 Light Green\n", "506 198.0 166.0 100.0 Brown\n", "507 59.0 131.0 189.0 Blue\n", "508 130.0 137.0 143.0 Grey\n", "\n", "[509 rows x 4 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "##### Imports dataset from partially cleaned csv file\n", "# The pandas dataframe is easier to work with than a pure csv file. \n", "\n", "dataset = pd.read_csv(\"combined_data.csv\")\n", "dataset" ] }, { "cell_type": "code", "execution_count": 5, "id": "b7dc68aa-5060-4a9c-a1b5-c503a58d22ec", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "colour\n", "green 17\n", "BLUE 15\n", "blue 15\n", "Green 13\n", "PURPLE 12\n", " ..\n", "pastel green 1\n", "Aphroditean Fuchsia 1\n", "Gorgonzola Blue 1\n", "Whiskey Sour 1\n", "Light Green 1\n", "Name: count, Length: 256, dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "##### Check number of occurences of each color in dataset\n", "# This helps us identify how we want to start cleaning our dataset\n", "\n", "dataset[\"colour\"].value_counts()" ] }, { "cell_type": "markdown", "id": "0a3dc818-f5b9-4770-8cdc-a8e344fa0573", "metadata": {}, "source": [ "We first notice colors are double counted due to case discrepancies. Let's standardize our color names by making them all lowercase " ] }, { "cell_type": "code", "execution_count": 6, "id": "fd81ad6c-2075-44f9-8f26-46223d0379cc", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RGBcolour
081.042.0173.0purple
1222.09.073.0red
259.0188.0227.0blue
314.0158.029.0green
4222.0222.082.0yellow
...............
504139.0140.0122.0grey
505189.0236.0182.0light green
506198.0166.0100.0brown
50759.0131.0189.0blue
508130.0137.0143.0grey
\n", "

509 rows × 4 columns

\n", "
" ], "text/plain": [ " R G B colour\n", "0 81.0 42.0 173.0 purple\n", "1 222.0 9.0 73.0 red\n", "2 59.0 188.0 227.0 blue\n", "3 14.0 158.0 29.0 green\n", "4 222.0 222.0 82.0 yellow\n", ".. ... ... ... ...\n", "504 139.0 140.0 122.0 grey\n", "505 189.0 236.0 182.0 light green\n", "506 198.0 166.0 100.0 brown\n", "507 59.0 131.0 189.0 blue\n", "508 130.0 137.0 143.0 grey\n", "\n", "[509 rows x 4 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "##### Case Folding -> to lowercase\n", "# We select the 'colour' column and apply Panda's lower() method.\n", "# This serves to standardize the format of our color names.\n", "\n", "dataset['colour'] = dataset['colour'].str.lower()\n", "dataset" ] }, { "cell_type": "markdown", "id": "54c8d336-fbe7-4c43-882f-2df55afe1b15", "metadata": {}, "source": [ "***\n", "We have an issue: some of our data is in **floating point RGB (0-1)** notation and not **integer RGB (1-255)**! Although these entries are in the wrong format, we don't want to simply discard them. As you may know, the former is a *normalized* RGB color-vector, which can be converted by multiplying each floating point value by 255 to get them to scale.\n", "\n", "It's nice to pick transformations independent of the rows, these are magic numbers. As an exercise, try to implement a cleaning algorithm that converts normalized RGB values without knowing the rows.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "44c94de2-059f-42b0-b2a6-16384411d3a7", "metadata": {}, "outputs": [], "source": [ "##### Denormalize the RGB value across the selected rows\n", "# This keeps our RGB values consistent throughout the whole dataset.\n", "\n", "dataset.loc[221:261, ['R','G','B'] ] = dataset[221:261][ ['R','G','B'] ].apply(lambda x: x * 255)" ] }, { "cell_type": "code", "execution_count": 8, "id": "81ef46b8-8c92-47e4-b828-ad9e6dd1bb2e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RGBcolour
221153.000151.98080.070moss green
222171.105109.90560.945browny orange
223255.000186.91533.915mustard yellow
22495.115255.00033.915flashy green
22532.89583.89513.005lime green
22658.90582.11049.980navy
227173.910225.930236.895light blue
228204.000176.970236.895light purple
229235.1104.080186.915flashy pink
23054.06031.11048.960dark purple
231134.895236.895213.945cyan
232255.000150.96088.995light orange
233255.000183.090198.900light pink
234236.89587.975121.125watermelon pink
235168.04532.89560.945maroon
236172.890250.92082.110lime green
237255.00088.995235.110magenta
23831.110166.00582.110dark green
23931.11026.010147.900dark blue
240241.995185.8950.000yellow
241241.99513.00537.995red
24283.895198.900255.000light blue
243255.000232.050194.055light tan
24447.94045.90041.055black
245255.000255.000160.905light yellow
246255.000255.000255.000white
247251.940249.900245.055white
248160.905166.005162.945grey
24945.90098.94029.070conifer green
25037.9956.885172.890deep purple
251186.1500.0000.000blood red
25293.075198.900255.000sky blue
253205.020237.915255.000ice blue
254255.000122.4000.000cheeto orange
2550.00058.905109.905marine blue
25692.055103.020111.945dark gray
257222.105208.080175.950sand color
25864.00552.02046.920dark brown
259224.910186.915255.000lavender
260255.0000.000158.100fuschiar
\n", "
" ], "text/plain": [ " R G B colour\n", "221 153.000 151.980 80.070 moss green\n", "222 171.105 109.905 60.945 browny orange \n", "223 255.000 186.915 33.915 mustard yellow \n", "224 95.115 255.000 33.915 flashy green \n", "225 32.895 83.895 13.005 lime green \n", "226 58.905 82.110 49.980 navy \n", "227 173.910 225.930 236.895 light blue \n", "228 204.000 176.970 236.895 light purple \n", "229 235.110 4.080 186.915 flashy pink \n", "230 54.060 31.110 48.960 dark purple \n", "231 134.895 236.895 213.945 cyan \n", "232 255.000 150.960 88.995 light orange \n", "233 255.000 183.090 198.900 light pink \n", "234 236.895 87.975 121.125 watermelon pink \n", "235 168.045 32.895 60.945 maroon\n", "236 172.890 250.920 82.110 lime green\n", "237 255.000 88.995 235.110 magenta\n", "238 31.110 166.005 82.110 dark green\n", "239 31.110 26.010 147.900 dark blue\n", "240 241.995 185.895 0.000 yellow\n", "241 241.995 13.005 37.995 red\n", "242 83.895 198.900 255.000 light blue\n", "243 255.000 232.050 194.055 light tan\n", "244 47.940 45.900 41.055 black\n", "245 255.000 255.000 160.905 light yellow\n", "246 255.000 255.000 255.000 white\n", "247 251.940 249.900 245.055 white\n", "248 160.905 166.005 162.945 grey\n", "249 45.900 98.940 29.070 conifer green\n", "250 37.995 6.885 172.890 deep purple\n", "251 186.150 0.000 0.000 blood red\n", "252 93.075 198.900 255.000 sky blue\n", "253 205.020 237.915 255.000 ice blue\n", "254 255.000 122.400 0.000 cheeto orange\n", "255 0.000 58.905 109.905 marine blue\n", "256 92.055 103.020 111.945 dark gray\n", "257 222.105 208.080 175.950 sand color\n", "258 64.005 52.020 46.920 dark brown\n", "259 224.910 186.915 255.000 lavender\n", "260 255.000 0.000 158.100 fuschiar" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[221:261]" ] }, { "cell_type": "markdown", "id": "41a2afd4-1dbc-4e99-84ca-61b1e4ad3190", "metadata": {}, "source": [ "***\n", "We want to decrease the number of possible classes.\n", "-> Notice the pattern in our dataset: composite colors often have a basic color as its second word (i.e Cheeto ***Orange***).\n", "We can therefore truncate every composite color name to *only* the last one." ] }, { "cell_type": "code", "execution_count": 9, "id": "7b1dd4c2-36ca-4871-849d-5d1a4a897557", "metadata": {}, "outputs": [], "source": [ "##### Keep last word of every color\n", "# This simplifies the composite names\n", "# We overwrite the column colour with the corrected entries\n", "\n", "dataset.loc[:, ['colour']] = dataset.loc[:, ['colour']].apply(lambda name: name.str.split().str[-1], axis=1)" ] }, { "cell_type": "markdown", "id": "8b0d4da3-86d2-4447-8f5d-56cb16d5d540", "metadata": {}, "source": [ "***\n", "Let's take a look at the amount of unique colors now!" ] }, { "cell_type": "code", "execution_count": 10, "id": "9c82d26b-15b4-40a0-954a-4e0f0fabb230", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "colour\n", "green 75\n", "blue 73\n", "pink 36\n", "purple 35\n", "red 35\n", " ..\n", "drab 1\n", "peru 1\n", "sunflower 1\n", "whip 1\n", "gurgundy 1\n", "Name: count, Length: 92, dtype: int64" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[\"colour\"].value_counts()" ] }, { "cell_type": "markdown", "id": "15d38bf7-5146-4343-9ba9-f721babe4e33", "metadata": {}, "source": [ "92 unique colours, but we can do better. It could be beneficial to remove all the colors with only a count of one... because some of these singletons are quite meaningless.\n", "\n", "We can select only the colors that appear more than **7** times to keep only the 'recognizable' colors." ] }, { "cell_type": "code", "execution_count": 11, "id": "695ec114-17cb-4f45-8e82-a719e862cb5c", "metadata": {}, "outputs": [], "source": [ "##### Get the Pandas Series of every color in our database\n", "# We can work with the series to select our desired colors\n", "\n", "value_counts = dataset['colour'].value_counts()" ] }, { "cell_type": "code", "execution_count": 12, "id": "3e603bed-9a3e-434b-b83e-fb3af4893479", "metadata": {}, "outputs": [], "source": [ "##### Keep every color in the Series that appears >7 times\n", "# Now we only have our desired colours in the Series\n", "\n", "valid_values = value_counts[value_counts >7].index" ] }, { "cell_type": "code", "execution_count": 13, "id": "01c35880-e214-4233-b759-1d5393e2fac9", "metadata": {}, "outputs": [], "source": [ "##### Get the newly filtered dataset\n", "# we obtain it by keeping the entries that appear both in the full the list/series of valid colors.\n", "\n", "filtered_dataset = dataset[dataset['colour'].isin(valid_values)]" ] }, { "cell_type": "code", "execution_count": 14, "id": "6a2c5862-344d-4cc5-aeaf-e30d6e820a09", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " R G B colour\n", "0 81.000 42.000 173.000 purple\n", "1 222.000 9.000 73.000 red\n", "2 59.000 188.000 227.000 blue\n", "3 14.000 158.000 29.000 green\n", "4 222.000 222.000 82.000 yellow\n", "5 222.000 154.000 82.000 orange\n", "6 240.000 137.000 26.000 orange\n", "7 54.000 181.000 126.000 green\n", "8 16.000 115.000 130.000 blue\n", "9 96.000 126.000 181.000 blue\n", "10 222.000 160.000 194.000 pink\n", "11 205.000 143.000 219.000 purple\n", "12 143.000 219.000 177.000 green\n", "13 143.000 176.000 219.000 blue\n", "14 215.000 219.000 143.000 yellow\n", "15 135.000 206.000 235.000 blue\n", "16 255.000 97.000 56.000 orange\n", "17 255.000 99.000 71.000 red\n", "18 120.000 81.000 169.000 purple\n", "19 255.000 20.000 147.000 pink\n", "20 199.000 77.000 222.000 pink\n", "21 222.000 200.000 77.000 yellow\n", "22 133.000 266.000 233.000 blue\n", "23 52.000 29.000 107.000 purple\n", "24 50.000 20.000 100.000 purple\n", "25 100.000 20.000 70.000 pink\n", "26 89.000 102.000 201.000 purple\n", "27 3.000 61.000 252.000 blue\n", "28 30.000 77.000 2.000 green\n", "29 176.000 125.000 174.000 pink\n", "30 57.000 77.000 11.000 green\n", "31 44.000 33.000 88.000 purple\n", "32 88.000 11.000 333.000 purple\n", "33 123.000 45.000 67.000 red\n", "34 89.000 156.000 233.000 blue\n", "35 210.000 85.000 30.000 orange\n", "36 255.000 165.000 0.000 orange\n", "37 255.000 192.000 203.000 pink\n", "38 60.000 179.000 113.000 green\n", "39 72.000 61.000 139.000 blue\n", "40 48.000 160.000 99.000 green\n", "41 16.000 201.000 100.000 green\n", "42 23.000 204.000 218.000 blue\n", "43 76.000 34.000 106.000 purple\n", "44 132.000 43.000 195.000 purple\n", "45 43.000 195.000 190.000 turquoise\n", "46 81.000 216.000 211.000 turquoise\n", "47 9.000 104.000 101.000 turquoise\n", "48 128.000 210.000 207.000 turquoise\n", "49 128.000 155.000 210.000 blue\n", "50 16.000 56.000 133.000 blue\n", "51 151.000 172.000 212.000 blue\n", "52 162.000 207.000 0.000 green\n", "53 174.000 197.000 90.000 green\n", "54 207.000 202.000 13.000 yellow\n", "55 224.000 221.000 116.000 yellow\n", "56 220.000 228.000 18.000 yellow\n", "57 228.000 67.000 18.000 red\n", "58 243.000 136.000 104.000 red\n", "59 245.000 148.000 240.000 pink\n", "60 169.000 252.000 3.000 green\n", "61 3.000 198.000 252.000 blue\n", "62 177.000 3.000 252.000 violet\n", "63 186.000 58.000 154.000 pink\n", "64 186.000 133.000 58.000 brown\n", "65 181.000 68.000 51.000 red\n", "66 181.000 131.000 51.000 brown\n", "67 181.000 51.000 177.000 purple\n", "68 181.000 51.000 51.000 red\n", "69 51.000 151.000 181.000 blue\n", "70 146.000 181.000 51.000 green\n", "71 181.000 181.000 51.000 yellow\n", "72 250.000 151.000 2.000 orange\n", "73 250.000 68.000 2.000 orange\n", "74 250.000 19.000 2.000 red\n", "75 2.000 250.000 11.000 green\n", "76 2.000 159.000 250.000 blue\n", "77 2.000 93.000 250.000 blue\n", "78 130.000 2.000 250.000 violet\n", "79 221.000 2.000 250.000 purple\n", "80 154.000 161.000 18.000 green\n", "81 67.000 98.000 94.000 green\n", "82 123.000 172.000 196.000 blue\n", "83 117.000 117.000 111.000 grey\n", "84 211.000 108.000 75.000 orange\n", "85 220.000 20.000 16.000 red\n", "86 43.000 76.000 112.000 blue\n", "87 237.000 210.000 87.000 yellow\n", "88 112.000 62.000 43.000 brown\n", "89 130.000 103.000 137.000 purple\n", "90 133.000 8.000 138.000 purple\n", "91 79.000 79.000 79.000 grey\n", "93 65.000 105.000 225.000 blue\n", "94 233.000 240.000 105.000 yellow\n", "95 223.000 210.000 214.000 pink\n", "96 252.000 140.000 3.000 orange\n", "97 240.000 252.000 3.000 yellow\n", "98 20.000 162.000 69.000 green\n", "99 192.000 208.000 174.000 green\n", "100 176.000 81.000 81.000 red\n", "101 57.000 27.000 112.000 violet\n", "102 223.000 52.000 224.000 pink\n", "103 58.000 51.000 71.000 grey\n", "104 76.000 55.000 139.000 blue\n", "105 124.000 22.000 17.000 red\n", "106 52.000 144.000 135.000 green\n", "107 116.000 93.000 86.000 brown\n", "108 39.000 249.000 108.000 green\n", "109 56.000 149.000 161.000 blue\n", "110 6.000 6.000 6.000 black\n", "111 249.000 249.000 249.000 white\n", "112 157.000 237.000 236.000 blue\n", "115 50.000 205.000 50.000 green\n", "116 255.000 20.000 147.000 pink\n", "117 70.000 130.000 180.000 blue\n", "118 186.000 85.000 211.000 purple\n", "120 64.000 224.000 208.000 turquoise\n", "123 218.000 112.000 214.000 purple\n", "124 46.000 139.000 87.000 green\n", "125 30.000 144.000 255.000 blue\n", "126 205.000 133.000 63.000 yellow\n", "128 138.000 68.000 46.000 brown\n", "129 225.000 99.000 255.000 purple\n", "130 187.000 117.000 9.000 brown\n", "131 66.000 168.000 78.000 green\n", "132 25.000 235.000 79.000 green\n", "133 224.000 49.000 234.000 purple\n", "134 169.000 175.000 255.000 blue\n", "135 126.000 136.000 171.000 blue\n", "136 250.000 2.000 242.000 pink\n", "137 7.000 218.000 1.000 green\n", "139 114.000 0.000 47.000 red\n", "140 40.000 114.000 51.000 green\n", "141 32.000 33.000 79.000 blue\n", "142 115.000 66.000 34.000 brown\n", "143 110.000 28.000 52.000 violet\n", "145 198.000 166.000 100.000 yellow\n", "146 137.000 172.000 118.000 green\n", "147 202.000 196.000 176.000 grey\n", "148 27.000 85.000 131.000 blue\n", "149 250.000 139.000 55.000 orange\n", "150 235.000 86.000 73.000 red\n", "151 57.000 250.000 115.000 green\n", "152 225.000 79.000 155.000 pink\n", "153 135.000 73.000 23.000 brown\n", "155 125.000 132.000 113.000 grey\n", "156 105.000 73.000 47.000 brown\n", "157 195.000 88.000 49.000 orange\n", "158 228.000 160.000 16.000 yellow\n", "160 165.000 32.000 25.000 red\n", "161 52.000 59.000 41.000 green\n", "162 0.000 143.000 57.000 green\n", "163 32.000 96.000 61.000 green\n", "164 229.000 190.000 1.000 yellow\n", "165 110.000 28.000 52.000 violet\n", "166 237.000 239.000 240.000 white\n", "167 2.000 14.000 20.000 black\n", "168 241.000 64.000 26.000 red\n", "172 70.000 223.000 195.000 turquoise\n", "173 80.000 142.000 131.000 blue\n", "175 132.000 117.000 247.000 blue\n", "177 186.000 232.000 255.000 blue\n", "178 203.000 208.000 204.000 grey\n", "179 203.000 50.000 52.000 red\n", "180 216.000 75.000 32.000 orange\n", "182 144.000 70.000 132.000 violet\n", "183 33.000 33.000 33.000 brown\n", "184 236.000 124.000 38.000 orange\n", "185 65.000 34.000 39.000 red\n", "186 76.000 81.000 74.000 grey\n", "187 127.000 181.000 181.000 turquoise\n", "189 139.000 140.000 122.000 grey\n", "191 217.000 80.000 48.000 pink\n", "192 202.000 18.000 167.000 violet\n", "196 79.000 36.000 40.000 red\n", "197 255.000 35.000 1.000 orange\n", "198 162.000 35.000 29.000 red\n", "199 149.000 95.000 32.000 brown\n", "206 31.000 161.000 60.000 green\n", "207 88.000 227.000 119.000 green\n", "209 70.000 77.000 211.000 blue\n", "211 74.000 206.000 133.000 green\n", "213 255.000 140.000 174.000 pink\n", "216 69.000 70.000 108.000 green\n", "217 89.000 198.000 174.000 blue\n", "218 15.000 206.000 164.000 green\n", "219 211.000 110.000 112.000 pink\n", "220 229.000 190.000 1.000 yellow\n", "221 153.000 151.980 80.070 green\n", "222 171.105 109.905 60.945 orange\n", "223 255.000 186.915 33.915 yellow\n", "224 95.115 255.000 33.915 green\n", "225 32.895 83.895 13.005 green\n", "227 173.910 225.930 236.895 blue\n", "228 204.000 176.970 236.895 purple\n", "229 235.110 4.080 186.915 pink\n", "230 54.060 31.110 48.960 purple\n", "232 255.000 150.960 88.995 orange\n", "233 255.000 183.090 198.900 pink\n", "234 236.895 87.975 121.125 pink\n", "236 172.890 250.920 82.110 green\n", "238 31.110 166.005 82.110 green\n", "239 31.110 26.010 147.900 blue\n", "240 241.995 185.895 0.000 yellow\n", "241 241.995 13.005 37.995 red\n", "242 83.895 198.900 255.000 blue\n", "244 47.940 45.900 41.055 black\n", "245 255.000 255.000 160.905 yellow\n", "246 255.000 255.000 255.000 white\n", "247 251.940 249.900 245.055 white\n", "248 160.905 166.005 162.945 grey\n", "249 45.900 98.940 29.070 green\n", "250 37.995 6.885 172.890 purple\n", "251 186.150 0.000 0.000 red\n", "252 93.075 198.900 255.000 blue\n", "253 205.020 237.915 255.000 blue\n", "254 255.000 122.400 0.000 orange\n", "255 0.000 58.905 109.905 blue\n", "258 64.005 52.020 46.920 brown\n", "261 NaN NaN NaN green\n", "265 0.000 208.000 20.000 green\n", "266 61.000 82.000 255.000 blue\n", "267 79.000 0.000 129.000 purple\n", "270 106.000 90.000 205.000 blue\n", "273 65.000 105.000 225.000 blue\n", "275 255.000 140.000 0.000 orange\n", "276 70.000 130.000 180.000 blue\n", "277 60.000 179.000 113.000 green\n", "279 255.000 105.000 180.000 pink\n", "280 139.000 69.000 19.000 brown\n", "281 135.000 206.000 250.000 blue\n", "282 50.000 205.000 50.000 green\n", "283 255.000 20.000 147.000 pink\n", "284 148.000 0.000 211.000 violet\n", "288 255.000 78.000 245.000 pink\n", "289 0.000 145.000 186.000 blue\n", "290 135.000 206.000 235.000 blue\n", "291 255.000 69.000 0.000 orange\n", "292 219.000 112.000 147.000 red\n", "294 30.000 144.000 255.000 blue\n", "295 250.000 250.000 210.000 yellow\n", "296 95.000 158.000 160.000 blue\n", "299 255.000 182.000 193.000 pink\n", "304 128.000 0.000 128.000 purple\n", "305 244.000 164.000 96.000 brown\n", "308 139.000 0.000 0.000 red\n", "310 141.000 84.000 0.000 brown\n", "311 255.000 31.000 128.000 pink\n", "312 197.000 134.000 0.000 orange\n", "314 146.000 221.000 24.000 green\n", "315 211.000 255.000 252.000 blue\n", "316 192.000 94.000 255.000 purple\n", "317 16.000 94.000 255.000 blue\n", "319 242.000 207.000 255.000 pink\n", "320 0.000 66.000 37.000 green\n", "323 234.000 255.000 180.000 yellow\n", "325 209.000 216.000 143.000 green\n", "328 178.000 225.000 222.000 blue\n", "330 186.000 104.000 147.000 pink\n", "332 255.000 244.000 179.000 yellow\n", "333 89.000 104.000 160.000 blue\n", "334 250.000 144.000 127.000 orange\n", "335 126.000 61.000 113.000 purple\n", "337 67.000 90.000 56.000 green\n", "338 165.000 196.000 213.000 blue\n", "339 190.000 177.000 225.000 purple\n", "340 56.000 79.000 64.000 green\n", "341 255.000 195.000 160.000 pink\n", "342 37.000 45.000 55.000 blue\n", "344 121.000 135.000 145.000 blue\n", "345 246.000 243.000 228.000 white\n", "346 211.000 124.000 92.000 orange\n", "347 143.000 102.000 189.000 violet\n", "348 252.000 169.000 205.000 pink\n", "350 130.000 41.000 0.000 red\n", "351 217.000 235.000 145.000 green\n", "352 145.000 230.000 235.000 blue\n", "353 166.000 145.000 235.000 purple\n", "354 133.000 48.000 124.000 pink\n", "355 111.000 133.000 48.000 green\n", "356 15.000 99.000 46.000 green\n", "357 65.000 27.000 117.000 purple\n", "358 130.000 8.000 47.000 red\n", "359 163.000 189.000 128.000 green\n", "360 128.000 177.000 189.000 blue\n", "361 50.000 97.000 115.000 blue\n", "363 159.000 163.000 206.000 blue\n", "365 32.000 35.000 24.000 green\n", "368 255.000 0.000 0.000 red\n", "369 255.000 255.000 0.000 yellow\n", "375 76.000 145.000 58.000 green\n", "376 110.000 41.000 50.000 purple\n", "377 41.000 74.000 150.000 blue\n", "378 207.000 204.000 52.000 yellow\n", "379 135.000 128.000 128.000 grey\n", "382 26.000 25.000 31.000 black\n", "385 239.000 138.000 23.000 orange\n", "391 63.000 97.000 45.000 green\n", "392 175.000 208.000 191.000 grey\n", "394 246.000 121.000 229.000 violet\n", "395 3.000 121.000 113.000 green\n", "400 38.000 94.000 28.000 green\n", "403 255.000 210.000 117.000 yellow\n", "404 219.000 90.000 66.000 red\n", "405 237.000 37.000 78.000 red\n", "406 1.000 25.000 54.000 blue\n", "407 70.000 83.000 98.000 grey\n", "408 107.000 170.000 117.000 green\n", "410 117.000 70.000 104.000 purple\n", "411 77.000 170.000 87.000 green\n", "412 88.000 125.000 113.000 green\n", "413 248.000 243.000 43.000 yellow\n", "414 255.000 237.000 223.000 pink\n", "415 197.000 216.000 109.000 green\n", "416 161.000 124.000 107.000 brown\n", "417 60.000 136.000 126.000 green\n", "418 72.000 77.000 109.000 purple\n", "419 238.000 99.000 82.000 red\n", "420 8.000 178.000 227.000 blue\n", "422 117.000 185.000 190.000 blue\n", "424 13.000 240.000 10.000 green\n", "426 237.000 14.000 200.000 pink\n", "427 9.000 188.000 138.000 green\n", "429 21.000 122.000 110.000 green\n", "431 152.000 159.000 206.000 purple\n", "432 135.000 61.000 72.000 red\n", "433 70.000 69.000 49.000 brown\n", "434 120.000 133.000 139.000 grey\n", "435 37.000 34.000 27.000 black\n", "436 195.000 195.000 195.000 grey\n", "438 139.000 140.000 122.000 grey\n", "439 10.000 10.000 13.000 black\n", "440 164.000 125.000 144.000 purple\n", "441 71.000 64.000 46.000 brown\n", "442 49.000 102.000 80.000 green\n", "443 25.000 26.000 36.000 purple\n", "444 29.000 30.000 51.000 blue\n", "446 166.000 94.000 46.000 brown\n", "447 248.000 243.000 53.000 yellow\n", "448 162.000 35.000 29.000 red\n", "449 127.000 181.000 181.000 turquoise\n", "450 69.000 50.000 46.000 brown\n", "451 104.000 108.000 94.000 grey\n", "452 228.000 160.000 16.000 yellow\n", "453 180.000 76.000 67.000 pink\n", "455 169.000 131.000 7.000 yellow\n", "456 16.000 44.000 84.000 blue\n", "457 47.000 69.000 56.000 green\n", "458 246.000 246.000 246.000 white\n", "459 234.000 137.000 154.000 pink\n", "460 213.000 48.000 50.000 red\n", "461 125.000 132.000 113.000 grey\n", "462 66.000 70.000 50.000 green\n", "464 189.000 236.000 182.000 pink\n", "465 236.000 124.000 38.000 orange\n", "466 109.000 101.000 82.000 grey\n", "467 45.000 87.000 44.000 green\n", "468 207.000 211.000 205.000 white\n", "469 37.000 109.000 123.000 blue\n", "471 237.000 255.000 33.000 yellow\n", "472 42.000 100.000 120.000 blue\n", "473 40.000 40.000 40.000 black\n", "475 228.000 160.000 16.000 yellow\n", "476 134.000 115.000 161.000 violet\n", "477 2.000 86.000 105.000 blue\n", "478 93.000 155.000 155.000 blue\n", "479 237.000 118.000 14.000 orange\n", "480 74.000 25.000 44.000 violet\n", "481 237.000 255.000 33.000 yellow\n", "482 204.000 6.000 5.000 red\n", "483 138.000 149.000 151.000 grey\n", "484 112.000 83.000 53.000 brown\n", "485 165.000 32.000 25.000 red\n", "486 35.000 26.000 36.000 blue\n", "487 217.000 80.000 48.000 pink\n", "488 40.000 114.000 51.000 green\n", "489 160.000 52.000 114.000 purple\n", "490 89.000 53.000 31.000 brown\n", "492 250.000 244.000 227.000 white\n", "494 222.000 76.000 138.000 pink\n", "495 236.000 124.000 38.000 orange\n", "496 48.000 132.000 70.000 green\n", "497 162.000 35.000 29.000 red\n", "498 255.000 164.000 32.000 yellow\n", "499 42.000 46.000 75.000 blue\n", "500 203.000 208.000 204.000 white\n", "501 62.000 59.000 50.000 black\n", "502 77.000 86.000 69.000 grey\n", "503 44.000 85.000 69.000 green\n", "504 139.000 140.000 122.000 grey\n", "505 189.000 236.000 182.000 green\n", "506 198.000 166.000 100.000 brown\n", "507 59.000 131.000 189.000 blue\n", "508 130.000 137.000 143.000 grey\n" ] } ], "source": [ "print(filtered_dataset.to_string())" ] }, { "cell_type": "markdown", "id": "b946e5f4-ecae-4a95-9359-a163fd4633fd", "metadata": {}, "source": [ "And here we have our fully cleaned dataset!\n", "\n", "To summarize, we:\n", "- made the color names lowercase\n", "- only kept the relevant part of compound names\n", "- filtered unpopular outliers\n", "- standardized our RGB values\n", "- etc.\n", "\n", "However, the accuracy of our model isn't all that good. This is because we made lots of oversimplifications in our data.\n" ] }, { "cell_type": "markdown", "id": "58f1551b-5bee-4415-83da-29e33a63a1df", "metadata": {}, "source": [ "\n", "***\n", ">\n", "# Exercise: training a KNN classififer for color names\n", "\n", "Input: R, G, B,\n", "Output (prediciton): color name\n", "\n", "- do a train test split on the diltered_dataset\n", "- train a KNN (choose whatever you want to start with)\n", "- make predictions on the test set\n", "- compute accuracy score\n", "- repeat for different K if necessary" ] }, { "cell_type": "code", "execution_count": 35, "id": "a1518e0d-112e-41ea-9e2d-e921258970a8", "metadata": {}, "outputs": [], "source": [ "\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 16, "id": "5129fb88-b47c-4f18-8ed5-155b065bdae5", "metadata": {}, "outputs": [], "source": [ "filtered_dataset = filtered_dataset.drop(index=261)" ] }, { "cell_type": "code", "execution_count": 17, "id": "9ec42c3c-baf9-4f03-9cc2-41ac470980c2", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = \\\n", " train_test_split(filtered_dataset[['R','G','B']], filtered_dataset['colour'], test_size=0.2)" ] }, { "cell_type": "code", "execution_count": 18, "id": "e4289601-5c08-4800-9ef6-cdab9be8fde1", "metadata": {}, "outputs": [], "source": [ "knn = KNeighborsClassifier(n_neighbors=3)" ] }, { "cell_type": "code", "execution_count": 19, "id": "19e137a3-580a-446d-97e6-915d339aabb4", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
KNeighborsClassifier(n_neighbors=3)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "KNeighborsClassifier(n_neighbors=3)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "knn.fit(X_train, y_train)\n", "\n", "# Remove NaN" ] }, { "cell_type": "code", "execution_count": 20, "id": "60751884-c502-44f4-8a15-ea2253cfd17e", "metadata": {}, "outputs": [], "source": [ "y_predicted = knn.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 21, "id": "29be4004-4088-4c79-b675-45273d53c738", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6582278481012658" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# KNN accuracy score\n", "accuracy_score(y_test, y_predicted)" ] }, { "cell_type": "code", "execution_count": 22, "id": "6155783e-d46a-411f-b48b-cde57c7565d9", "metadata": {}, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "code", "execution_count": 23, "id": "80a94642-3886-4711-8a2a-bd3ad21a4243", "metadata": {}, "outputs": [], "source": [ "dt = DecisionTreeClassifier(max_depth=7)" ] }, { "cell_type": "code", "execution_count": 24, "id": "9137287b-28f8-4e36-9741-d39787ef895f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
DecisionTreeClassifier(max_depth=7)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "DecisionTreeClassifier(max_depth=7)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dt.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 25, "id": "495049c4-78de-4b01-a419-5d4c8da2fdd1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6329113924050633" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Decision Tree accuracy score\n", "y_predicted = dt.predict(X_test)\n", "accuracy_score(y_test, y_predicted)" ] }, { "cell_type": "code", "execution_count": 26, "id": "1ff56166-1eb6-45b5-84e6-3933d9366d0c", "metadata": {}, "outputs": [], "source": [ "from sklearn import tree\n" ] }, { "cell_type": "code", "execution_count": 27, "id": "80856df0-ca78-49f7-9157-5e373f9545a9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Text(0.5528846153846154, 0.9375, 'x[0] <= 175.5\\ngini = 0.888\\nsamples = 314\\nvalue = [7.0, 55.0, 20.0, 60.0, 20.0, 20.0, 29.0, 27.0, 27.0\\n5.0, 10.0, 7.0, 27.0]'),\n", " Text(0.32967032967032966, 0.8125, 'x[2] <= 138.5\\ngini = 0.812\\nsamples = 194\\nvalue = [7.0, 54.0, 15.0, 55.0, 17.0, 1.0, 1.0, 20.0, 9.0\\n5.0, 9.0, 0.0, 1.0]'),\n", " Text(0.4412774725274725, 0.875, 'True '),\n", " Text(0.18681318681318682, 0.6875, 'x[1] <= 141.5\\ngini = 0.788\\nsamples = 135\\nvalue = [7, 13, 15, 54, 13, 1, 1, 15, 9, 1, 5, 0, 1]'),\n", " Text(0.17582417582417584, 0.5625, 'x[0] <= 69.5\\ngini = 0.86\\nsamples = 106\\nvalue = [7, 13, 15, 25, 13, 1, 1, 15, 9, 1, 5, 0, 1]'),\n", " Text(0.08791208791208792, 0.4375, 'x[1] <= 58.952\\ngini = 0.717\\nsamples = 54\\nvalue = [7, 13, 1, 24, 1, 0, 0, 5, 1, 1, 1, 0, 0]'),\n", " Text(0.04395604395604396, 0.3125, 'x[2] <= 33.5\\ngini = 0.783\\nsamples = 23\\nvalue = [6, 7, 1, 1, 1, 0, 0, 5, 1, 0, 1, 0, 0]'),\n", " Text(0.02197802197802198, 0.1875, 'x[1] <= 34.5\\ngini = 0.278\\nsamples = 6\\nvalue = [5, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.01098901098901099, 0.0625, 'gini = 0.0\\nsamples = 5\\nvalue = [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.03296703296703297, 0.0625, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.06593406593406594, 0.1875, 'x[0] <= 43.0\\ngini = 0.727\\nsamples = 17\\nvalue = [1, 7, 1, 0, 1, 0, 0, 5, 1, 0, 1, 0, 0]'),\n", " Text(0.054945054945054944, 0.0625, 'gini = 0.219\\nsamples = 8\\nvalue = [0, 7, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]'),\n", " Text(0.07692307692307693, 0.0625, 'gini = 0.741\\nsamples = 9\\nvalue = [1, 0, 1, 0, 1, 0, 0, 4, 1, 0, 1, 0, 0]'),\n", " Text(0.13186813186813187, 0.3125, 'x[2] <= 97.5\\ngini = 0.41\\nsamples = 31\\nvalue = [1, 6, 0, 23, 0, 0, 0, 0, 0, 1, 0, 0, 0]'),\n", " Text(0.10989010989010989, 0.1875, 'x[1] <= 64.0\\ngini = 0.095\\nsamples = 20\\nvalue = [1, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.0989010989010989, 0.0625, 'gini = 0.5\\nsamples = 2\\nvalue = [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.12087912087912088, 0.0625, 'gini = 0.0\\nsamples = 18\\nvalue = [0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.15384615384615385, 0.1875, 'x[1] <= 118.0\\ngini = 0.562\\nsamples = 11\\nvalue = [0, 6, 0, 4, 0, 0, 0, 0, 0, 1, 0, 0, 0]'),\n", " Text(0.14285714285714285, 0.0625, 'gini = 0.406\\nsamples = 8\\nvalue = [0, 6, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0]'),\n", " Text(0.16483516483516483, 0.0625, 'gini = 0.0\\nsamples = 3\\nvalue = [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.26373626373626374, 0.4375, 'x[2] <= 49.5\\ngini = 0.806\\nsamples = 52\\nvalue = [0, 0, 14, 1, 12, 1, 1, 10, 8, 0, 4, 0, 1]'),\n", " Text(0.21978021978021978, 0.3125, 'x[1] <= 44.5\\ngini = 0.57\\nsamples = 20\\nvalue = [0, 0, 11, 0, 0, 0, 0, 0, 7, 0, 1, 0, 1]'),\n", " Text(0.1978021978021978, 0.1875, 'x[0] <= 76.5\\ngini = 0.219\\nsamples = 8\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 1, 0, 0]'),\n", " Text(0.18681318681318682, 0.0625, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]'),\n", " Text(0.2087912087912088, 0.0625, 'gini = 0.0\\nsamples = 7\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0]'),\n", " Text(0.24175824175824176, 0.1875, 'x[1] <= 113.0\\ngini = 0.153\\nsamples = 12\\nvalue = [0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.23076923076923078, 0.0625, 'gini = 0.0\\nsamples = 11\\nvalue = [0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.25274725274725274, 0.0625, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.3076923076923077, 0.3125, 'x[1] <= 79.0\\ngini = 0.74\\nsamples = 32\\nvalue = [0, 0, 3, 1, 12, 1, 1, 10, 1, 0, 3, 0, 0]'),\n", " Text(0.2857142857142857, 0.1875, 'x[2] <= 88.0\\ngini = 0.531\\nsamples = 14\\nvalue = [0, 0, 0, 0, 0, 0, 1, 9, 1, 0, 3, 0, 0]'),\n", " Text(0.27472527472527475, 0.0625, 'gini = 0.625\\nsamples = 4\\nvalue = [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 0, 0]'),\n", " Text(0.2967032967032967, 0.0625, 'gini = 0.34\\nsamples = 10\\nvalue = [0, 0, 0, 0, 0, 0, 1, 8, 0, 0, 1, 0, 0]'),\n", " Text(0.32967032967032966, 0.1875, 'x[0] <= 150.0\\ngini = 0.519\\nsamples = 18\\nvalue = [0, 0, 3, 1, 12, 1, 0, 1, 0, 0, 0, 0, 0]'),\n", " Text(0.31868131868131866, 0.0625, 'gini = 0.414\\nsamples = 16\\nvalue = [0, 0, 2, 1, 12, 0, 0, 1, 0, 0, 0, 0, 0]'),\n", " Text(0.34065934065934067, 0.0625, 'gini = 0.5\\nsamples = 2\\nvalue = [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.1978021978021978, 0.5625, 'gini = 0.0\\nsamples = 29\\nvalue = [0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.4725274725274725, 0.6875, 'x[1] <= 14.005\\ngini = 0.496\\nsamples = 59\\nvalue = [0, 41, 0, 1, 4, 0, 0, 5, 0, 4, 4, 0, 0]'),\n", " Text(0.46153846153846156, 0.5625, 'gini = 0.0\\nsamples = 2\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0]'),\n", " Text(0.4835164835164835, 0.5625, 'x[2] <= 211.5\\ngini = 0.464\\nsamples = 57\\nvalue = [0, 41, 0, 1, 4, 0, 0, 5, 0, 4, 2, 0, 0]'),\n", " Text(0.43956043956043955, 0.4375, 'x[0] <= 129.0\\ngini = 0.642\\nsamples = 34\\nvalue = [0, 19, 0, 1, 4, 0, 0, 4, 0, 4, 2, 0, 0]'),\n", " Text(0.3956043956043956, 0.3125, 'x[1] <= 190.0\\ngini = 0.448\\nsamples = 25\\nvalue = [0, 18, 0, 1, 0, 0, 0, 2, 0, 4, 0, 0, 0]'),\n", " Text(0.37362637362637363, 0.1875, 'x[1] <= 103.0\\ngini = 0.188\\nsamples = 19\\nvalue = [0, 17, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0]'),\n", " Text(0.3626373626373626, 0.0625, 'gini = 0.48\\nsamples = 5\\nvalue = [0, 3, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0]'),\n", " Text(0.38461538461538464, 0.0625, 'gini = 0.0\\nsamples = 14\\nvalue = [0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.4175824175824176, 0.1875, 'x[2] <= 182.0\\ngini = 0.5\\nsamples = 6\\nvalue = [0, 1, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0, 0]'),\n", " Text(0.4065934065934066, 0.0625, 'gini = 0.5\\nsamples = 2\\nvalue = [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.42857142857142855, 0.0625, 'gini = 0.0\\nsamples = 4\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]'),\n", " Text(0.4835164835164835, 0.3125, 'x[1] <= 131.0\\ngini = 0.691\\nsamples = 9\\nvalue = [0, 1, 0, 0, 4, 0, 0, 2, 0, 0, 2, 0, 0]'),\n", " Text(0.46153846153846156, 0.1875, 'x[2] <= 152.5\\ngini = 0.5\\nsamples = 4\\nvalue = [0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0]'),\n", " Text(0.45054945054945056, 0.0625, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]'),\n", " Text(0.4725274725274725, 0.0625, 'gini = 0.444\\nsamples = 3\\nvalue = [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0]'),\n", " Text(0.5054945054945055, 0.1875, 'x[2] <= 198.5\\ngini = 0.32\\nsamples = 5\\nvalue = [0, 1, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.4945054945054945, 0.0625, 'gini = 0.0\\nsamples = 4\\nvalue = [0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.5164835164835165, 0.0625, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.5274725274725275, 0.4375, 'x[0] <= 165.5\\ngini = 0.083\\nsamples = 23\\nvalue = [0, 22, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]'),\n", " Text(0.5164835164835165, 0.3125, 'gini = 0.0\\nsamples = 20\\nvalue = [0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.5384615384615384, 0.3125, 'x[2] <= 235.948\\ngini = 0.444\\nsamples = 3\\nvalue = [0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]'),\n", " Text(0.5274725274725275, 0.1875, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]'),\n", " Text(0.5494505494505495, 0.1875, 'gini = 0.0\\nsamples = 2\\nvalue = [0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.7760989010989011, 0.8125, 'x[2] <= 107.0\\ngini = 0.84\\nsamples = 120\\nvalue = [0, 1, 5, 5, 3, 19, 28, 7, 18, 0, 1, 7, 26]'),\n", " Text(0.6644917582417582, 0.875, ' False'),\n", " Text(0.6648351648351648, 0.6875, 'x[1] <= 157.0\\ngini = 0.729\\nsamples = 62\\nvalue = [0, 0, 5, 0, 0, 19, 2, 0, 17, 0, 0, 0, 19]'),\n", " Text(0.6153846153846154, 0.5625, 'x[1] <= 103.5\\ngini = 0.627\\nsamples = 41\\nvalue = [0, 0, 3, 0, 0, 18, 2, 0, 17, 0, 0, 0, 1]'),\n", " Text(0.5934065934065934, 0.4375, 'x[0] <= 252.5\\ngini = 0.461\\nsamples = 23\\nvalue = [0, 0, 0, 0, 0, 5, 2, 0, 16, 0, 0, 0, 0]'),\n", " Text(0.5824175824175825, 0.3125, 'x[1] <= 67.5\\ngini = 0.39\\nsamples = 21\\nvalue = [0, 0, 0, 0, 0, 3, 2, 0, 16, 0, 0, 0, 0]'),\n", " Text(0.5714285714285714, 0.1875, 'gini = 0.0\\nsamples = 11\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0]'),\n", " Text(0.5934065934065934, 0.1875, 'x[2] <= 40.0\\ngini = 0.62\\nsamples = 10\\nvalue = [0, 0, 0, 0, 0, 3, 2, 0, 5, 0, 0, 0, 0]'),\n", " Text(0.5824175824175825, 0.0625, 'gini = 0.0\\nsamples = 3\\nvalue = [0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.6043956043956044, 0.0625, 'gini = 0.408\\nsamples = 7\\nvalue = [0, 0, 0, 0, 0, 0, 2, 0, 5, 0, 0, 0, 0]'),\n", " Text(0.6043956043956044, 0.3125, 'gini = 0.0\\nsamples = 2\\nvalue = [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.6373626373626373, 0.4375, 'x[0] <= 192.0\\ngini = 0.444\\nsamples = 18\\nvalue = [0, 0, 3, 0, 0, 13, 0, 0, 1, 0, 0, 0, 1]'),\n", " Text(0.6263736263736264, 0.3125, 'gini = 0.0\\nsamples = 3\\nvalue = [0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.6483516483516484, 0.3125, 'x[2] <= 96.498\\ngini = 0.24\\nsamples = 15\\nvalue = [0, 0, 0, 0, 0, 13, 0, 0, 1, 0, 0, 0, 1]'),\n", " Text(0.6373626373626373, 0.1875, 'x[0] <= 208.0\\ngini = 0.133\\nsamples = 14\\nvalue = [0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.6263736263736264, 0.0625, 'gini = 0.5\\nsamples = 2\\nvalue = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.6483516483516484, 0.0625, 'gini = 0.0\\nsamples = 12\\nvalue = [0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.6593406593406593, 0.1875, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]'),\n", " Text(0.7142857142857143, 0.5625, 'x[2] <= 89.0\\ngini = 0.254\\nsamples = 21\\nvalue = [0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 18]'),\n", " Text(0.6813186813186813, 0.4375, 'x[0] <= 251.5\\ngini = 0.111\\nsamples = 17\\nvalue = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 16]'),\n", " Text(0.6703296703296703, 0.3125, 'gini = 0.0\\nsamples = 14\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14]'),\n", " Text(0.6923076923076923, 0.3125, 'x[2] <= 16.0\\ngini = 0.444\\nsamples = 3\\nvalue = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2]'),\n", " Text(0.6813186813186813, 0.1875, 'x[1] <= 210.0\\ngini = 0.5\\nsamples = 2\\nvalue = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.6703296703296703, 0.0625, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.6923076923076923, 0.0625, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.7032967032967034, 0.1875, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.7472527472527473, 0.4375, 'x[0] <= 238.5\\ngini = 0.5\\nsamples = 4\\nvalue = [0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]'),\n", " Text(0.7362637362637363, 0.3125, 'x[0] <= 215.5\\ngini = 0.444\\nsamples = 3\\nvalue = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]'),\n", " Text(0.7252747252747253, 0.1875, 'gini = 0.5\\nsamples = 2\\nvalue = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.7472527472527473, 0.1875, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.7582417582417582, 0.3125, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.8873626373626373, 0.6875, 'x[1] <= 207.5\\ngini = 0.744\\nsamples = 58\\nvalue = [0, 1, 0, 5, 3, 0, 26, 7, 1, 0, 1, 7, 7]'),\n", " Text(0.8296703296703297, 0.5625, 'x[0] <= 221.5\\ngini = 0.485\\nsamples = 35\\nvalue = [0, 0, 0, 0, 2, 0, 24, 7, 1, 0, 1, 0, 0]'),\n", " Text(0.8021978021978022, 0.4375, 'x[2] <= 176.5\\ngini = 0.627\\nsamples = 13\\nvalue = [0, 0, 0, 0, 2, 0, 3, 7, 1, 0, 0, 0, 0]'),\n", " Text(0.7802197802197802, 0.3125, 'x[0] <= 215.0\\ngini = 0.56\\nsamples = 5\\nvalue = [0, 0, 0, 0, 1, 0, 3, 0, 1, 0, 0, 0, 0]'),\n", " Text(0.7692307692307693, 0.1875, 'x[2] <= 175.0\\ngini = 0.375\\nsamples = 4\\nvalue = [0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.7582417582417582, 0.0625, 'gini = 0.0\\nsamples = 3\\nvalue = [0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.7802197802197802, 0.0625, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.7912087912087912, 0.1875, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]'),\n", " Text(0.8241758241758241, 0.3125, 'x[1] <= 186.0\\ngini = 0.219\\nsamples = 8\\nvalue = [0, 0, 0, 0, 1, 0, 0, 7, 0, 0, 0, 0, 0]'),\n", " Text(0.8131868131868132, 0.1875, 'gini = 0.0\\nsamples = 7\\nvalue = [0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0]'),\n", " Text(0.8351648351648352, 0.1875, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.8571428571428571, 0.4375, 'x[2] <= 226.5\\ngini = 0.087\\nsamples = 22\\nvalue = [0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 1, 0, 0]'),\n", " Text(0.8461538461538461, 0.3125, 'gini = 0.0\\nsamples = 18\\nvalue = [0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.8681318681318682, 0.3125, 'x[2] <= 235.5\\ngini = 0.375\\nsamples = 4\\nvalue = [0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0]'),\n", " Text(0.8571428571428571, 0.1875, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]'),\n", " Text(0.8791208791208791, 0.1875, 'gini = 0.0\\nsamples = 3\\nvalue = [0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.945054945054945, 0.5625, 'x[2] <= 193.0\\ngini = 0.756\\nsamples = 23\\nvalue = [0, 1, 0, 5, 1, 0, 2, 0, 0, 0, 0, 7, 7]'),\n", " Text(0.9230769230769231, 0.4375, 'x[0] <= 220.5\\ngini = 0.569\\nsamples = 12\\nvalue = [0, 0, 0, 5, 0, 0, 1, 0, 0, 0, 0, 0, 6]'),\n", " Text(0.9120879120879121, 0.3125, 'x[1] <= 217.5\\ngini = 0.449\\nsamples = 7\\nvalue = [0, 0, 0, 5, 0, 0, 1, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.9010989010989011, 0.1875, 'gini = 0.0\\nsamples = 3\\nvalue = [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.9230769230769231, 0.1875, 'x[1] <= 227.0\\ngini = 0.625\\nsamples = 4\\nvalue = [0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.9120879120879121, 0.0625, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.9340659340659341, 0.0625, 'gini = 0.444\\nsamples = 3\\nvalue = [0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.9340659340659341, 0.3125, 'gini = 0.0\\nsamples = 5\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5]'),\n", " Text(0.967032967032967, 0.4375, 'x[0] <= 194.5\\ngini = 0.562\\nsamples = 11\\nvalue = [0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 7, 1]'),\n", " Text(0.9560439560439561, 0.3125, 'gini = 0.0\\nsamples = 1\\nvalue = [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'),\n", " Text(0.978021978021978, 0.3125, 'x[2] <= 220.5\\ngini = 0.48\\nsamples = 10\\nvalue = [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 7, 1]'),\n", " Text(0.967032967032967, 0.1875, 'x[2] <= 207.5\\ngini = 0.72\\nsamples = 5\\nvalue = [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 2, 1]'),\n", " Text(0.9560439560439561, 0.0625, 'gini = 0.444\\nsamples = 3\\nvalue = [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0]'),\n", " Text(0.978021978021978, 0.0625, 'gini = 0.5\\nsamples = 2\\nvalue = [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]'),\n", " Text(0.989010989010989, 0.1875, 'gini = 0.0\\nsamples = 5\\nvalue = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0]')]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tree.plot_tree(dt)\n" ] }, { "cell_type": "markdown", "id": "063e0a3b", "metadata": {}, "source": [ "# Wednesday - Day 3: Unsupervised learning and model validation" ] }, { "cell_type": "code", "execution_count": 29, "id": "08437bee-8af0-42b6-b14a-13923eca3d77", "metadata": {}, "outputs": [], "source": [ "X = filtered_dataset[['R', 'G', 'B']] #make a new matrix with just a few columns \n", "y = filtered_dataset['colour'] #one dimensional so use colour " ] }, { "cell_type": "code", "execution_count": 36, "id": "1bef2f7a", "metadata": {}, "outputs": [], "source": [ "# Split the dataset into training and testing sets\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)" ] }, { "cell_type": "code", "execution_count": 30, "id": "31933a0c", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "code", "execution_count": 31, "id": "91557abe", "metadata": {}, "outputs": [], "source": [ "rf = RandomForestClassifier(n_estimators = 100) # passing through the parameterhow many trees to use\n", "# Optional: can test it out with different depths\n", "# Setting max_depth limits the growth of each tree, helping to prevent overfitting\n", "rf = RandomForestClassifier(n_estimators = 500, max_depth=5) " ] }, { "cell_type": "code", "execution_count": 32, "id": "c5b6bace", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomForestClassifier(max_depth=5, n_estimators=500)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestClassifier(max_depth=5, n_estimators=500)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fit the model to the training data\n", "rf.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 33, "id": "f57f13c4", "metadata": {}, "outputs": [], "source": [ "# Use the trained model to make predictions on the test data\n", "y_predicted = dt.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 34, "id": "d8f14df5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6329113924050633" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Measure how well the model performs on the test data\n", "accuracy_score(y_test, y_predicted)" ] }, { "cell_type": "code", "execution_count": null, "id": "b9eff482", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 5 }