A Python script converting CFR (InfoFer) XML files to a SQLite database
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

267 lines
7.7 KiB

3 years ago
#! /usr/bin/env python3
import xml.etree.ElementTree as ET
from itertools import takewhile
try:
from tqdm import tqdm
except ImportError:
def tqdm(iter, *args, **kargs):
return iter
def get_database_connection():
import sqlite3
return sqlite3.connect('trains.sqlite')
def create_tables(con):
cursor = con.cursor()
cursor.execute("select name from sqlite_master where type='table' order by name;")
tables = [item[0] for item in cursor.fetchall()]
if 'Meta' not in tables:
cursor.execute('create table Meta (Versiune int)')
if 'Companii' not in tables:
cursor.execute('create table Companii (Id integer primary key, NumeLegal text, NumeComun text)')
if 'Trenuri' not in tables:
cursor.execute('create table Trenuri (Number integer primary key, IdCompanie int, CategorieTren, KmCum int, Lungime int, Numar, Operator, Proprietar, Putere, Rang, Servicii, Tonaj)')
if 'Trase' not in tables:
cursor.execute('create table Trase (NumarTren int, Id int, Tip, CodStatieInitiala int, CodStatieFinala int)')
if 'ElementeTrasa' not in tables:
cursor.execute('create table ElementeTrasa (NumarTren int, IdTrasa int, Secventa int, Ajustari, CodStaDest int, CodStaOrigine int, DenStaDestinatie, DenStaOrigine, Km int, Lungime int, OraP int, OraS int, Rci, Rco, Restrictie, StationareSecunde int, TipOprire, Tonaj, VitezaLivret int)')
3 years ago
if 'Statii' not in tables:
cursor.execute('create table Statii (CodStatie integer primary key, Denumire text)')
con.commit()
def insert(con, table, *args, _commit=True, **kargs):
cursor = con.cursor()
if args and not kargs:
arg_str = '(' + ', '.join((['?'] * len(args))) + ')'
cursor.execute(f"insert into {table} values {arg_str}", args)
elif not args and kargs:
arg_str = '(' + ', '.join((['?'] * len(kargs))) + ')'
apair = list(kargs.items())
keys = [k for (k, _) in apair]
values = [v for (_, v) in apair]
columns = '(' + ','.join(keys) + ')'
cursor.execute(f"insert into {table} {columns} values {arg_str}", values)
else:
raise Exception('Provide args XOR kargs')
if _commit:
con.commit()
def get_data_folder():
import os
data_folder = os.path.join('.', 'datafiles')
return data_folder
def get_xml_files():
import os
data_folder = get_data_folder()
for entry in os.listdir(data_folder):
entry = os.path.join(data_folder, entry)
if os.path.isfile(entry):
if os.path.splitext(entry)[1] == '.xml':
yield entry
def get_mappings():
from os.path import join
import json
data_folder = get_data_folder()
mappings_file = join(data_folder, 'mapping.json')
try:
with open(mappings_file) as f:
fj = json.load(f)
return fj['mappings']
except:
return []
def train_number_stoi(s):
return int(''.join(takewhile(lambda c: c.isnumeric(), s)))
def find_trains(con):
cursor = con.cursor()
cursor.execute('select Number from Trenuri;')
return set((item[0] for item in cursor.fetchall()))
def find_trase(con, train_number=None):
cursor = con.cursor()
if train_number is None:
cursor.execute('select NumarTren, Id from Trase')
else:
cursor.execute('select NumarTren, Id from Trase where NumarTren = ?', (train_number,))
return [(nt, i) for nt, i in cursor.fetchall()]
def find_secvente(con, train_number=None, id_trasa=None):
cursor = con.cursor()
if train_number is None:
cursor.execute('select NumarTren, IdTrasa, Secventa from ElementeTrasa')
elif id_trasa is None:
cursor.execute('select NumarTren, IdTrasa, Secventa from ElementeTrasa where NumarTren = ?', (train_number,))
else:
cursor.execute('select NumarTren, IdTrasa, Secventa from ElementeTrasa where NumarTren = ? and IdTrasa = ?', (train_number, id_trasa))
return [(nt, it, s) for nt, it, s in cursor.fetchall()]
def find_station_ids(con):
cursor = con.cursor()
cursor.execute('select CodStatie from Statii;')
return set((item[0] for item in cursor.fetchall()))
def find_companies(con):
cursor = con.cursor()
cursor.execute('select Id, NumeLegal, NumeComun from Companii')
return list(cursor.fetchall())
def main():
con = get_database_connection()
create_tables(con)
insert(con, 'Meta', 2)
station_ids = find_station_ids(con)
companies = find_companies(con)
mappings = get_mappings()
def get_company_name(path):
try:
from os.path import basename
bn = basename(path)
for mapping in mappings:
if mapping['filename'] == bn:
return mapping['legalName'], mapping['commonName']
except:
pass
return None, None
for f in get_xml_files():
company_legal_name, company_common_name = get_company_name(f)
if len([cln for (_, cln, ccn) in companies if cln == company_legal_name and ccn == company_common_name]) == 0:
insert(con, 'Companii', NumeLegal=company_legal_name, NumeComun=company_common_name)
companies = find_companies(con)
company_id = [i for (i, cln, ccn) in companies if cln == company_legal_name and ccn == company_common_name][0]
tree = ET.parse(f)
el_trenuri = tree.find("/XmlMts/Mt/Trenuri")
trains = find_trains(con)
print(f'Adding {company_common_name or f}...')
for el_tren in tqdm(el_trenuri.findall("./Tren")):
train_number_str = el_tren.attrib['Numar']
train_number = train_number_stoi(train_number_str)
if train_number in trains:
continue
trase = find_trase(con, train_number)
insert(
con,
'Trenuri',
train_number,
company_id,
el_tren.attrib['CategorieTren'],
el_tren.attrib['KmCum'],
el_tren.attrib['Lungime'],
train_number_str,
el_tren.attrib['Operator'],
el_tren.attrib['Proprietar'],
el_tren.attrib['Putere'],
el_tren.attrib['Rang'],
el_tren.attrib['Servicii'],
el_tren.attrib['Tonaj'],
_commit=False,
)
for el_trasa in el_tren.findall('./Trase/Trasa'):
id_trasa = int(el_trasa.attrib['Id'])
if (train_number, id_trasa) in trase:
continue
secvente = find_secvente(con, train_number, id_trasa)
insert(
con,
'Trase',
train_number,
id_trasa,
el_trasa.attrib['Tip'],
el_trasa.attrib['CodStatieInitiala'],
el_trasa.attrib['CodStatieFinala'],
_commit=False,
)
for el_elementtrasa in el_trasa.findall('./ElementTrasa'):
secventa = int(el_elementtrasa.attrib['Secventa'])
if (train_number, id_trasa, secventa) in secvente:
continue
insert(
con,
'ElementeTrasa',
train_number,
id_trasa,
secventa,
el_elementtrasa.attrib['Ajustari'],
el_elementtrasa.attrib['CodStaDest'],
el_elementtrasa.attrib['CodStaOrigine'],
el_elementtrasa.attrib['DenStaDestinatie'],
el_elementtrasa.attrib['DenStaOrigine'],
el_elementtrasa.attrib['Km'],
el_elementtrasa.attrib['Lungime'],
el_elementtrasa.attrib['OraP'],
el_elementtrasa.attrib['OraS'],
el_elementtrasa.attrib['Rci'],
el_elementtrasa.attrib['Rco'],
el_elementtrasa.attrib['Restrictie'],
el_elementtrasa.attrib['StationareSecunde'],
el_elementtrasa.attrib['TipOprire'],
el_elementtrasa.attrib['Tonaj'],
el_elementtrasa.attrib['VitezaLivret'],
_commit=False,
)
if el_elementtrasa.attrib['CodStaOrigine'].isnumeric():
cod_sta_orig = int(el_elementtrasa.attrib['CodStaOrigine'])
if cod_sta_orig not in station_ids:
station_ids.add(cod_sta_orig)
insert(
con,
'Statii',
cod_sta_orig,
el_elementtrasa.attrib['DenStaOrigine'],
_commit=False,
)
if el_elementtrasa.attrib['CodStaDest'].isnumeric():
cod_sta_orig = int(el_elementtrasa.attrib['CodStaDest'])
if cod_sta_orig not in station_ids:
station_ids.add(cod_sta_orig)
insert(
con,
'Statii',
cod_sta_orig,
el_elementtrasa.attrib['DenStaDestinatie'],
_commit=False,
)
con.commit()
con.commit()
if __name__ == '__main__':
main()