update to 1.0.2
This commit is contained in:
89
source/database/mysql_data_source.py
Normal file
89
source/database/mysql_data_source.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import mysql.connector
|
||||
|
||||
from .data_source_db import DataSourceDb
|
||||
|
||||
|
||||
class MySQLDataSource(DataSourceDb):
|
||||
def __init__(self, mysql_connection_info: Dict[str, str]) -> None:
|
||||
self.user = mysql_connection_info["user"]
|
||||
self.host = mysql_connection_info["host"]
|
||||
self.password = mysql_connection_info["password"]
|
||||
|
||||
def save_record(
|
||||
self,
|
||||
database: str,
|
||||
tablename: str,
|
||||
data: Dict[str, Any],
|
||||
use_obj_connection: bool = False,
|
||||
close_obj_connection: bool = True,
|
||||
):
|
||||
query, values = self.create_query(tablename, data)
|
||||
|
||||
if use_obj_connection:
|
||||
try:
|
||||
self.cursor.execute(query, values)
|
||||
|
||||
if close_obj_connection:
|
||||
self.db.commit()
|
||||
self.db.close()
|
||||
|
||||
except mysql.connector.Error:
|
||||
self.db.commit()
|
||||
self.db.close()
|
||||
raise
|
||||
|
||||
else:
|
||||
db = mysql.connector.connect(
|
||||
user=self.user,
|
||||
host=self.host,
|
||||
password=self.password,
|
||||
database=database,
|
||||
)
|
||||
try:
|
||||
cursor = db.cursor()
|
||||
cursor.execute(query, values)
|
||||
db.commit()
|
||||
db.close()
|
||||
except mysql.connector.Error:
|
||||
db.close()
|
||||
raise
|
||||
|
||||
def bulk_save(self, export_data: List[Dict[str, Any]]) -> None:
|
||||
for index, unit in enumerate(export_data):
|
||||
database, table_name, data = unit.values()
|
||||
|
||||
if (
|
||||
not hasattr(self, "db")
|
||||
or not hasattr(self, "cursor")
|
||||
or not self.db.is_connected()
|
||||
):
|
||||
self.db = mysql.connector.connect(
|
||||
user=self.user,
|
||||
host=self.host,
|
||||
password=self.password,
|
||||
database=database,
|
||||
)
|
||||
try:
|
||||
self.cursor = self.db.cursor()
|
||||
except:
|
||||
self.db.close()
|
||||
raise
|
||||
|
||||
self.save_record(
|
||||
database, table_name, data, True, index + 1 == len(export_data)
|
||||
)
|
||||
|
||||
def create_query(self, table_name: str, data: dict) -> Tuple[str, list]:
|
||||
columns = list(data.keys())
|
||||
values = list(data.values())
|
||||
|
||||
query = (
|
||||
f"REPLACE INTO {table_name} ("
|
||||
+ ", ".join(columns)
|
||||
+ ") VALUES ("
|
||||
+ ", ".join(["%s"] * len(columns))
|
||||
+ ")"
|
||||
)
|
||||
return (query, values)
|
||||
Reference in New Issue
Block a user