Files
Esercizi-MLN/Labs/Lab1 - Python.ipynb
2024-10-04 16:57:29 +02:00

924 lines
36 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "1a0e63ce",
"metadata": {},
"source": [
"# Lab #1"
]
},
{
"cell_type": "markdown",
"id": "4df7a712",
"metadata": {},
"source": [
"The purpose of this laboratory is to get you acquainted with Python. \n",
"More specifically, you will learn how to:\n",
"- read different types of datasets (CSV and JSON). \n",
"- extract some useful information (mean and standard deviation) from the datasets while only using basic python.\n",
"- create a simple rule-based classifier that is already capable to perform some classification.\n"
]
},
{
"cell_type": "markdown",
"id": "6bdb4439",
"metadata": {},
"source": [
"## Preliminaries\n",
"### Python availability\n",
"Make sure that Python 3 is installed on your device with the commands `python --version`. The version should be in the form `3.x.x.`"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2bd291ae",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:04.672718500Z",
"start_time": "2023-10-09T09:26:04.536268400Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python 3.12.6\n"
]
}
],
"source": [
"! python --version"
]
},
{
"cell_type": "markdown",
"id": "05e6a809",
"metadata": {},
"source": [
"### Dataset Download\n",
"For this lab, three different datasets will be used. Here, you will learnmore about them and how to retrieve\n",
"them."
]
},
{
"cell_type": "markdown",
"id": "b43b2a4d",
"metadata": {},
"source": [
"#### Iris\n",
"Iris is a particularly famous *toy dataset* (i.e. a dataset with a small number of rows and columns, mostly\n",
"used for initial small-scale tests and proofs of concept). \n",
"This specific dataset contains information about the **Iris**, a genus that includes 260-300 species of plants. \n",
"The Iris dataset contains measurements for 150 Iris flowers, each belonging to one of three species (50 flowers each): \n",
"\n",
"Iris Virginica | Iris Versicolor | Iris Setosa |\n",
":-------------------------:|:-------------------------:|:---------------|\n",
":<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/f/f8/Iris_virginica_2.jpg/1200px-Iris_virginica_2.jpg\" alt=\"Iris Virginica\" width=\"200\" /> | <img src=\"https://www.waternursery.it/document/img_prodotti/616/1646318149.jpeg\" alt=\"Iris Versicolor\" width=\"200\" /> |<img src=\"https://d2j6dbq0eux0bg.cloudfront.net/images/28296135/2323483832.jpg\" alt=\"Iris Setosa\" width=\"200\" />|\n",
"\n",
"Each of the 150 flowers contained in the Iris dataset is represented by 5 values:\n",
"- sepal length, in cm\n",
"- sepal width, in cm\n",
"- petal length, in cm\n",
"- petal width, in cm\n",
"- Iris species, one of: Iris-setosa, Iris-versicolor, Iris-virginica (the label)\n",
"\n",
"Each row of the dataset represents a distinct flower (as such, the dataset will have 150 rows). Each\n",
"row then contains 5 values (4 measurements and a species label).\n",
"The dataset is described in more detail on the [UCI Machine Learning Repository website](https://archive.ics.uci.edu/dataset/53/iris). The dataset\n",
"can either be downloaded directly from there (iris.data file), or from a terminal, using the `wget` tool. The\n",
"following command downloads the dataset from the original URL and stores it in a file named iris.csv.\n",
"\n",
"`wget \"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data\" -O iris.csv`\n",
"\n",
"The dataset is available as a Comma-Separated Values (CSV) file. These files are typically used to\n",
"represent tabular data. \n",
"- Each row is represented on one of the lines. \n",
"- Each of the rows contains a fixed number of columns. \n",
"- Each of the columns (in each row) is separated by a comma (,).\n",
"\n",
"To read CSV files, Python offers a module called `csv` (here the offical [doc](https://docs.python.org/3/library/csv.html)). This module allows using `csv.reader()`, which\n",
"reads a file row by row. For each row, it returns a list of columns that can be processed as needed. \n",
"\n",
"\n",
"Let's download the dataset and print the first three rows.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "94138fd0",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:05.995857200Z",
"start_time": "2023-10-09T09:26:04.671709900Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b7\u001b[1A\u001b[1G\u001b[27G[Files: 0 Bytes: 0 [0 B/s] Re]\u001b8\u001b7\u001b[2A\u001b[1G\u001b[27G[https://archive.ics.uci.edu/ml]\u001b8\u001b7\u001b[2A\u001b[1Giris.csv 100% [=============================>] 8.37K --.-KB/s\u001b8\u001b7\u001b[1A\u001b[1G\u001b[27G[Files: 1 Bytes: 8.37K [7.97KB]\u001b8\u001b[m\u001b[m\u001b[m\u001b[mReading first lines of IRIS dataset\n",
"['5.1', '3.5', '1.4', '0.2', 'Iris-setosa']\n",
"['4.9', '3.0', '1.4', '0.2', 'Iris-setosa']\n",
"['4.7', '3.2', '1.3', '0.2', 'Iris-setosa']\n",
"['4.6', '3.1', '1.5', '0.2', 'Iris-setosa']\n",
"['5.0', '3.6', '1.4', '0.2', 'Iris-setosa']\n"
]
}
],
"source": [
"! wget \"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data\" -O iris.csv\n",
"\n",
"print(\"Reading first lines of IRIS dataset\")\n",
"import csv \n",
"with open(\"iris.csv\") as f:\n",
" for i, cols in enumerate(csv.reader(f)):\n",
" print(cols)\n",
" if i >= 4:\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "3ecb2df8",
"metadata": {},
"source": [
"Note by default, csv.reader converts all fields read into strings (str). \n",
"If you want to treat them as number, remember to cast them correctly!"
]
},
{
"cell_type": "markdown",
"id": "5725ff17",
"metadata": {},
"source": [
"#### MNIST\n",
"The MNIST dataset is another particularly famous dataset. It contains several thousands of hand-written\n",
"digits (0 to 9). \n",
"- Each hand-written digit is contained in an image represented as $28 x 28$ 8-bit grayscale image. \n",
"- This means that each digit has $784$ ($28^2$) pixels\n",
"- Each pixel has a value that ranges from 0 (black) to 255 (white).\n",
"\n",
"<img src=\"https://machinelearningmastery.com/wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-MNIST-Dataset.png\" alt=\"MNIST images\" width=\"500\" />\n",
"\n",
"The dataset can be downloaded from the following link:\n",
"\n",
"[https://raw.githubusercontent.com/dbdmg/data-science-lab/master/datasets/mnist_test.csv](https://raw.githubusercontent.com/dbdmg/data-science-lab/master/datasets/mnist_test.csv)\n",
"\n",
"In this case, MNIST is represented as a CSV file. Similarly to the Iris dataset, each row of the MNIST\n",
"datasets represents the pixels of the image representing a digit. For the sake of simplicity, this dataset contains only a small fraction (10; 000\n",
"digits out of 70; 000) of the real MNIST dataset. \n",
"\n",
"For each digit, 785 values are available: \n",
"- the first one is the numerical value depicted in the image (e.g. for Figure 2 it would be 5). \n",
"- the following 784 columns represent the grayscale image in row-major order (for more information about row- and column-major order of matrices, see [Wikipedia](https://en.wikipedia.org/wiki/Row-_and_column-major_order)).\n",
"\n",
"The MNIST dataset in CSV format can be read with the same approach used for Iris, keeping in mind\n",
"that, in this case, the digit label (i.e. the first column) is an integer from 0 to 9, while the following 784\n",
"values are integers between 0 and 255."
]
},
{
"cell_type": "markdown",
"id": "7532fe60",
"metadata": {},
"source": [
"## Exercises\n",
"Note that exercises marked with a (*) are optional, you should focus on completing the other ones first.\n",
"### Iris analysis\n",
"1. Load the previously downloaded Iris dataset as a list of lists (each of the 150 lists should have 5 elements). You can make use of the csv module presented"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "9379bd5669ca8db9",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:05.996857100Z",
"start_time": "2023-10-09T09:26:05.992854900Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset loaded. Number of lines: 150\n"
]
}
],
"source": [
"import csv\n",
"\n",
"iris_list = []\n",
"\n",
"with open(\"iris.csv\") as f:\n",
" for cols in csv.reader(f):\n",
" if len(cols) != 5:\n",
" continue\n",
" iris_list.append(cols)\n",
"\n",
"print(f\"Dataset loaded. Number of lines: {len(iris_list)}\")\n"
]
},
{
"cell_type": "markdown",
"id": "1bbcdf393c1b88a7",
"metadata": {},
"source": [
"2. Compute and print the mean and the standard deviation for each of the 4 measurement columns (i.e. sepal length and width, petal length and width). Remember that, for a given list of n values $x = (x_1, x_2, ..., x_n)$, the mean $\\mu$ and the standard deviation $\\sigma$ are defined respectively as:\n",
"$$\\mu = {1 \\over n} \\sum_i^n x_i $$\n",
"\n",
"$$ \\sigma = \\sqrt{ {1 \\over n} \\sum_i^n (x_i - \\mu)^2} $$"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "a8c4f766f46b3e23",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.007483400Z",
"start_time": "2023-10-09T09:26:05.995857200Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sepal length mean: 5.843333333333334, std_dev: 0.8253012917851409\n",
"Sepal wifth mean: 3.0540000000000003, std_dev: 0.43214658007054346\n",
"Petal length mean: 3.758666666666666, std_dev: 1.758529183405521\n",
"Petal width mean: 1.1986666666666668, std_dev: 0.7606126185881716\n"
]
}
],
"source": [
"from math import sqrt\n",
"\n",
"def mean(items):\n",
" return sum(items)/len(items)\n",
"\n",
"def std_dev(items, mu = None):\n",
" if mu is None:\n",
" mu = mean(items)\n",
"\n",
" return sqrt(mean([(x-mu)**2 for x in items]))\n",
"\n",
"sepal_length = []\n",
"sepal_width = []\n",
"petal_length = []\n",
"petal_width = []\n",
"\n",
"for iris in iris_list:\n",
" sepal_length.append(float(iris[0]))\n",
" sepal_width.append(float(iris[1]))\n",
" petal_length.append(float(iris[2]))\n",
" petal_width.append(float(iris[3]))\n",
"\n",
"sepal_length_metrics = (mean(sepal_length), std_dev(sepal_length))\n",
"sepal_width_metrics = (mean(sepal_width), std_dev(sepal_width))\n",
"petal_length_metrics = (mean(petal_length), std_dev(petal_length))\n",
"petal_width_metrics = (mean(petal_width), std_dev(petal_width))\n",
"\n",
"print(f\"Sepal length mean: {sepal_length_metrics[0]}, std_dev: {sepal_length_metrics[1]}\")\n",
"print(f\"Sepal wifth mean: {sepal_width_metrics[0]}, std_dev: {sepal_width_metrics[1]}\")\n",
"print(f\"Petal length mean: {petal_length_metrics[0]}, std_dev: {petal_length_metrics[1]}\")\n",
"print(f\"Petal width mean: {petal_width_metrics[0]}, std_dev: {petal_width_metrics[1]}\")"
]
},
{
"cell_type": "markdown",
"id": "1015b464055be7bd",
"metadata": {},
"source": [
"\n",
"3. Compute and print the mean and the standard deviation for each of the 4 measurement columns, separately for each of the three Iris species (versicolor, virginica and setosa)."
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "8372b07413d00161",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.008490200Z",
"start_time": "2023-10-09T09:26:06.000855600Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Species found: {'Iris-versicolor', 'Iris-setosa', 'Iris-virginica'}\n",
"Metrics for specie Iris-versicolor\n",
"Sepal length for mean: 5.936, std_dev: 0.5109833656783751\n",
"Sepal wifth mean: 2.77, std_dev: 0.31064449134018135\n",
"Petal length mean: 4.26, std_dev: 0.4651881339845203\n",
"Petal width mean: 1.3259999999999998, std_dev: 0.19576516544063705\n",
"\n",
"Metrics for specie Iris-setosa\n",
"Sepal length for mean: 5.006, std_dev: 0.3489469873777391\n",
"Sepal wifth mean: 3.418, std_dev: 0.37719490982779713\n",
"Petal length mean: 1.464, std_dev: 0.17176728442867112\n",
"Petal width mean: 0.24400000000000002, std_dev: 0.10613199329137281\n",
"\n",
"Metrics for specie Iris-virginica\n",
"Sepal length for mean: 6.587999999999999, std_dev: 0.6294886813914926\n",
"Sepal wifth mean: 2.9739999999999998, std_dev: 0.3192553836664309\n",
"Petal length mean: 5.5520000000000005, std_dev: 0.546347874526844\n",
"Petal width mean: 2.026, std_dev: 0.2718896835115301\n",
"\n"
]
}
],
"source": [
"metrics_species = {}\n",
"species = set([iris[4] for iris in iris_list])\n",
"print(f\"Species found: {species}\")\n",
"\n",
"for specie in species:\n",
" irises_filtered = filter(lambda s: s[4] == specie, iris_list)\n",
"\n",
" sepal_length = []\n",
" sepal_width = []\n",
" petal_length = []\n",
" petal_width = []\n",
"\n",
" for iris in irises_filtered:\n",
" sepal_length.append(float(iris[0]))\n",
" sepal_width.append(float(iris[1]))\n",
" petal_length.append(float(iris[2]))\n",
" petal_width.append(float(iris[3]))\n",
"\n",
" metrics_species[specie] = {}\n",
"\n",
" metrics_species[specie][\"sepal_length_metrics\"] = (mean(sepal_length), std_dev(sepal_length))\n",
" metrics_species[specie][\"sepal_width_metrics\"] = (mean(sepal_width), std_dev(sepal_width))\n",
" metrics_species[specie][\"petal_length_metrics\"] = (mean(petal_length), std_dev(petal_length))\n",
" metrics_species[specie][\"petal_width_metrics\"] = (mean(petal_width), std_dev(petal_width))\n",
"\n",
" print(f\"Metrics for specie {specie}\")\n",
" print(f\"Sepal length for mean: {metrics_species[specie][\"sepal_length_metrics\"][0]}, std_dev: {metrics_species[specie][\"sepal_length_metrics\"][1]}\")\n",
" print(f\"Sepal wifth mean: {metrics_species[specie][\"sepal_width_metrics\"][0]}, std_dev: {metrics_species[specie][\"sepal_width_metrics\"][1]}\")\n",
" print(f\"Petal length mean: {metrics_species[specie][\"petal_length_metrics\"][0]}, std_dev: {metrics_species[specie][\"petal_length_metrics\"][1]}\")\n",
" print(f\"Petal width mean: {metrics_species[specie][\"petal_width_metrics\"][0]}, std_dev: {metrics_species[specie][\"petal_width_metrics\"][1]}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"id": "a8a57e3baea8eae1",
"metadata": {},
"source": [
"\n",
"4. Based on the results of exercises 2 and 3, which of the 4 measurements would you considering as being the most characterizing one for the three species? (In other words, which measurement would you consider “best”, if you were to guess the Iris species based only on those four values?)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "aff6224527faf6ef",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.009315Z",
"start_time": "2023-10-09T09:26:06.003855200Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Metrics for specie Iris-versicolor\n",
"Range for sepal_length_metrics: [5.425016634321625, 6.446983365678375]\n",
"Range for sepal_width_metrics: [2.4593555086598187, 3.0806444913401814]\n",
"Range for petal_length_metrics: [3.7948118660154795, 4.7251881339845205]\n",
"Range for petal_width_metrics: [1.1302348345593627, 1.521765165440637]\n",
"\n",
"Metrics for specie Iris-setosa\n",
"Range for sepal_length_metrics: [4.657053012622261, 5.3549469873777396]\n",
"Range for sepal_width_metrics: [3.040805090172203, 3.7951949098277975]\n",
"Range for petal_length_metrics: [1.292232715571329, 1.635767284428671]\n",
"Range for petal_width_metrics: [0.1378680067086272, 0.35013199329137284]\n",
"\n",
"Metrics for specie Iris-virginica\n",
"Range for sepal_length_metrics: [5.958511318608506, 7.217488681391492]\n",
"Range for sepal_width_metrics: [2.654744616333569, 3.2932553836664304]\n",
"Range for petal_length_metrics: [5.005652125473157, 6.098347874526844]\n",
"Range for petal_width_metrics: [1.7541103164884697, 2.2978896835115297]\n",
"\n"
]
}
],
"source": [
"for specie, content in metrics_species.items():\n",
" print(f\"Metrics for specie {specie}\")\n",
"\n",
" for metric, values in content.items():\n",
" print(f\"Range for {metric}: [{values[0]-values[1]}, {values[0]+values[1]}]\")\n",
" print()\n",
"\n",
"# The best index seems to be the petal_length"
]
},
{
"cell_type": "markdown",
"id": "be6f631cb71fb4c9",
"metadata": {},
"source": [
"\n",
"5. Based on the considerations of Exercise 3, assign the flowers with the following measurements to what you consider would be the most likely species.\n",
"````\n",
"5.2, 3.1, 4.0, 1.2: versicolor\n",
"4.9, 2.5, 5.6, 2.0: virginica\n",
"5.4, 3.2, 1.9, 0.4: setosa\n",
"````"
]
},
{
"cell_type": "markdown",
"id": "42b9f6d1c6a4304c",
"metadata": {},
"source": [
"\n",
"6. (*) Create a Rule-based classifier similar to the one seen in class. This classifier, again, will receive some rule and will classify each sample into one of the three species."
]
},
{
"cell_type": "code",
"execution_count": 116,
"id": "d9e8b272c7ff4e78",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.010829Z",
"start_time": "2023-10-09T09:26:06.008490200Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"'Iris-virginica'"
]
},
"execution_count": 116,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def classify_iris(row):\n",
" petal_length = float(row[2])\n",
" diffs = {}\n",
"\n",
" for specie in species:\n",
" diffs[specie] = abs(metrics_species[specie][\"petal_length_metrics\"][0]-petal_length)\n",
" \n",
" min_val = min(diffs.values())\n",
"\n",
" for k in diffs:\n",
" if diffs[k] == min_val:\n",
" return k\n",
"\n",
" return None\n",
"\n",
"classify_iris(iris_list[120])"
]
},
{
"cell_type": "markdown",
"id": "4acda37f8fa9e27d",
"metadata": {},
"source": [
"7. (*) Compute prediction for all the elements in the dataset and store them in a list. Then, compute the accuracy of the classifier that you create. Remember that the accuracy metric is:\n",
"\n",
"$$ {\\text{number of correct predictions (TP + TN)} \\over \\text{total number of predictions (TP+TN+FP+FN)}} $$\n",
"\n",
"Where one can check whether the prediction is correct by looking at the label of the sample ($5^{th}$ column)"
]
},
{
"cell_type": "code",
"execution_count": 120,
"id": "120f43768967753c",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.020829300Z",
"start_time": "2023-10-09T09:26:06.010105700Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy is 0.9466666666666667\n"
]
}
],
"source": [
"correct_predictions = 0.0\n",
"\n",
"for row in iris_list:\n",
" guess = classify_iris(row)\n",
" if guess == row[4]:\n",
" correct_predictions += 1\n",
"\n",
"print(f\"Accuracy is {correct_predictions / len(iris_list)}\")\n"
]
},
{
"cell_type": "markdown",
"id": "46442757",
"metadata": {},
"source": [
"### MNIST Analysis\n",
"\n",
"1. Load the previously downloaded MNIST dataset. You can make use of the csv module already presented."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "70e93e04",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.021830200Z",
"start_time": "2023-10-09T09:26:06.015834100Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset loaded. Number of lines: 10000\n"
]
}
],
"source": [
"# ! wget https://raw.githubusercontent.com/dbdmg/data-science-lab/master/datasets/mnist_test.csv -O mnist.csv\n",
"\n",
"import csv\n",
"\n",
"mnist_dataset = []\n",
"\n",
"with open(\"mnist.csv\") as f:\n",
" for cols in csv.reader(f):\n",
" mnist_dataset.append((int(cols[0]), [int(value) for value in cols[1:]]))\n",
"\n",
"print(f\"Dataset loaded. Number of lines: {len(mnist_dataset)}\")\n"
]
},
{
"cell_type": "markdown",
"id": "725ddad1ee7b1c1b",
"metadata": {},
"source": [
"2. Create a function that, given a position $1 < k < 10,000$, prints the $k^{th}$ sample of the dataset (i.e. the $k^{th}$ row of the csv file) as a grid of $28x28$ characters. More specifically, you should map each range of pixel values to the following characters:\n",
" - [0; 64) &rarr; \" \"\n",
" - [64; 128) &rarr; \".\"\n",
" - [128; 192) &rarr; \"*\"\n",
" - [192; 256) &rarr; \"#\"\n",
"So, for example, you should map the sequence `0, 72, 192, 138, 250` to the string `.#*#`.\n",
"*Note*: Remember to start a new line every time you read 28 characters\n",
"\n",
"Example of output: \n",
"```\n",
" .# **\n",
" .##..*#####\n",
" #########*.\n",
" #####***.\n",
" ##*\n",
" *##\n",
" ##\n",
" .##\n",
" ###*\n",
" .#####.\n",
" *###*\n",
" *###*\n",
" ###\n",
" .##\n",
" ###\n",
" .###\n",
" . *###.\n",
" .# .*###*\n",
" .######.\n",
" *##*.\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "77672741a499b98d",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.055875500Z",
"start_time": "2023-10-09T09:26:06.016829800Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" *# \n",
" .### \n",
" ####* \n",
" *######. \n",
" ###*#### \n",
" .## .#### \n",
" .#. *##* \n",
" *#* ###* \n",
" .## .*#### \n",
" #####* ##* \n",
" .###* *## \n",
" .** *#* \n",
" .## \n",
" *#. \n",
" ## \n",
" .#* \n",
" ## \n",
" #* \n",
" *# \n",
" #. \n",
" \n",
" \n"
]
}
],
"source": [
"def print_k_element(dataset, k):\n",
" assert 1<=k<=10000\n",
" digit = dataset[k-1][1]\n",
"\n",
" for i, c in enumerate(digit):\n",
" newline = \"\"\n",
" if i % 28 == 27:\n",
" newline = \"\\n\"\n",
" \n",
" if 64 <= c < 128:\n",
" printable_char = \".\"\n",
" elif 128 <= c < 192:\n",
" printable_char = \"*\"\n",
" elif 192 <= c < 256:\n",
" printable_char = \"#\"\n",
" else:\n",
" printable_char = \" \"\n",
"\n",
" print(printable_char, end=newline)\n",
"\n",
"print_k_element(mnist_dataset, 8)"
]
},
{
"cell_type": "markdown",
"id": "abc23e028f2b3979",
"metadata": {},
"source": [
"3. Compute the Euclidean distance between each pair of the 784-dimensional vectors of the digits at\n",
"the following positions: $26^{th}$, $30^{th}$, $32^{nd}$, $35^{th}$.\n",
"\n",
"*Note*: Remember that Python arrays are indexed from 0, so the $k^{th}$ value will be at position $k-1$"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bb3ddc8c6c9571d1",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.056876Z",
"start_time": "2023-10-09T09:26:06.019828700Z"
}
},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'sqrt' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 11\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m j:\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDistance between item \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mj\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[43meuclidian_distance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmnist_dataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;250;43m \u001b[39;49m\u001b[43mmnist_dataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43mj\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"Cell \u001b[0;32mIn[7], line 2\u001b[0m, in \u001b[0;36meuclidian_distance\u001b[0;34m(v1, v2)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21meuclidian_distance\u001b[39m(v1, v2):\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msqrt\u001b[49m(\u001b[38;5;28msum\u001b[39m([(a \u001b[38;5;241m-\u001b[39m b)\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m a,b \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(v1, v2)]))\n",
"\u001b[0;31mNameError\u001b[0m: name 'sqrt' is not defined"
]
}
],
"source": [
"def euclidian_distance(v1, v2):\n",
" return sqrt(sum([(a - b)**2 for a,b in zip(v1, v2)]))\n",
"\n",
"values_index = [25, 29, 31, 34]\n",
"\n",
"for i in values_index:\n",
" for j in values_index:\n",
" if i >= j:\n",
" continue\n",
"\n",
" print(f\"Distance between item {i} and {j}: {euclidian_distance(mnist_dataset[i][1], mnist_dataset[j][1])}\")"
]
},
{
"cell_type": "markdown",
"id": "c2a988794dbd7294",
"metadata": {},
"source": [
"4. Based on the distances computed in the previous step and knowing that the digits listed in Exercise 3 are (not necessarily in this order) $0, 1, 1, 7$ can you assign the correct label to each of the digits of Exercise 3?"
]
},
{
"cell_type": "markdown",
"id": "4c6c2fbf5d334780",
"metadata": {},
"source": [
"Item 29 and 31 have the lower distance, so they are probably the same number (\"1\"). The distance between 29-34 and 31-34 is similar, and is lower than the distance between 29-25 and 31-25. As the digit 1 is similar to the digit 7, item 34 is 7 and so item 25 is 0."
]
},
{
"cell_type": "markdown",
"id": "c19bf11e0d25a0af",
"metadata": {},
"source": [
"5. There are 1,135 images representing 1s and 980 images representing 0s in the dataset. For all 0s and 1s separately, count the number of times each of the 784 pixels is black (use 128 as the threshold value). You can do this by building a list `Z` and a list `O`, each containing 784 elements, containing respectively the counts for the 0s and the 1s. `Z[i]` and `O[i]` contain the number of times the $i^{th}$ pixel was black for either class. For each value i, compute `abs(Z[i] - O[i])`. The $i$ with the highest value represents the pixel that best separates the digits “0” and “1” (i.e. the pixel that is most often black for one class and white for the other). Where is this pixel located within the grid? Why is it?"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "d0b1ca74cb547174",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.056876Z",
"start_time": "2023-10-09T09:26:06.022830500Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t1\t3\t2\t1\t0\t0\t0\t0\t1\t1\t0\t0\t0\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t2\t1\t0\t5\t9\t17\t25\t14\t9\t5\t5\t1\t3\t2\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t0\t1\t1\t1\t7\t20\t34\t50\t59\t19\t70\t82\t47\t24\t14\t12\t0\t4\t3\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t0\t2\t3\t9\t30\t72\t136\t185\t181\t161\t140\t181\t174\t145\t81\t61\t32\t18\t2\t5\t2\t0\t0\n",
"0\t0\t0\t0\t0\t3\t4\t15\t40\t98\t192\t295\t353\t325\t225\t185\t212\t260\t298\t258\t203\t135\t74\t18\t5\t2\t0\t0\n",
"0\t0\t0\t0\t1\t3\t11\t32\t92\t188\t327\t449\t466\t353\t199\t113\t145\t222\t346\t395\t362\t305\t170\t54\t8\t0\t0\t0\n",
"0\t0\t0\t0\t4\t5\t19\t75\t169\t309\t467\t560\t541\t355\t34\t113\t47\t94\t315\t486\t513\t411\t287\t120\t18\t1\t0\t0\n",
"0\t0\t0\t0\t3\t8\t46\t140\t251\t436\t570\t633\t540\t218\t196\t414\t313\t64\t242\t495\t590\t531\t376\t184\t35\t1\t0\t0\n",
"0\t0\t0\t0\t4\t21\t89\t210\t366\t532\t661\t646\t432\t23\t467\t696\t497\t173\t167\t458\t603\t571\t466\t244\t66\t1\t0\t0\n",
"0\t0\t0\t0\t4\t44\t150\t308\t490\t646\t670\t566\t330\t181\t748\t901\t552\t214\t125\t403\t583\t600\t511\t294\t101\t2\t0\t0\n",
"0\t0\t0\t1\t7\t78\t234\t407\t590\t670\t617\t455\t184\t440\t995\t979\t534\t129\t139\t373\t579\t615\t528\t328\t121\t7\t0\t0\n",
"0\t0\t0\t1\t17\t126\t328\t495\t655\t658\t546\t311\t20\t717\t1100\t978\t424\t53\t153\t368\t556\t622\t530\t347\t120\t6\t0\t0\n",
"0\t0\t0\t0\t30\t192\t397\t578\t671\t625\t425\t193\t181\t896\t1113\t924\t280\t5\t176\t408\t578\t620\t509\t330\t112\t7\t0\t0\n",
"0\t0\t0\t0\t39\t249\t451\t615\t678\t554\t321\t50\t385\t967\t1104\t795\t160\t70\t245\t484\t628\t609\t463\t279\t101\t6\t0\t0\n",
"0\t0\t0\t0\t57\t295\t488\t654\t665\t485\t220\t100\t532\t971\t1041\t592\t79\t138\t373\t565\t655\t575\t412\t213\t58\t3\t0\t0\n",
"0\t0\t0\t0\t84\t312\t517\t657\t642\t452\t129\t232\t603\t931\t875\t402\t35\t316\t529\t640\t641\t506\t319\t151\t33\t2\t0\t0\n",
"0\t0\t0\t1\t85\t305\t514\t650\t626\t423\t109\t282\t580\t797\t639\t201\t217\t482\t631\t656\t561\t388\t221\t72\t9\t1\t0\t0\n",
"0\t0\t0\t1\t73\t267\t470\t605\t627\t456\t121\t188\t440\t519\t320\t55\t420\t607\t664\t572\t405\t248\t123\t36\t2\t0\t0\t0\n",
"0\t0\t0\t2\t43\t201\t384\t558\t618\t492\t273\t21\t126\t179\t37\t244\t503\t603\t569\t415\t259\t132\t52\t8\t0\t0\t0\t0\n",
"0\t0\t0\t3\t22\t114\t258\t400\t506\t507\t389\t268\t170\t101\t159\t329\t458\t469\t372\t233\t124\t52\t15\t2\t0\t0\t0\t0\n",
"0\t0\t0\t0\t5\t47\t124\t221\t318\t358\t360\t351\t301\t229\t210\t244\t273\t247\t179\t105\t42\t9\t1\t0\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t2\t21\t32\t103\t146\t178\t232\t234\t153\t60\t44\t60\t74\t52\t22\t4\t1\t1\t0\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t3\t7\t3\t6\t6\t26\t50\t51\t19\t4\t2\t6\t5\t5\t3\t2\t1\t1\t0\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t0\t0\t1\t2\t3\t4\t5\t1\t0\t3\t2\t1\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n",
"0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\t0\n",
"The pixel located at 14:14 has the maximum value of 1113\n"
]
}
],
"source": [
"Z = [0]*784\n",
"O = [0]*784\n",
"\n",
"for digit in mnist_dataset:\n",
" if digit[0] != 0:\n",
" continue\n",
"\n",
" for i, value in enumerate(digit[1]):\n",
" if digit[1][i] > 128:\n",
" Z[i] += 1\n",
"\n",
"for digit in mnist_dataset:\n",
" if digit[0] != 1:\n",
" continue\n",
"\n",
" for i, value in enumerate(digit[1]):\n",
" if digit[1][i] > 128:\n",
" O[i] += 1\n",
"\n",
"max_val = 0\n",
"max_pos = -1\n",
"\n",
"for i in range(784):\n",
" endline = \"\\t\"\n",
" if i % 28 == 27:\n",
" endline = \"\\n\"\n",
" \n",
" diff = abs(Z[i]-O[i])\n",
" if diff > max_val:\n",
" max_val = diff\n",
" max_pos = i\n",
"\n",
" print(diff, end=endline)\n",
"\n",
"print(f\"The pixel located at {max_pos//28}:{max_pos%28} has the maximum value of {max_val}\")"
]
},
{
"cell_type": "markdown",
"id": "b36c9eda73610bca",
"metadata": {},
"source": [
"6. (*) Extract a subset of the MNIST dataset composed of only 0 and 1 digits. Create a Rule-based classifier that take as input the rule that you discovered in ex. 5. As previously then, compute the prediction of such a classifier on all the samples in the dataset"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "e56ccd823714b00f",
"metadata": {
"ExecuteTime": {
"end_time": "2023-10-09T09:26:06.091154700Z",
"start_time": "2023-10-09T09:26:06.028359400Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy is 0.9895981087470449\n"
]
}
],
"source": [
"def classify_number(row):\n",
" if row[406] > 128:\n",
" return 1\n",
" else:\n",
" return 0\n",
"\n",
"\n",
"correct_predictions = 0.0\n",
"dataset_len = 0\n",
"\n",
"for row in mnist_dataset:\n",
" if row[0] != 0 and row[0] != 1:\n",
" continue\n",
"\n",
" dataset_len += 1\n",
" guess = classify_number(row[1])\n",
" if guess == row[0]:\n",
" correct_predictions += 1\n",
"\n",
"print(f\"Accuracy is {correct_predictions / dataset_len}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}