|
| 1 | +import psycopg2 |
| 2 | +from configparser import RawConfigParser |
| 3 | + |
| 4 | +config = RawConfigParser() |
| 5 | +config.read('config.properties') |
| 6 | + |
| 7 | + |
| 8 | +def get_conn(): |
| 9 | + if not hasattr(get_conn, 'db_details'): |
| 10 | + get_conn.db_details = dict(config.items('DB')) |
| 11 | + conn = psycopg2.connect(**get_conn.db_details) |
| 12 | + return conn |
| 13 | + |
| 14 | + |
| 15 | +def batch(iterable, size=1): |
| 16 | + for i in range(0, len(iterable), size): |
| 17 | + yield iterable[i: i + size] |
| 18 | + |
| 19 | + |
| 20 | +def insert_rows_batch(table, rows, batch_size=500, target_fields=None): |
| 21 | + """ |
| 22 | + NOTE: Handle data type for columns yourself if using this method. |
| 23 | + Otherwise look at insert_rows() |
| 24 | +
|
| 25 | + A utility method to insert batch of set of tuples into a table, |
| 26 | + a new transaction is created after every batch size |
| 27 | +
|
| 28 | + :param table: Name of the target table |
| 29 | + :type table: str |
| 30 | + :param rows: The rows to insert into the table |
| 31 | + :type rows: iterable of tuples |
| 32 | + :param batch_size: The size of batch of rows to insert at a time |
| 33 | + :type batch_size: int |
| 34 | + :param target_fields: The names of the columns to fill in the table |
| 35 | + :type target_fields: iterable of strings |
| 36 | + """ |
| 37 | + if target_fields: |
| 38 | + target_fields = ", ".join(target_fields) |
| 39 | + target_fields = "({})".format(target_fields) |
| 40 | + else: |
| 41 | + target_fields = '' |
| 42 | + |
| 43 | + conn = get_conn() |
| 44 | + cur = conn.cursor() |
| 45 | + count = 0 |
| 46 | + |
| 47 | + for mini_batch in batch(rows, batch_size): |
| 48 | + mini_batch_size = len(mini_batch) |
| 49 | + count += mini_batch_size |
| 50 | + record_template = ','.join(["%s"] * mini_batch_size) |
| 51 | + sql = "INSERT INTO {0} {1} VALUES {2};".format( |
| 52 | + table, |
| 53 | + target_fields, |
| 54 | + record_template) |
| 55 | + cur.execute(sql, mini_batch) |
| 56 | + conn.commit() |
| 57 | + print("Loaded {} rows into {} so far".format(count, table)) |
| 58 | + print("Done loading. Loaded a total of {} rows".format(count)) |
| 59 | + cur.close() |
| 60 | + conn.close() |
| 61 | + |
| 62 | + |
| 63 | +def upsert(table, pk_fields, all_fields, rows, pk_name=None, schema=None, target_fields=None, |
| 64 | + batch_size=100): |
| 65 | + """ |
| 66 | + Implements Insert + Update (UPSERT) for Postgres database. |
| 67 | + |
| 68 | + Make sure to pass pk_fields and all_fields params. |
| 69 | + |
| 70 | + NOTE: Maintain the order of the all_fields parameter as per the order of column names in database. |
| 71 | +
|
| 72 | + :param table: The name of the table |
| 73 | + :param pk_fields: A list of primary key field(s) |
| 74 | + :param all_fields: A list of all fields or column names of the database in correct order |
| 75 | + :param rows: A list of tuples of rows to be insetred or updated |
| 76 | + :param pk_name: The name of the table primary key. Don't pass it if primary key name has not been set manually, |
| 77 | + in that case will use the default primary key name as TABLE-NAME_pkey |
| 78 | + :param schema: The schema used for the table. |
| 79 | + :param target_fields: A list of all coulmn names (Optional). |
| 80 | + :param batch_size: The size of the batch to perform upsert upon, default 100 |
| 81 | + :return: None |
| 82 | + """ |
| 83 | + |
| 84 | + assert len(pk_fields) > 0 and len(all_fields) > 1 |
| 85 | + |
| 86 | + other_fields = [field for field in all_fields if field not in pk_fields] |
| 87 | + |
| 88 | + if target_fields: |
| 89 | + target_fields = ", ".join(target_fields) |
| 90 | + target_fields = "({})".format(target_fields) |
| 91 | + else: |
| 92 | + target_fields = '' |
| 93 | + |
| 94 | + if not pk_name: |
| 95 | + pk_name = table + "_pkey" |
| 96 | + |
| 97 | + if schema and '.' not in table: |
| 98 | + table = '%s.%s' % (schema, table) |
| 99 | + |
| 100 | + field_bracket = "{}" if len(other_fields) == 1 else "({})" |
| 101 | + |
| 102 | + insert_sql = "INSERT INTO {} {}".format(table, target_fields) + " VALUES {}" + \ |
| 103 | + " ON CONFLICT ON CONSTRAINT {}".format(pk_name) + \ |
| 104 | + " DO UPDATE SET " + field_bracket.format(', '.join(other_fields)) + \ |
| 105 | + " = ({}) ;".format(', '.join(['EXCLUDED.' + col for col in other_fields])) |
| 106 | + |
| 107 | + conn = get_conn() |
| 108 | + cur = conn.cursor() |
| 109 | + count = 0 |
| 110 | + |
| 111 | + for mini_batch in batch(iterable=rows, size=batch_size): |
| 112 | + mini_batch_size = len(mini_batch) |
| 113 | + record_template = ','.join(["%s"] * mini_batch_size) |
| 114 | + cur.execute(insert_sql.format(record_template), mini_batch) |
| 115 | + conn.commit() |
| 116 | + count += mini_batch_size |
| 117 | + print("Commit done on {} row(s) for UPSERT so far.".format(count)) |
| 118 | + print("Commit done on all {} rows for UPSERT.".format(len(rows))) |
| 119 | + print("UPSERT Done.") |
| 120 | + cur.close() |
| 121 | + conn.close() |
0 commit comments