forked from gsingers/search_with_machine_learning_course
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
file for students to implement level 1 task 1
- Loading branch information
1 parent
ce493ba
commit e8325fd
Showing
1 changed file
with
52 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
import argparse | ||
import xml.etree.ElementTree as ET | ||
import pandas as pd | ||
|
||
categories_file_name = r'/workspace/datasets/product_data/categories/categories_0001_abcat0010000_to_pcmcat99300050000.xml' | ||
|
||
queries_file_name = r'/workspace/datasets/train.csv' | ||
output_file_name = r'/workspace/datasets/fasttext/labeled_query_data.txt' | ||
|
||
parser = argparse.ArgumentParser(description='Process arguments.') | ||
general = parser.add_argument_group("general") | ||
general.add_argument("--min_queries", default=1, help="The minimum number of queries per category label (default is 1)") | ||
general.add_argument("--output", default=output_file_name, help="the file to output to") | ||
|
||
args = parser.parse_args() | ||
output_file_name = args.output | ||
|
||
if args.min_queries: | ||
directory = args.min_queries | ||
|
||
# The root category, named Best Buy with id cat00000, doesn't have a parent. | ||
root_category_id = 'cat00000' | ||
|
||
tree = ET.parse(categories_file_name) | ||
root = tree.getroot() | ||
|
||
# Map of queries to parents. | ||
parents = {} | ||
|
||
# Parse the category XML file to map each category id to its parent category id. | ||
for child in root: | ||
id = child.find('id').text | ||
cat_path = child.find('path') | ||
cat_path_ids = [cat.find('id').text for cat in cat_path] | ||
leaf_id = cat_path_ids[-1] | ||
if leaf_id != root_category_id: | ||
parents[leaf_id] = cat_path_ids[-2] | ||
|
||
# Read the training data into pandas. | ||
df = pd.read_csv(queries_file_name)[['category', 'query']] | ||
|
||
# Create labels in fastText format. | ||
df['label'] = '__label__' + df['category'] | ||
|
||
# IMPLEMENT ME: Trim the queries (some a`qre quoted strings), convert them to lowercase, and optionally | ||
# implement other normalization, like using the nltk stemmer. | ||
|
||
# IMPLEMENT ME: Roll up categories to ancestors to satisfy the minimum number of queries per category. | ||
|
||
# Output labeled query data as a space-separated file. | ||
df[['label', 'query']].to_csv(output_file_name, header=False, sep=' ', index=False) |